In [1]:
from IPython.display import clear_output

In [2]:
# !mkdir ~/.kaggle
# !mv kaggle.json ~/.kaggle/kaggle.json
# !pip3 install kaggle
# !kaggle datasets download -d imran2002/imagenet-top50-400train-50val
# clear_output()

In [3]:
# !apt-get update
# !apt-get install p7zip-full -y
# !7z x imagenet-top50-400train-50val.zip -odata/
# !rm imagenet-top50-400train-50val
# clear_output()

In [4]:
# !git clone https://github.com/imrnh/beyond_intuition_code.git
# !cp -r beyond_intuition_code/* /workspace/
# !rm -rf beyond_intuition_code
# clear_output()

In [5]:
!pip3 install h5py einops tqdm
clear_output()

In [6]:
import warnings
warnings.filterwarnings('ignore')

# **Training Loop**

In [7]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader

import os
import numpy as np
from tqdm.auto import tqdm
from PIL import Image
from utils.trainer_callback import Callback
from utils.model_loaders import vit_base_patch16_224

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f04ddbb57b0>

In [8]:
def embd_to_class(batched_embed):
    class_indices = []
    for embd in batched_embed:
        cls_idx = torch.argmax(embd)
        class_indices.append(cls_idx)
    return torch.Tensor(class_indices)

def measure_accuracy(yhat, y):
    yhat_idx = embd_to_class(yhat)

    accuracy = 0.0
    for yh, yval in zip(yhat_idx, y):
        accuracy += 1 if (yh == yval) else 0
        
    return accuracy

In [9]:
class ModelTrainer:
    def __init__(self, data_path, epochs, batch_size, val_batch_size ,num_workers, lr, weight_decay) -> None:
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        self.data_path = data_path
        self.callbacks = Callback("train_logs.txt", "custom_trained_models/", True, True, False, 0, 0)

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.epochs = epochs
        
        self.lr = lr
        self.weight_decay = weight_decay
        
        # Model, Optimizer and Loss function setup
        self.model = vit_base_patch16_224(pretrained=False, num_classes=50).cuda()

        self.optimizer = Adam(self.model.parameters(), lr=self.lr, weight_decay=weight_decay)
        self.criterion = CrossEntropyLoss()

        # Data loading and transforms
        self.image_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])

        self.dataset = datasets.ImageFolder(root=self.data_path + "/train", transform=self.image_transform)
        self.dataloader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False)

        self.validation_dataset = datasets.ImageFolder(root=self.data_path + "/val", transform=self.image_transform)
        self.validation_dataloader = DataLoader(self.validation_dataset, batch_size= val_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False)
    
        print(f"{len(self.dataloader) * batch_size } Train Images found")
        print(f"{len(self.validation_dataloader) * val_batch_size} Validation Images found")
        
        
    """
        Train the model and finally save it.
    """
    def train(self):
        for epoch in range(self.epochs):
            train_loss, val_loss, val_accuracy = 0.0, 0.0, 0.0

            for batch in tqdm(self.dataloader):  # Training
                x, y = batch 
                x, y = x.cuda(), y.cuda()
                y_hat = self.model(x)
                loss = self.criterion(y_hat, y)
                train_loss += loss.detach().cpu().item() / len(self.dataloader)

                self.optimizer.zero_grad()
                loss.backward()  # backprop calculation
                self.optimizer.step()  # Updating the weight based on these calculation.


            for batch in tqdm(self.validation_dataloader):  # Validation
                x_val, y_val = batch
                x_val, y_val = x_val.cuda(), y_val.cuda()
                pred = self.model(x_val)
                vloss = self.criterion(pred, y_val)
                
                val_loss += vloss.detach().cpu().item() / len(self.validation_dataloader)
                val_accuracy += measure_accuracy(pred, y_val) / len(self.validation_dataloader)
                
            self.callbacks.__call__(self.model, "r", epoch, 
                                    {'train_loss': train_loss, 'validation_loss': val_loss, 'validation_accuracy': val_accuracy})

In [10]:
trainer = ModelTrainer(data_path="data", epochs=50, batch_size=96, val_batch_size=16, num_workers=50, lr=0.001, weight_decay=0.01)

20064 Train Images found
2512 Validation Images found


In [None]:
trainer.train()

  0%|          | 0/157 [00:00<?, ?it/s]

{'train_loss': 0.0, 'validation_loss': 4.077100434880348, 'validation_accuracy': 0.3375796178343949}


  0%|          | 0/157 [00:00<?, ?it/s]

{'train_loss': 0.0, 'validation_loss': 4.077100434880348, 'validation_accuracy': 0.3375796178343949}


  0%|          | 0/157 [00:00<?, ?it/s]