In [None]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW

from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score

import numpy as np
import pandas as pd

import os
import random
from tqdm import tqdm
from PIL import Image

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(7)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset = datasets.ImageFolder(root='/kaggle/input/image-classification-2024-spring/dataset/train')

In [None]:
dataset_size = len(dataset)
train_size = int(dataset_size * 0.7)
val_size = dataset_size - train_size

trainset, valset = random_split(dataset, [train_size, val_size])

In [None]:
train_transform = transforms.Compose([
    transforms.RandomApply([
        transforms.RandAugment(num_ops=14, magnitude=15)], p=0.2),
    transforms.Resize((380, 380)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((380, 380)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

class CustomImageDataset(Dataset):
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __getitem__(self, idx):
        image = self.data[idx][0]
        label = self.data[idx][1]
        image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.data)

In [None]:
trainset = CustomImageDataset(trainset, train_transform)
valset = CustomImageDataset(valset, test_transform)

In [None]:
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
import torchvision.models as models
model = models.efficientnet_v2_s(pretrained=True)

num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
model = model.to(device)

In [None]:
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import OneCycleLR

steps_per_epoch = len(train_loader)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=15,
    steps_per_epoch=steps_per_epoch,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)
scaler = GradScaler()

In [None]:
if not os.path.exists('checkpoint'):
    os.makedirs('checkpoint')

best_acc = 0.

In [None]:
for epoch in range(3):
    model.train()
    running_loss = 0.0
    preds = []
    labels = []

    for inputs, label in tqdm(train_loader):
        inputs = inputs.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, label.long())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        preds += predicted.detach().cpu().numpy().tolist()
        labels += label.detach().cpu().numpy().tolist()
    train_accuracy = accuracy_score(labels, preds)
    print(f'epoch {epoch} - train_accuracy: {train_accuracy}')

    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for inputs, label in tqdm(val_loader):
            inputs = inputs.to(device)
            label = label.to(device)

            with autocast():
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                val_preds += predicted.detach().cpu().numpy().tolist()
                val_labels += label.detach().cpu().numpy().tolist()

    val_accuracy = accuracy_score(val_labels, val_preds)
    print(f'epoch {epoch} - val_accuracy: {val_accuracy}')

#     if val_accuracy >= best_acc:
#         best_acc = val_accuracy
    if epoch == 2:
        torch.save(model.state_dict(), f'checkpoint/model1.pth')
    scheduler.step()