# Assignment 3: Image Classification

### Import Packages

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
from torchvision.datasets import DatasetFolder, VisionDataset
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset

In [None]:
myseed = 6666  # set a random seed for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)

### Transforms

In [None]:
test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(90, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
])


### Datasets

In [None]:
class FoodDataset(Dataset):
    def __init__(self, tfm, path="/kaggle/input/ml2023spring-hw3", isTrain=True):
        super(FoodDataset).__init__()
        if isTrain:
            self.train_path = path + "/train"
            self.valid_path = path + "/valid"
            self.files = [self.train_path + "/" + x for x in os.listdir(self.train_path) if x.endswith(".jpg")]
            self.files += [self.valid_path + "/" + x for x in os.listdir(self.valid_path) if x.endswith(".jpg")]
            np.random.shuffle(self.files)
        else:
            self.path = path + "/test"
            self.files = sorted([self.path + "/" + x for x in os.listdir(self.path) if x.endswith(".jpg")])

        self.transform = tfm

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        im = Image.open(fname)
        im = self.transform(im)

        try:
            label = int(fname.split("/")[-1].split("_")[0])
        except:
            label = -1  # test has no label

        return im, label


### Models

In [None]:
from torchvision import models


class Resnet(nn.Module):
    def __init__(self, n_class):
        super(Resnet, self).__init__()
        self.cnn = models.resnet18(weights=None)
        self.cnn.fc = nn.Linear(512, n_class)

    def forward(self, x):
        return self.cnn(x)


### Configurations

In [None]:
device =  "cuda" if torch.cuda.is_available() else "cpu"

# hyperparameters
batch_size = 64
n_epochs = 100

patience = 8  # If no improvement in 'patience' epochs, early stop.


### Construct Dataset

In [None]:
dataset = FoodDataset(train_tfm)

### Start Training

In [None]:
from sklearn.model_selection import KFold
from torch.utils.data import SubsetRandomSampler

In [None]:
_exp_name = "resnet18"
fold_idx = 0

In [None]:
def adjust_learning_rate(optimizer):
    isPrint = False
    for param_group in optimizer.param_groups:
        if isPrint == False:
            lr = param_group["lr"]
            print(f"--- Learning rate decreases from {lr:.6f} to {lr * 0.8:.6f}. ---")
            isPrint = True
        param_group["lr"] = param_group["lr"] * 0.8


In [None]:
kf = KFold(n_splits=4)

for fold, (train_idx, valid_idx) in enumerate(kf.split(dataset)):
    if fold != fold_idx:
        continue
    
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0, pin_memory=True )
    valid_loader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler, num_workers=0, pin_memory=True )

    model = Resnet(11).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # Initialize trackers, these are not parameters and should not be changed
    stale = 0
    best_acc = 0

    for epoch in range(n_epochs):
        if stale > 5:
            adjust_learning_rate(optimizer)

        # ---------- Training ----------
        model.train()
        train_loss = []
        train_accs = []

        with tqdm(total=len(train_loader), unit="batch") as tqdm_bar:
            tqdm_bar.set_description(f"Epoch {epoch + 1:03d}/{n_epochs:03d}")
            for batch in train_loader:
                imgs, labels = batch

                # Forward the data.
                logits = model(imgs.to(device))

                # Calculate the cross-entropy loss.
                loss = criterion(logits, labels.to(device))

                # Gradients stored in the parameters in the previous step should be cleared out first.
                optimizer.zero_grad()

                # Compute the gradients for parameters.
                loss.backward()

                # Clip the gradient norms for stable training.
                grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)

                # Update the parameters with computed gradients.
                optimizer.step()

                # Compute the accuracy for current batch.
                acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

                # Record the loss and accuracy.
                train_loss.append(loss.item())
                train_accs.append(acc)

                tqdm_bar.update(1)
                tqdm_bar.set_postfix(loss=f"{sum(train_loss)/len(train_loss):.5f}", acc=f"{sum(train_accs) / len(train_accs):.5f}", val_loss=f"{0:.5f}", val_acc=f"{0:.5f}")

            train_loss = sum(train_loss) / len(train_loss)
            train_acc = sum(train_accs) / len(train_accs)
            tqdm_bar.set_postfix(loss=f"{train_loss:.5f}", acc=f"{train_acc:.5f}", val_loss=f"{0:.5f}", val_acc=f"{0:.5f}")

            # ---------- Validation ----------
            model.eval()
            valid_loss = []
            valid_accs = []

            for batch in valid_loader:
                imgs, labels = batch

                # Using torch.no_grad() accelerates the forward process.
                with torch.no_grad():
                    logits = model(imgs.to(device))

                # We can still compute the loss (but not the gradient).
                loss = criterion(logits, labels.to(device))

                # Compute the accuracy for current batch.
                acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

                # Record the loss and accuracy.
                valid_loss.append(loss.item())
                valid_accs.append(acc)

                tqdm_bar.set_postfix(
                    loss=f"{train_loss:.5f}", acc=f"{train_acc:.5f}", val_loss=f"{sum(valid_loss) / len(valid_loss):.5f}", val_acc=f"{sum(valid_accs) / len(valid_accs):.5f}"
                )

            # The average loss and accuracy for entire validation set is the average of the recorded values.
            valid_loss = sum(valid_loss) / len(valid_loss)
            valid_acc = sum(valid_accs) / len(valid_accs)

            tqdm_bar.set_postfix(loss=f"{train_loss:.5f}", acc=f"{train_acc:.5f}", val_loss=f"{valid_loss:.5f}", val_acc=f"{valid_acc:.5f}")
            tqdm_bar.close()

        # update logs
        if valid_acc > best_acc:
            with open(f"./{_exp_name}_log.txt", "a") as f:
                f.write(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best\n")
        else:
            with open(f"./{_exp_name}_log.txt", "a") as f:
                f.write(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}\n")

        # save models
        if valid_acc > best_acc:
            print(f"Best model found at epoch {epoch+1}, saving model")
            torch.save(model.state_dict(), f"{_exp_name}_best.ckpt")  # only save best to prevent output memory exceed error
            best_acc = valid_acc
            stale = 0
        else:
            stale += 1
            if stale > patience:
                print(f"No improvment {patience} consecutive epochs, early stopping")
                break


### Dataloader for test

In [None]:
# Construct test datasets.
test_set = FoodDataset(test_tfm, isTrain=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

### Test Time Augmentation

In [None]:
tta_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(90, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
])

tta_num = 5

### Testing and generate prediction CSV

In [None]:
model_best = Resnet(11).to(device)
model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt"))
model_best.eval()

prediction = []
with torch.no_grad():
    for data, _ in tqdm(test_loader):
        for img in data:
            test_input = img.view(1, 3, 224, 224)
            test_pred = model_best(test_input.to(device))
            test_pred = test_pred.cpu().data.numpy()

            # test time augmentation
            tta_pred = np.zeros((1, 11))
            for _ in range(tta_num):
                test_augmented = tta_transform(img)
                test_augmented = test_augmented.view(1, 3, 224, 224)
                pred = model_best(test_augmented.to(device))
                tta_pred = tta_pred + pred.cpu().data.numpy()
            tta_pred = tta_pred / tta_num
            
            # final prediction
            test_label = np.argmax(test_pred * 0.7 + tta_pred * 0.3)

            prediction.append(test_label)


In [None]:
# create test csv
def pad4(i):
    return "0" * (4 - len(str(i))) + str(i)


df = pd.DataFrame()
df["Id"] = [pad4(i) for i in range(len(test_set))]
df["Category"] = prediction
df.to_csv(f"{_exp_name}_submission.csv", index=False)
