### IMPORT

In [None]:
# Import required libraries
import os
import random
import warnings

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import albumentations as A
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models
import wandb

# Import custom modules
from utils.loss import ASLSingleLabel
from utils.optim import ASAM, SAM
from utils.dataset import Cifar10SearchDataset

# Set warning filter
warnings.filterwarnings("ignore")

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [None]:
os.environ['WANDB_START_METHOD'] = 'thread'

run_name = 'SAM_cosine_ASL_augv3_224'
wandb.init(project="cifa10_proj", name=run_name)

### HyperParameter Setting

In [None]:
CFG = {
    'Learning_rate': 1e-4,
    'EPOCHS': 100,
    'BATCH_SIZE': 128,
    'SEED' : 42
}

wandb.config.update(CFG)

### Fixed Seed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['SEED'])

### Dadaset Load

In [None]:
# mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

composed_train = A.Compose([A.Resize(224, 224),
                            A.Rotate(limit=30, p=0.5),
                            A.HorizontalFlip(p=0.2),
                            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.1),
                            A.Cutout(num_holes=1, max_h_size=16, max_w_size=16, p=0.75),
                            A.Normalize(mean=mean, std=std)
                            ])

composed_test = A.Compose([A.Resize(224, 224),
                           A.HorizontalFlip(p=0.2),
                           A.Normalize(mean=mean, std=std)
                           ])

In [None]:
root_dir = "../data/CIFAR_10"

train_dataset = Cifar10SearchDataset(root=root_dir, train=True, download=True, transform=composed_train)
test_dataset = Cifar10SearchDataset(root=root_dir, train=False, transform=composed_test)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=CFG['BATCH_SIZE'],
                          shuffle=True,
                          pin_memory=True,
                          num_workers=32
                          )
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=CFG['BATCH_SIZE'],
                         shuffle=False,
                         pin_memory=True,
                         num_workers=32
                         )

### Model define

In [None]:
# import timm

model = models.resnet18(pretrained='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, 10)
# resmetV2 - resnet18
# model = timm.create_model('resnet18', pretrained=True, num_classes=10).to(device)

# optimizer = torch.optim.SGD(model.parameters(), lr=CFG['Learning_rate'], momentum=0.9, weight_decay=1e-4)
base_optimizer = torch.optim.AdamW #(model.parameters(), lr=CFG['Learning_rate'])
optimizer = SAM(model.parameters(), base_optimizer, lr=CFG['Learning_rate'])

# exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG['EPOCHS'], eta_min=0)

### Train

In [None]:
from copy import deepcopy

class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

In [None]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model = nn.DataParallel(model, device_ids=[0, 1], dim=0, output_device=0)
    model.to(device)

    model_ema = ModelEmaV2(model, device=device)

    criterion = ASLSingleLabel().to(device)
    
    best_score = 0
    best_epoch = 0

    gradient_accumulation_steps = 1

    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        total, correct = 0, 0
        for step, (imgs, labels) in enumerate(tqdm(iter(train_loader)), start=1):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            output = model(imgs)

            # first step
            loss = criterion(output, labels)
            loss.backward()
            
            if step % gradient_accumulation_steps == 0:
                # # optimizer.first_step(zero_grad=True)
                # minimizer.ascent_step()

                # # second step
                # criterion(model(imgs), labels).backward()
                # # optimizer.second_step(zero_grad=True)
                # minimizer.descent_step()

                def closure():
                    return criterion(model(imgs), labels).backward()
                
                optimizer.step(closure)

                optimizer.zero_grad()  # Reset gradients after accumulation

            model_ema.update(model)
            
            train_loss.append(loss.item())

            _, predicted = output.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
        _val_loss, _val_score = validation(model, criterion, val_loader, device)
        _train_loss = np.mean(train_loss)
        train_acc = correct / total

        print(f'Epoch [{epoch}] Train Loss : [{_train_loss:.5f}] Train Acc [{train_acc:.5f}] Val Loss : [{_val_loss:.5f}] Val Acc : [{_val_score:.5f}]')
       
        if scheduler is not None:
            scheduler.step()

        wandb.log({"Val Loss": _val_loss, "Val Acc": _val_score, "Train Loss": _train_loss, "Train Acc": train_acc, "lr" : optimizer.param_groups[0]['lr']})
        
    return model

In [None]:
def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss = []
    totals, corrects = 0, 0

    with torch.no_grad():
        for imgs, labels in iter(val_loader):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            pred = model(imgs)
            loss = criterion(pred, labels)
            val_loss.append(loss.item())

            _, predicted = pred.max(1)
            totals += labels.size(0)
            corrects += predicted.eq(labels).sum().item()
        
        _val_loss = np.mean(val_loss)
        _val_score = corrects / totals
    
    return _val_loss, _val_score

### Run

In [None]:
infer_model = train(model, optimizer, train_loader, test_loader, exp_lr_scheduler, device)

### Classification Report

In [None]:
from sklearn.metrics import classification_report

infer_model.eval()
preds, true_labels = [], []

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

with torch.no_grad():
    for imgs, labels in tqdm(iter(test_loader)):
        imgs = imgs.float().to(device)
        labels = labels.to(device)

        pred = infer_model(imgs)

        preds += pred.argmax(1).detach().cpu().numpy().tolist()
        true_labels += labels.detach().cpu().numpy().tolist()

report_df = pd.DataFrame(classification_report(true_labels, preds, target_names=class_names, output_dict=True)).transpose()
print(report_df)

In [None]:
report_df

In [None]:
# confusion matrix
from sklearn.metrics import confusion_matrix
import seaborn as sns
from matplotlib import pyplot as plt

cm = confusion_matrix(true_labels, preds)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='.2f', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()