In [None]:
# !mkdir ~/.kaggle
# !mv kaggle.json ~/.kaggle/kaggle.json
# !pip3 install kaggle
# !kaggle datasets download -d imran2002/imagenet-top50-400train-50val

In [None]:
# !apt-get update
# !apt-get install p7zip-full -y
!7z x data_raw.zip -odata/

In [None]:
# !git clone https://github.com/imrnh/beyond_intuition_code.git
# !cp -r beyond_intuition_code/* /workspace/
# !rm -rf beyond_intuition_code

In [None]:
!pip3 install h5py einops tqdm

In [4]:
import warnings
warnings.filterwarnings('ignore')

# **Training Loop**

In [None]:
def embd_to_class(self, batched_embed):
    class_indices = []
    for embd in batched_embed:
        cls_idx = torch.argmax(embd)
        class_indices.append(cls_idx)
    return torch.Tensor(class_indices)

def measure_accuracy(self, yhat, y):
    yhat_idx = self.embd_to_class(yhat)

    accuracy = 0.0
    for yh, yval in zip(yhat_idx, y):
        accuracy += 1 if (yh == yval) else 0
        
    return accuracy


def callbacks(t_loss, v_loss, vacc, epoch, train_logs):
    if epoch > 1:
        prev_vl = train_logs[-2]['val_loss']
        prev_vacc = train_logs[-2]['val_accuracy']

        if vacc >= prev_vacc :
            save_model(f"raw_images_{str(epoch)}_max_validation_accuracy")
            print(f"Saved for acc:{vacc}")

        if v_loss <= prev_vl :
            save_model(f"raw_images_{str(epoch)}_min_validation_loss")
            print(f"Saved for loss: {v_loss}")

    return train_logs

In [5]:
"""
    Train the models for given saliency map.
"""

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader

import os
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from utils.model_loaders import vit_base_patch16_224

np.random.seed(0)
torch.manual_seed(0)


class ModelTrainer:
    def __init__(self, data_path, epochs, batch_size, num_workers, lr, weight_decay) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        self.save_dir = "custom_trained_models/"
        self.data_path = data_path
        self.train_logs = []


        self.batch_size = batch_size
        self.num_workers = num_workers
        self.epochs = epochs
        
        self.lr = lr
        self.weight_decay = weight_decay
        
        # Model, Optimizer and Loss function setup
        self.model = vit_base_patch16_224(pretrained=False, num_classes=50).cuda()

        self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)
        self.criterion = CrossEntropyLoss()

        # Data loading and transforms
        self.image_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        self.dataset = datasets.ImageFolder(root=self.data_path + "/train", transform=self.image_transform)
        self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False)

        self.validation_dataset = datasets.ImageFolder(root=self.data_path + "/val", transform=self.image_transform)
        self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=int(self.batch_size / 4), shuffle=False, num_workers=self.num_workers, drop_last=False)
    
        print(f"{len(self.dataloader) * batch_size } Train Images found")
        print(f"{len(self.validation_dataloader) * batch_size} Validation Images found")
        
        
    """
        Train the model and finally save it.
    """
    def train(self):
        for epoch in range(self.epochs):
            train_loss, val_loss, val_accuracy = 0.0, 0.0, 0.0

            for batch in tqdm(self.dataloader):  # Training
                x, y = batch 
                x, y = x.cuda(), y.cuda()
                y_hat = self.model(x)
                loss = self.criterion(y_hat, y)
                train_loss += loss.detach().cpu().item() / len(self.dataloader)

                self.optimizer.zero_grad()
                loss.backward()  # backprop calculation
                self.optimizer.step()  # Updating the weight based on these calculation.


            for batch in tqdm(self.validation_dataloader):  # Validation
                x_val, y_val = batch
                x_val, y_val = x_val.cuda(), y_val.cuda()
                pred = self.model(x_val)
                vloss = self.criterion(pred, y_val)
                
                val_loss += vloss.detach().cpu().item() / len(self.validation_dataloader)
                val_accuracy += self.measure_accuracy(pred, y_val) / len(self.validation_dataloader)
                
            self.callbacks(train_loss, val_loss, val_accuracy, epoch)

In [6]:
trainer = ModelTrainer(data_path="data", epochs=50, batch_size=192, num_workers=100, lr=0.0001, weight_decay=0.00004)

20160 Train Images found
10176 Validation Images found


In [None]:
trainer.train()

In [None]:
print("K")