In [None]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

from torchvision import datasets, transforms

from sklearn.metrics import accuracy_score

import numpy as np
import pandas as pd

import os
import random
from tqdm import tqdm
from PIL import Image

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

seed_everything(7)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
class JpegCompression(transforms.Lambda):
    def __init__(self, quality_lower=60, quality_upper=100, p=0.5):
        super().__init__(self.apply_jpeg_compression)
        self.quality_lower = quality_lower
        self.quality_upper = quality_upper
        self.probability = p

    def apply_jpeg_compression(self, img):
        if random.random() < self.probability:
            quality = random.randint(self.quality_lower, self.quality_upper)
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG", quality=quality)
            buffer.seek(0)
            img = Image.open(buffer)
        return img

In [None]:
def random_color_jitter():
    return transforms.ColorJitter(
        brightness=random.uniform(0.1, 0.3),
        contrast=random.uniform(0.1, 0.3),
        saturation=random.uniform(0.1, 0.3)
    )

In [None]:
class Crop20(object):
    def __call__(self, image):
        if image.size == (1080, 1920):
            return image.crop((0,0,1080,1536))
        else:
            return image
        
class ConditionalResize:
    def __init__(self, target_size):
        self.target_size = target_size

    def __call__(self, img):
        if max(img.size) < self.target_size:
            return transforms.Resize(self.target_size, interpolation=transforms.InterpolationMode.BILINEAR)(img)
        return img

In [None]:
import torchvision.transforms as transforms
import io

target_size = 380
transform = transforms.Compose([
    Crop20(),
    ConditionalResize(target_size),
    transforms.CenterCrop(target_size),
    JpegCompression(quality_lower=60, quality_upper=90, p=0.2),
    transforms.RandomApply([ 
        transforms.Lambda(lambda img: random_color_jitter()(img))
    ], p=0.2),
    transforms.RandomPerspective(0.1, 0.1),
    transforms.RandomHorizontalFlip(0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
dataset = datasets.ImageFolder(root='/kaggle/input/image-classification-2024-spring/dataset/train', transform=transform)

In [None]:
dataset_size = len(dataset)
train_size = int(dataset_size * 0.95)
val_size = dataset_size - train_size

trainset, valset = random_split(dataset, [train_size, val_size])

In [None]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
val_loader = DataLoader(valset, batch_size=32, shuffle=False)

In [None]:
import torchvision.models as models
model = models.efficientnet_v2_s(pretrained=True)

num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
model = model.to(device)

In [None]:
steps_per_epoch = len(train_loader)
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = OneCycleLR(
    optimizer,
    max_lr=1e-3,             
    epochs=7,             
    steps_per_epoch=steps_per_epoch
)

In [None]:
if not os.path.exists('checkpoint'):
    os.makedirs('checkpoint')

best_acc = 0.

In [None]:
for epoch in range(7):
    model.train()
    running_loss = 0.0
    preds = []
    labels = []

    for inputs, label in tqdm(train_loader):
        inputs = inputs.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, label.long())
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        preds += predicted.detach().cpu().numpy().tolist()
        labels += label.detach().cpu().numpy().tolist()

    train_accuracy = accuracy_score(labels, preds)
    print(f'train_accuracy: {train_accuracy}')

    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for inputs, label in tqdm(val_loader):
            inputs = inputs.to(device)
            label = label.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            val_preds += predicted.detach().cpu().numpy().tolist()
            val_labels += label.detach().cpu().numpy().tolist()

    val_accuracy = accuracy_score(val_labels, val_preds)
    print(f'val_accuracy: {val_accuracy}')

    if epoch == 6:
        torch.save(model.state_dict(), f'checkpoint/model3.pth')
    scheduler.step()