<a href="https://colab.research.google.com/github/libra0901/loss-analysis-cka/blob/main/tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import the required libraries



In [1]:
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import torch.optim as  optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn


## Changing the device to execute on - either CPU or CUDA


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


### Loss Functions

In [3]:
class LossFunctions:
    def __init__(self, num_classes):
        super(LossFunctions, self).__init__()
        self.num_classes = int(num_classes)
    
    # Log (cross-entropy) Loss
    def cross_entropy_loss(self, x, y):
        loss = F.cross_entropy(x, y)
        return loss

    # Sum of sqaures
    def sos_loss(self, x, y):
        ones = torch.sparse.torch.eye(self.num_classes).to(device)
        y = ones.index_select(0, y)
        m = nn.Softmax(dim=1)
        criterion = nn.MSELoss(reduction='sum')
        loss = criterion(m(x), y)
        return loss
    
    # Mean Sqaured loss - L2 loss 
    def mse_loss(self, x, y):
        ones = torch.sparse.torch.eye(self.num_classes).to(device)
        y = ones.index_select(0, y)
        m = nn.Softmax(dim=1)
        criterion = nn.MSELoss(reduction='mean')
        loss = criterion(m(x), y)
        return loss
    
    # Negative log likelihood = logarithmic softmax
    def neg_loglike_loss(self, x, y):
        m = nn.LogSoftmax(dim=1)
        nll_loss = nn.NLLLoss()
        loss = nll_loss(m(x), y)
        return loss
    
    # Expectation Loss - L1 loss - Mean absoulte error
    def expectation_loss(self, x, y):
        ones = torch.sparse.torch.eye(self.num_classes).to(device)
        y = ones.index_select(0, y)
        m = nn.Softmax(dim=1)
        loss = F.l1_loss(m(x), y)
        return loss
    
    def bce_loss(self, x, y):
        ones = torch.sparse.torch.eye(self.num_classes).to(device)
        y = ones.index_select(0, y)
        loss = F.binary_cross_entropy_with_logits(x, y)
        return loss



## Assign Batch Size and required transformations for your data

In [4]:
batch_size = 256

traindata_transforms = transforms.Compose([
                        transforms.Resize((32,32)),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

testdata_transforms = transforms.Compose([
                        transforms.Resize((32,32)),  
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


## Dataloading

In [5]:
traindata = datasets.CIFAR10(root="\PATH", train=True, transform=traindata_transforms, download=True) 
testdata = datasets.CIFAR10(root="\PATH", train=False, transform=traindata_transforms, download=True)

train_dataloader = DataLoader(traindata, batch_size, shuffle=True) 
test_dataloader = DataLoader(testdata, batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to \PATH/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting \PATH/cifar-10-python.tar.gz to \PATH
Files already downloaded and verified


## Train Function

In [6]:
class Train:
    def __init__(self,
                 optimizer,
                 loss,
                 epochs:int, 
                 modelname:str,
                 dataset:str):
        
        super(Train, self).__init__()
        self.optimizer = optimizer
        self.optimizer_name = str(optimizer)
        self.loss = loss
        self.epochs = epochs
        self.modelname = modelname
    
    def model_(self, modelname:str, num_classes:int, dataset:str):
        if modelname == 'resnet50':
            model = models.resnet50(pretrained=True).to(device)
            model.fc = nn.Sequential(
                        nn.Linear(2048, 1024, bias=True),
                        nn.Dropout(),
                        nn.Linear(1024, 512, bias=True),
                        nn.Dropout(),
                        nn.Linear(512, num_classes, bias=True)
                        ).to(device)
            return model
        elif modelname =='resnet18' and dataset=='\mnist':
            model = models.resnet18(pretrained=False).to(device)
            model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).to(device)
            model.fc = nn.Linear(512, num_classes, bias=True).to(device)
            return model
        elif modelname == 'resnet18':
            model = models.resnet18(pretrained=False).to(device)
            model.fc = nn.Linear(512, num_classes, bias=True).to(device)
            return model
        else:
            print('ERROR_model_or_dataset_is_not_found')
      
    def top1_accuracy(self, outputs, labels):
        _, preds = torch.max(outputs, dim=1)
        acc = torch.tensor(torch.sum(preds == labels).item() / len(preds))
        return acc
    
    @torch.no_grad()
    def test(self, model, test_dl, dataset):
        model.eval()
        for batch in test_dl:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            loss = self.loss(outputs, labels)
            acc = self.top1_accuracy(outputs, labels)
        return loss, acc
    
    # Training Loop 
    def train(self, train_dl, test_dl, num_classes, filename, dataset):
        history = []
        since = time.time()
        model = self.model_(modelname=self.modelname,
                            num_classes = int(num_classes),
                            dataset = dataset)
        if self.optimizer_name == 'SGD':
            optimizer = self.optimizer(model.parameters(), lr=0.0001, momentum=0.93)
        else:
            optimizer = self.optimizer(model.parameters(), lr=0.001)
        
        for epoch in range(self.epochs):
            model.train()
            train_loss = []
            train_acc = []
            test_loss = []
            test_acc = []
            result = {}
            with tqdm(train_dl, unit="batch") as loop:
              for batch in loop:
                  inputs, labels = batch
                  inputs = inputs.to(device)
                  labels = labels.to(device)
                  outputs = model(inputs)
                  loss = self.loss(outputs, labels)
                  acc = self.top1_accuracy(outputs, labels)
                  train_acc.append(acc.cpu().detach().numpy())
                  train_loss.append(loss.cpu().detach().numpy())
                  loop.set_description(f"Epoch [{epoch}/{self.epochs}]")
                  loop.set_postfix(train_loss=np.average(train_loss),train_acc=np.average(train_acc))
                  loss.backward()
                  optimizer.step()
                  optimizer.zero_grad()
              
              test_losses,test_accu = self.test(model, test_dl, dataset)
              test_loss.append(test_losses.cpu().detach().numpy())
              test_acc.append(test_accu.cpu().detach().numpy())       
              result['train_loss'] = np.average(train_loss)
              result['train_acc'] = np.average(train_acc)
              result['test_loss'] = np.average(test_loss)
              result['test_acc'] = np.average(test_acc)
              print('\nEpoch',epoch,result)
              history.append(result)
            
        time_elapsed = time.time() - since
        print('Training Completed in {:.0f} min {:.0f} sec'.format(time_elapsed//60, time_elapsed%60))
        return history

## Assign:<br/>
-----------------------------------------------------------------------<br/>
Number of classes (w.r.t dataset) = n : int <br/>
-----------------------------------------------------------------------<br/>
Loss Function = (from set of declared 6 loss fucntions, assign one)<br/>
1.   L<sub>1</sub> = <br/>
2.   L<sub>2</sub> = <br/>
3.   Softmax Cross Entropy = <br/>
4.   Binary Cross Entropy = <br/>
5.   Mean-Squared-Error = <br/>
6.   Sum-of-Sqaures = <br/>
-----------------------------------------------------------------------<br/>
## Execute the model

In [7]:
num_classes = 10
lf = LossFunctions(num_classes)
loss = lf.cross_entropy_loss
batch_size = 512
optimizer = optim.Adam
epochs = 15
modelname = 'resnet18'
dataset = '\cifar10'

# Declaring Loss Function
l = LossFunctions(num_classes)
# Training 
t = Train(optimizer = optimizer,
          loss = loss,
          epochs = epochs,
          modelname = modelname,
          dataset = dataset,
          )
filename = dataset[1:] + modelname

history = t.train(
            train_dl = train_dataloader,
            test_dl = test_dataloader,
            num_classes = num_classes,
            filename = filename,
            dataset = dataset
            )

Epoch [0/15]: 100%|██████████| 196/196 [00:27<00:00,  7.06batch/s, train_acc=0.505, train_loss=1.38]



Epoch 0 {'train_loss': 1.3775111, 'train_acc': 0.5046157, 'test_loss': 1.0798844, 'test_acc': 0.5625}


Epoch [1/15]: 100%|██████████| 196/196 [00:28<00:00,  7.00batch/s, train_acc=0.655, train_loss=0.973]



Epoch 1 {'train_loss': 0.9733933, 'train_acc': 0.65464765, 'test_loss': 1.033646, 'test_acc': 0.625}


Epoch [2/15]: 100%|██████████| 196/196 [00:27<00:00,  7.09batch/s, train_acc=0.721, train_loss=0.793]



Epoch 2 {'train_loss': 0.7927615, 'train_acc': 0.7207351, 'test_loss': 0.4255637, 'test_acc': 0.875}


Epoch [3/15]: 100%|██████████| 196/196 [00:27<00:00,  7.09batch/s, train_acc=0.767, train_loss=0.662]



Epoch 3 {'train_loss': 0.6617474, 'train_acc': 0.7674266, 'test_loss': 0.34082595, 'test_acc': 0.875}


Epoch [4/15]: 100%|██████████| 196/196 [00:27<00:00,  7.13batch/s, train_acc=0.809, train_loss=0.546]



Epoch 4 {'train_loss': 0.5463242, 'train_acc': 0.80891263, 'test_loss': 0.6106515, 'test_acc': 0.75}


Epoch [5/15]: 100%|██████████| 196/196 [00:27<00:00,  7.12batch/s, train_acc=0.841, train_loss=0.454]



Epoch 5 {'train_loss': 0.4536835, 'train_acc': 0.8408761, 'test_loss': 0.45053098, 'test_acc': 0.8125}


Epoch [6/15]: 100%|██████████| 196/196 [00:27<00:00,  7.13batch/s, train_acc=0.866, train_loss=0.381]



Epoch 6 {'train_loss': 0.3811563, 'train_acc': 0.86615115, 'test_loss': 0.45118195, 'test_acc': 0.75}


Epoch [7/15]: 100%|██████████| 196/196 [00:27<00:00,  7.11batch/s, train_acc=0.895, train_loss=0.3]



Epoch 7 {'train_loss': 0.30007794, 'train_acc': 0.89517295, 'test_loss': 0.6390468, 'test_acc': 0.6875}


Epoch [8/15]: 100%|██████████| 196/196 [00:27<00:00,  7.13batch/s, train_acc=0.914, train_loss=0.243]



Epoch 8 {'train_loss': 0.24327156, 'train_acc': 0.9141024, 'test_loss': 1.2660936, 'test_acc': 0.625}


Epoch [9/15]: 100%|██████████| 196/196 [00:27<00:00,  7.13batch/s, train_acc=0.929, train_loss=0.2]



Epoch 9 {'train_loss': 0.19998111, 'train_acc': 0.9286113, 'test_loss': 0.72094035, 'test_acc': 0.8125}


Epoch [10/15]: 100%|██████████| 196/196 [00:27<00:00,  7.10batch/s, train_acc=0.943, train_loss=0.163]



Epoch 10 {'train_loss': 0.16342404, 'train_acc': 0.94333947, 'test_loss': 0.41008183, 'test_acc': 0.875}


Epoch [11/15]: 100%|██████████| 196/196 [00:27<00:00,  7.09batch/s, train_acc=0.952, train_loss=0.137]



Epoch 11 {'train_loss': 0.13715476, 'train_acc': 0.9522122, 'test_loss': 0.77681756, 'test_acc': 0.8125}


Epoch [12/15]: 100%|██████████| 196/196 [00:27<00:00,  7.09batch/s, train_acc=0.957, train_loss=0.122]



Epoch 12 {'train_loss': 0.122180365, 'train_acc': 0.9572744, 'test_loss': 0.7432883, 'test_acc': 0.75}


Epoch [13/15]: 100%|██████████| 196/196 [00:27<00:00,  7.10batch/s, train_acc=0.966, train_loss=0.0959]



Epoch 13 {'train_loss': 0.09590746, 'train_acc': 0.9661153, 'test_loss': 1.1599112, 'test_acc': 0.75}


Epoch [14/15]: 100%|██████████| 196/196 [00:27<00:00,  7.14batch/s, train_acc=0.966, train_loss=0.0967]



Epoch 14 {'train_loss': 0.09666071, 'train_acc': 0.96574855, 'test_loss': 1.2463706, 'test_acc': 0.6875}
Training Completed in 8 min 11 sec


## Visualize the Model's **CKA**


