In [None]:
## Removing mac specific dotfiles
import os
import shutil

def rem_macfiles(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file == ".DS_Store":
                os.remove(os.path.join(root, file))

rem_macfiles("data")


In [None]:
## import libraries
import os
import numpy as np
import pandas as pd
import splitfolders
from glob import glob
# from tqdm import tqdm, trange
from tqdm.autonotebook import tqdm, trange
from numba import njit, jit
from natsort import natsorted

import matplotlib.pyplot as plt

#PIL
from PIL import Image, ImageOps

#random
from random import sample

#open cv
import cv2

#sklearn
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold, train_test_split, KFold
from sklearn.preprocessing import LabelBinarizer

#scipy
from scipy import stats 

#pytorch
import torch
from torch.utils.data import Dataset, Subset
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

#timm
import timm
import timm.optim
import timm.scheduler
from timm.data import ImageDataset, create_dataset, create_loader
from timm.data.transforms_factory import create_transform

import evaluate

import time

import warnings
warnings.filterwarnings('ignore')

In [None]:
## Setup Apple Silicon or CUDA for enhanced training

if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

device

In [None]:
## Visilize one image

paths = glob('./data/train/**/**')

Image.open(paths[0]).convert("RGB")

In [None]:
## defining image tranformations

img_size_train = 256            ## values to be selected by looking at the huggingface page of each model

img_size_test = 320

transform = {
    "train": transforms.Compose([
        transforms.Resize(size=(img_size_train, img_size_train), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor()
    ]),
    "test": transforms.Compose([
        transforms.Resize(size=(img_size_test, img_size_test), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor()
    ])
}

In [None]:
## preparing the data

train_path = './data/train/'

dataset = ImageDataset(train_path, transform=transform["train"])
id2label = natsorted(os.listdir('./data/train/'))
id2label

In [None]:
y = glob('./data/train/**/**')
y = natsorted(y)
y = [os.path.basename(os.path.dirname(i)) for i in y]

y[:10]

In [None]:
## Check data count
len(paths), len(y)

In [None]:
## Training parameters

model_name = 'resnet18'

num_epochs = 15
criterion = nn.CrossEntropyLoss().to(device)

train_batch_size = 4
eval_batch_size = 4
num_accumulate = 4
num_classes = len(id2label)

## Cross validarion configuration
k_splits = 5
metric = evaluate.load("f1", additional_keys=["accuracy", "precision", "recall", "support"])

In [None]:
all_eval_scores = []

!rm -rf first_model
!mkdir first_model

skf = StratifiedKFold(n_splits=k_splits, shuffle=True)

training_start_time = time.time()

for fold, (train_idx, val_idx) in enumerate(skf.split(dataset, y)):
    
    # display fold number
    print(f"\nFold: {fold+1} / {k_splits}")

    # load model
    model = timm.create_model(model_name=model_name, pretrained=True, num_classes=num_classes).to(device)

    # optimizer and sceduler
    optimizer = timm.optim.create_optimizer_v2(model, opt="AdamW", lr=1e-3)
    scheduler = timm.scheduler.create_scheduler_v2(optimizer, num_epochs=num_epochs)[0]

    # split main dataset into train and val set with kfold
    train_dataset = Subset(dataset, train_idx)
    val_dataset = Subset(dataset, val_idx)

    # dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=eval_batch_size, shuffle=True)

    # Reset Model Info
    info = {
        "metric_train": {"f1": [], "accuracy": [], "precision": [], "recall": [], "auc_roc": []},
        "metric_val": {"f1": [], "accuracy": [], "precision": [], "recall": [], "auc_roc": []},
        "train_loss": [],
        "val_loss": [],
        "best_metric_val": -np.inf,
        "confusion_matrix": None,
    }
    count = 0

    # Create a LabelBinarizer object to convert labels to binary class matrix
    lb = LabelBinarizer()

    for epoch in range(num_epochs):
        train_loss_epoch = []
        val_loss_epoch = []
    
        train_preds = []
        train_targets = []
    
        val_preds = []
        val_targets = []
    
        num_updates = epoch * len(train_dataloader)


        # Training loop
        model.train()
        for idx, (X, y) in enumerate(tqdm(train_dataloader)):
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = criterion(logits, y)

            loss.backward()

            if ((idx + 1) % num_accumulate == 0) or (idx + 1 == len(train_dataloader)):
                optimizer.step()
                scheduler.step_update(num_updates=num_updates)
                optimizer.zero_grad()
            
            train_loss_epoch.append(loss.item())
            train_preds += logits.argmax(-1).detach().tolist()
            train_targets += y.tolist()
        
        optimizer.step()
        scheduler.step(epoch + 1)

        # Evaluation loop

        model.eval()
        with torch.no_grad():
            for (X, y) in tqdm(val_dataloader):
                X, y = X.to(device), y.to(device)
                logits = model(X)
                loss = criterion(logits, y)

                val_loss_epoch.append(loss.item())
                val_preds += logits.argmax(-1).detach().tolist()
                val_targets += y.tolist()


        # Convert labels to binary class matrix
        train_targets_bin = lb.fit_transform(train_targets)
        train_preds_bin = lb.transform(train_preds)

        # Convert labels to binary class matrix
        val_targets_bin = lb.fit_transform(val_targets)
        val_preds_bin = lb.transform(val_preds)
                
        # Compute evaluation metrics
        metric_train = metric.compute(predictions=train_preds, references=train_targets, average="macro")
        metric_val = metric.compute(predictions=val_preds, references=val_targets, average="macro")

        # Calculate metrics for the training set
        precision_train, recall_train, f1_train, support_train = precision_recall_fscore_support(train_targets, train_preds, average='macro')
        acc_train = np.sum(np.array(train_preds) == np.array(train_targets)) / len(train_preds)

        # Calculate metrics for the validation set
        precision_val, recall_val, f1_val, support_val = precision_recall_fscore_support(val_targets, val_preds, average='macro')
        acc_val = np.sum(np.array(val_preds) == np.array(val_targets)) / len(val_preds)

        info["metric_train"]["f1"].append(f1_train)
        info["metric_train"]["precision"].append(precision_train)
        info["metric_train"]["recall"].append(recall_train)
        info["metric_train"]["auc_roc"].append(roc_auc_score(train_targets_bin, train_preds_bin, multi_class='ovo'))
        info["metric_train"]["accuracy"].append(acc_train)

        info["metric_val"]["f1"].append(f1_val)
        info["metric_val"]["precision"].append(precision_val)
        info["metric_val"]["recall"].append(recall_val)
        info["metric_val"]["auc_roc"].append(roc_auc_score(val_targets_bin, val_preds_bin, multi_class='ovo'))
        info["metric_val"]["accuracy"].append(acc_val)
        
        info["train_loss"].append(np.average(train_loss_epoch))
        info["val_loss"].append(np.average(val_loss_epoch))

        cm = confusion_matrix(val_targets, val_preds)
        info["confusion_matrix"] = cm.tolist()

        if metric_val["f1"] > info["best_metric_val"]:
            print("\nNew Best Score!")
            info["best_metric_val"] = metric_val["f1"]
            torch.save(model, f"./first_model/checkpoint_fold{fold}.pt")
            count = 0
        else:
            count += 1
            
        print(info)
        print(f"Fold: {fold} | Epoch: {epoch} | Metric: {metric_val['f1']} | Training Loss: {np.average(train_loss_epoch)} | Validation Loss: {np.average(val_loss_epoch)}\n")

    # save all best metric val
    all_eval_scores.append(info["best_metric_val"])

print('Training finished, took {:.2f}s'.format(time.time() - training_start_time))


In [None]:
all_eval_scores

In [None]:
## get the labels 

id2label = natsorted(os.listdir('./data/test/'))
if '.DS_Store' in id2label:
    id2label.remove('.DS_Store')
id2label

In [None]:
## Load the test data 

test_paths = natsorted(glob("./data/test/**/**"))
test_y = [os.path.basename(os.path.dirname(path)) for path in test_paths]

test_paths[:5], test_y[:5]

In [None]:
## convert literal classes in y_test to numeric labels

label_map = dict(zip(id2label, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38]))

lables = list(map(lambda x: label_map[x], test_y))

In [None]:
#for ensemble method
pred_all = []

for fold in trange(5):

    #load model
    model = torch.load(f"./first_model/checkpoint_fold{fold}.pt")
    model.eval()

    pred = []
    with torch.no_grad():
        for f in tqdm(test_paths):

            #image
            img = Image.open(f).convert("RGB")
            transformed = transform["test"](img).unsqueeze(0).to(device)

            #cls
            cls = model(transformed).argmax(-1).item()
            #cls = id2label[cls]

            pred.append(cls)

    #for ensemble method
    pred_all.append(pred)

In [None]:
#ensemble_value
sub_ensemble = []

#num of fold
kfold = 5 #kfold

#ensemble
for i in trange(len(pred_all[0])):
    check = []
    
    #loop every fold
    for j in range(kfold):
        check.append(pred_all[j][i])
  
    cls = stats.mode(check).mode
    sub_ensemble.append(cls)

In [None]:
acc = accuracy_score(lables, sub_ensemble)
precision_test, recall_test, f1_test, support_test = precision_recall_fscore_support(lables, sub_ensemble, average='macro')
print(f'accuracy {acc}')
print(f'precision {precision_test}')
print(f'recall {recall_test}')
print(f'f1-score {f1_test}')