In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
import torchvision
import torchvision.transforms as transforms
import tarfile
import pandas as pd
import os
import re
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.metrics import confusion_matrix
from sklearn.decomposition import PCA
from io import StringIO
from PIL import Image
import re
from sklearn.metrics import accuracy_score, f1_score, precision_score
import pickle
import torchvision.models as models

from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

class CNN3(nn.Module):
    
    def __init__(self, n_classes, k=4):
        super(CNN3, self).__init__()

        self.n_classes = n_classes
        # 3x224x224
        self.cl1 = nn.Conv2d(3, 4, 3, stride=1, padding=1)
        # 4x224x224
        self.pl1 = nn.AvgPool2d(2, stride=2)
        # 4x112x112
        self.cl2 = nn.Conv2d(4, 16, 3, stride=1, padding=1)
        # 16x112x112
        self.pl2 = nn.AvgPool2d(2, stride=2)

        # NetVLAD
        self.K = k
    
        self.nv_conv = nn.Conv2d(16, self.K, 1)

        self.nv_soft_ass = nn.Softmax2d() 
        '''
        Initializes a softmax activation function along spatial dimensions (2D). 
        After the NetVLAD convolution to normalize the descriptors across spatial locations.

        Convert the raw feature activations into probability distributions,
        where each value represents the importance or weight of a particular feature descriptor within its spatial context.
        '''

        # NetVLAD Parameter
        self.c = nn.Parameter(torch.rand(self.K, 16))
        '''
        Initializes a learnable parameter c of size Kx16, where each row represents a cluster center in the NetVLAD layer. 
        This parameter will be updated during training to adaptively learn the cluster centers.
        '''
        
        # Flatten to get h
        self.flat = nn.Flatten(1, -1)

        # Output layer
        self.out = nn.Linear(self.K*16, self.n_classes)

    def forward(self, x):
        
        x = self.pl1(F.relu(self.cl1(x)))
        x = self.pl2(F.relu(self.cl2(x)))
        # print(x.shape)
        
        # NetVLAD Step 1
        a = self.nv_soft_ass(self.nv_conv(x))

        # NetVLAD Step 2
        for k in range(self.K): #loop over clusters
            a_k = a[:, k, :, :] #a (batch_size, num_channels, height, width).
            ''' Extracting Cluster Activation 
            For each cluster k,  extract the activation map corresponding to that cluster from the input activations a. 
            This is done using array slicing to select the k-th channel along the channel dimension.'''
            c_k = self.c[k, :]
            '''Extracting Cluster Center
            We extract the cluster center (c_k) corresponding to the current cluster k from the learnable parameter self.c.
            '''
            temp = (x - c_k.reshape(1, -1, 1, 1))*a_k.unsqueeze(1) 
            '''Compute Temporal Difference 
               a_k.unsqueeze(1) = compute the temporal difference between the input activations x and the cluster center c_k. 
                computed element-wise for each pixel in the input.

                Multiply this difference by the activation map a_k corresponding to the current cluster to weight the contribution of each pixel to the final output                                             '''
            z_k = torch.sum(temp, axis=(2, 3))
            '''
            We sum the weighted temporal differences across spatial dimensions (height and width) for each input image in the batch. 
            This results in a vector of size (batch_size, 16) for each cluster k, 
            where each element represents the aggregated contribution of the input pixels to that cluster.
            '''
            if k==0:
                Z = z_k.unsqueeze(1)
            else:
                Z = torch.cat((Z, z_k.unsqueeze(1)), 1) #unsqueeze adds a new dimernsion
        
        # Flatten
        Z = self.flat(Z)
        # print('Z shape', Z.shape)
        Z = self.out(Z)

        return Z


    def predict(self, x):

        y_hat = self.forward(x)
        y_hat = torch.argmax(y_hat, axis=1)
        return y_hat
    
classifier = CNN3(7)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.0001)
classifier = classifier.to(device)

with torch.no_grad():
    
    classifier.cl1.weight = params[0]
    classifier.cl1.bias = params[1]

    classifier.cl2.weight = params[2]
    classifier.cl2.bias = params[3]

old_loss = np.inf
from IPython.display import clear_output
losses = []
max_epoch = 200
for epoch in range(max_epoch):

    running_loss = 0.0
    
    for data in trainloader:
            
        X, y = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()
        
        # Forward
        y_hat = classifier(X)
        
        # Calculate Loss (Cross Entropy)
        loss = criterion(y_hat, y)

        # Backpropagation
        loss.backward()
        
        # Update Parameters
        optimizer.step()
        
        running_loss += loss.item()*len(X)/train_size

    print('Epoch', epoch+1, ': Loss =', running_loss)
    losses.append(running_loss)

    
    if (abs(running_loss-old_loss)/running_loss < 0.0001): #if (abs(running_loss-old_loss)/running_loss < 0.2) and epoch>=10 and running_loss<0.01:
        print('Converged')
        break
    
    old_loss = running_loss

print('Finished Training')
plt.plot(losses)
plt.ylabel('Loss')
plt.xlabel('Iter Number')
plt.title('Convergence monitor plot')
plt.show()

with torch.no_grad():
    
    train_loss = 0.0
    y_train = []
    y_train_pred = []

    for data in trainloader:

        X, y = data[0].to(device), data[1].to(device)
        y_hat = classifier(X)      
        train_loss += criterion(y_hat, y)*len(X)/train_size
        
        y_train.extend(list(y.detach().cpu().numpy()))
        y_train_pred.extend(list(torch.argmax(y_hat, axis=1).detach().cpu().numpy()))


print('Train Loss =', train_loss.item())
pd.DataFrame(confusion_matrix(y_train, y_train_pred))

acc_tr = accuracy_score(y_train, y_train_pred)
prec_tr = precision_score(y_train, y_train_pred, average='weighted')
f1_tr = f1_score(y_train, y_train_pred, average='weighted')

print('Train Accuracy =', acc_tr, 'Train Precision =', prec_tr, 'Train F1 =', f1_tr)
acc = accuracy_score(y_test, y_test_pred)
prec = precision_score(y_test, y_test_pred, average='macro')
f1 = f1_score(y_test, y_test_pred, average='macro')

print('Test Accuracy =', acc, 'Test Precision =', prec, 'Test F1 =', f1)
torch.save(classifier, '/content/drive/My Drive/solution_netvlad_model.pt')
        