<a href="https://colab.research.google.com/github/inverter404/bosch-interiit/blob/main/Action_Classification_Black_Box_Settings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#REQUIREMENTS

##INSTALLING LIBRARIES

In [None]:
!pip install -q timm
!pip install -q einops
!pip install -q -I mmcv==1.4.0
!pip install -q pytorchvideo

##IMPORTING LIBRARIES

In [None]:
import json
from typing import Dict
from collections import OrderedDict

from mmcv import Config, DictAction
from mmaction.models import build_model
from mmcv.runner import get_dist_info, init_dist, load_checkpoint
from mmcv.parallel import MMDataParallel

import torch
import torch.nn as nn
from torch.nn import Parameter
from torch.nn import functional as F
from torch.nn import init
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)

from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample
    # UniformCropVideo
)

# Swin Transformer

In [None]:
!rm -rf Video-Swin-Transformer
!git clone https://github.com/SwinTransformer/Video-Swin-Transformer.git
%cd Video-Swin-Transformer
!rm -rf /content/checkpoints
!mkdir /content/checkpoints 
%cd /content/checkpoints
!wget -q https://github.com/SwinTransformer/storage/releases/download/v1.0.4/swin_tiny_patch244_window877_kinetics400_1k.pth
%cd /content/

In [None]:
config = '/content/Video-Swin-Transformer/configs/recognition/swin/swin_tiny_patch244_window877_kinetics400_1k.py'
checkpoint = '/content/checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth'

cfg = Config.fromfile(config)
model = build_model(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))



class SwinT(nn.Module):
    def __init__(self, model):
        super(SwinT, self).__init__()
        self.backbone = model
        self.cls_head = nn.Linear(768, 400)
        self.dropout = nn.Dropout(p=0.5, inplace=False)
        self.pool = nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))

    def forward(self,x):
        feat = self.backbone(x)
        feat = self.dropout(feat)
        feat = self.pool(feat)
        feat = feat.view(-1, 768)
        return self.cls_head(feat)

black_box_model = SwinT(model.backbone)

new_state_dict = OrderedDict()
checkpoint = torch.load("/content/checkpoints/swin_tiny_patch244_window877_kinetics400_1k.pth")
for k, v in checkpoint['state_dict'].items():
    if 'backbone' in k:
        name = k
        new_state_dict[name] = v 

new_state_dict["cls_head.weight"] = checkpoint["state_dict"]["cls_head.fc_cls.weight"]
new_state_dict["cls_head.bias"] = checkpoint["state_dict"]["cls_head.fc_cls.bias"]

black_box_model.load_state_dict(new_state_dict)

/content/checkpoints
/content


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>

# Generator 
(DVD-GAN: [Source](https://github.com/Harrypotterrrr/DVD-GAN))

### Normalization

In [None]:
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

class ConditionalNorm(nn.Module):

    def __init__(self, in_channel, n_condition=96):
        super().__init__()

        self.in_channel = in_channel
        self.bn = nn.BatchNorm2d(self.in_channel, affine=False)

        self.embed = nn.Linear(n_condition, self.in_channel * 2)
        self.embed.weight.data[:, :self.in_channel].normal_(1, 0.02)
        self.embed.weight.data[:, self.in_channel:].zero_()

    def forward(self, x, class_id):
        out = self.bn(x)
        embed = self.embed(class_id)
        gamma, beta = embed.chunk(2, 1)
        gamma = gamma.view(-1, self.in_channel, 1, 1)
        beta = beta.view(-1, self.in_channel, 1, 1)
        out = gamma * out + beta

        return out


## GResBlock

In [None]:
class GResBlock(nn.Module):

    def __init__(self, in_channel, out_channel, kernel_size=None,
                 padding=1, stride=1, n_class=96, bn=True,
                 activation=F.relu, upsample_factor=2, downsample_factor=1):
        super().__init__()

        self.upsample_factor = upsample_factor if downsample_factor is 1 else 1
        self.downsample_factor = downsample_factor
        self.activation = activation
        self.bn = bn if downsample_factor is 1 else False

        if kernel_size is None:
            kernel_size = [3, 3]

        self.conv0 = SpectralNorm(nn.Conv2d(in_channel, out_channel,
                                             kernel_size, stride, padding,
                                             bias=True if bn else True))
        self.conv1 = SpectralNorm(nn.Conv2d(out_channel, out_channel,
                                             kernel_size, stride, padding,
                                             bias=True if bn else True))

        self.skip_proj = True
        self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))



        if bn:
            self.CBNorm1 = ConditionalNorm(in_channel, n_class) # TODO 2 x noise.size[1]
            self.CBNorm2 = ConditionalNorm(out_channel, n_class)

    def forward(self, x, condition=None):

        # The time dimension is combined with the batch dimension here, so each frame proceeds
        # through the blocks independently
        BT, C, W, H = x.size()
        out = x

        if self.bn:
            out = self.CBNorm1(out, condition)

        out = self.activation(out)

        if self.upsample_factor != 1:
            out = F.interpolate(out, scale_factor=self.upsample_factor)

        out = self.conv0(out)

        if self.bn:
            out = out.view(BT, -1, W * self.upsample_factor, H * self.upsample_factor)
            out = self.CBNorm2(out, condition)

        out = self.activation(out)
        out = self.conv1(out)

        if self.downsample_factor != 1:
            out = F.avg_pool2d(out, self.downsample_factor)

        if self.skip_proj:
            skip = x
            if self.upsample_factor != 1:
                skip = F.interpolate(skip, scale_factor=self.upsample_factor)
            skip = self.conv_sc(skip)
            if self.downsample_factor != 1:
                skip = F.avg_pool2d(skip, self.downsample_factor)
        else:
            skip = x

        y = out + skip
        y = y.view(
            BT, -1,
            W * self.upsample_factor // self.downsample_factor,
            H * self.upsample_factor // self.downsample_factor
        )

        return y


## ConvGRUCell

In [None]:
class ConvGRUCell(nn.Module):
    """
    Generate a convolutional GRU cell
    """

    def __init__(self, input_size, hidden_size, kernel_size, activation=torch.sigmoid):

        super().__init__()
        padding = kernel_size // 2
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding)
        self.activation = activation

        init.orthogonal_(self.reset_gate.weight)
        init.orthogonal_(self.update_gate.weight)
        init.orthogonal_(self.out_gate.weight)
        init.constant_(self.reset_gate.bias, 0.)
        init.constant_(self.update_gate.bias, 0.)
        init.constant_(self.out_gate.bias, 0.)


    def forward(self, x, prev_state=None):

        if prev_state is None:

            # get batch and spatial sizes
            batch_size = x.data.size()[0]
            spatial_size = x.data.size()[2:]

            # generate empty prev_state, if None is provided
            state_size = [batch_size, self.hidden_size] + list(spatial_size)
            # prev_state = torch.zeros(state_size)

            if torch.cuda.is_available():
                prev_state = torch.zeros(state_size).cuda()
            else:
                prev_state = torch.zeros(state_size)

        # data size is [batch, channel, height, width]
        stacked_inputs = torch.cat([x, prev_state], dim=1)

        update = self.activation(self.update_gate(stacked_inputs))
        reset = self.activation(self.reset_gate(stacked_inputs))
        out_inputs = torch.tanh(self.out_gate(torch.cat([x, prev_state * reset], dim=1)))
        new_state = prev_state * (1 - update) + out_inputs * update

        return new_state


class ConvGRU(nn.Module):

    def __init__(self, input_size, hidden_sizes, kernel_sizes, n_layers):
        """
        Generates a multi-layer convolutional GRU.

        Parameters
        -------
        input_size : integer. depth dimension of input tensors.
        hidden_sizes : integer or list. depth dimensions of hidden state.
                      if integer, the same hidden size is used for all cells.
        kernel_sizes : integer or list. sizes of Conv2d gate kernels.
                      if integer, the same kernel size is used for all cells.
        n_layers : integer. number of chained `ConvGRUCell`.
        """

        super().__init__()

        self.input_size = input_size

        if type(hidden_sizes) != list:
            self.hidden_sizes = [hidden_sizes]*n_layers
        else:
            assert len(hidden_sizes) == n_layers, '`hidden_sizes` must have the same length as n_layers'
            self.hidden_sizes = hidden_sizes
        if type(kernel_sizes) != list:
            self.kernel_sizes = [kernel_sizes]*n_layers
        else:
            assert len(kernel_sizes) == n_layers, '`kernel_sizes` must have the same length as n_layers'
            self.kernel_sizes = kernel_sizes

        self.n_layers = n_layers

        cells = nn.ModuleList()

        for i in range(self.n_layers):
            if i == 0:
                input_dim = self.input_size
            else:
                input_dim = self.hidden_sizes[i-1]

            cell = ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i])
            cells.append(cell)

        self.cells = cells


    def forward(self, x, hidden=None):
        '''
        Parameters
        -------
        x : 4D input tensor. (batch, channels, height, width).
        hidden : list of 4D hidden state representations. (batch, channels, height, width).

        Returns
        -------
        upd_hidden : 5D hidden representation. (layer, batch, channels, height, width).
        '''

        input_ = x
        output = []

        if hidden is None:
            hidden = [None] * self.n_layers

        for i in range(self.n_layers):

            cell = self.cells[i]
            cell_hidden = hidden[i]

            # pass through layer
            upd_cell_hidden = cell(input_, cell_hidden) # TODO comment
            output.append(upd_cell_hidden)
            # update input_ to the last updated hidden layer for next pass
            input_ = upd_cell_hidden

        # retain tensors in list to allow different hidden sizes
        return output



## Generator

In [None]:
class Generator(nn.Module):

    def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hierar_flag=False):
        '''
        Parameters
        -------
        in_dim : dimension of the input vector
        latent_dim : no of classes to be considered
        ch : produced video channels
        n_frames : no of frames for produced video
        '''
        super().__init__()

        self.in_dim = in_dim
        self.latent_dim = latent_dim
        self.n_class = n_class
        self.ch = ch
        self.hierar_flag = hierar_flag
        self.n_frames = n_frames

        self.embedding = nn.Embedding(n_class, in_dim)

        self.affine_transfrom = nn.Linear(in_dim * 2, latent_dim * latent_dim * 8 * ch)

        self.conv = nn.ModuleList([
            ConvGRU(8 * ch, hidden_sizes=[8 * ch, 16 * ch, 8 * ch], kernel_sizes=[3, 5, 3], n_layers=3),
            # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2),
            GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1),
            GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2),
            ConvGRU(8 * ch, hidden_sizes=[8 * ch, 16 * ch, 8 * ch], kernel_sizes=[3, 5, 3], n_layers=3),
            # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2),
            GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1),
            GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2),
            ConvGRU(8 * ch, hidden_sizes=[8 * ch, 16 * ch, 8 * ch], kernel_sizes=[3, 5, 3], n_layers=3),
            # ConvGRU(8 * ch, hidden_sizes=[8 * ch, 8 * ch], kernel_sizes=[3, 3], n_layers=2),
            GResBlock(8 * ch, 8 * ch, n_class=in_dim * 2, upsample_factor=1),
            GResBlock(8 * ch, 4 * ch, n_class=in_dim * 2),
            ConvGRU(4 * ch, hidden_sizes=[4 * ch, 8 * ch, 4 * ch], kernel_sizes=[3, 5, 5], n_layers=3),
            # ConvGRU(4 * ch, hidden_sizes=[4 * ch, 4 * ch], kernel_sizes=[3, 5], n_layers=2),
            GResBlock(4 * ch, 4 * ch, n_class=in_dim * 2, upsample_factor=1),
            GResBlock(4 * ch, 2 * ch, n_class=in_dim * 2)
        ])

        # TODO impl ScaledCrossReplicaBatchNorm
        # self.ScaledCrossReplicaBN = ScaledCrossReplicaBatchNorm2d(1 * chn)

        self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1))


    def forward(self, x, class_id):

        '''
        Parameters
        -------
        x : input vector
        class_id : video class to produced
        
        Return
        -------
        y : video tensor
        '''
        if self.hierar_flag is True:
            noise_emb = torch.split(x, self.in_dim, dim=1)
        else:
            noise_emb = x

        class_emb = self.embedding(class_id)

        if self.hierar_flag is True:
            y = self.affine_transfrom(torch.cat((noise_emb[0], class_emb), dim=1)) # B x (2 x ld x ch)
        else:
            y = self.affine_transfrom(torch.cat((noise_emb, class_emb), dim=1)) # B x (2 x ld x ch)

        y = y.view(-1, 8 * self.ch, self.latent_dim, self.latent_dim) # B x ch x ld x ld

        for k, conv in enumerate(self.conv):
            if isinstance(conv, ConvGRU):

                if k > 0:
                    _, C, W, H = y.size()
                    y = y.view(-1, self.n_frames, C, W, H).contiguous()

                frame_list = []
                for i in range(self.n_frames):
                    if k == 0:
                        if i == 0:
                            frame_list.append(conv(y))  # T x [B x ch x ld x ld]
                        else:
                            frame_list.append(conv(y, frame_list[i - 1]))
                    else:
                        if i == 0:
                            frame_list.append(conv(y[:,0,:,:,:].squeeze(1)))  # T x [B x ch x ld x ld]
                        else:
                            frame_list.append(conv(y[:,i,:,:,:].squeeze(1), frame_list[i - 1]))
                frame_hidden_list = []
                for i in frame_list:
                    frame_hidden_list.append(i[-1].unsqueeze(0))
                y = torch.cat(frame_hidden_list, dim=0) # T x B x ch x ld x ld

                y = y.permute(1, 0, 2, 3, 4).contiguous() # B x T x ch x ld x ld
                # print(y.size())
                B, T, C, W, H = y.size()
                y = y.view(-1, C, W, H)

            elif isinstance(conv, GResBlock):
                condition = torch.cat([noise_emb, class_emb], dim=1)
                condition = condition.repeat(self.n_frames,1)
                y = conv(y, condition) # BT, C, W, H

        y = F.relu(y)
        y = self.colorize(y)
        y = torch.tanh(y)

        BT, C, W, H = y.size()
        y = y.view(-1, self.n_frames, C, W, H) # B, T, C, W, H

        return y

In [None]:
torch.cuda.empty_cache()

# SlowFast Model

### Import Modules

In [None]:
##########################
##  SlowFast transform  ##
##########################

side_size = 256
mean = [0, 0, 0]
std = [1, 1, 1]
crop_size = 256
num_frames = 32
sampling_rate = 2
frames_per_second = 30
alpha = 4

class PackPathway(torch.nn.Module):
    """
    Transform for converting video frames as a list of tensors.
    """
    def __init__(self):
        super().__init__()

    def forward(self, frames: torch.Tensor):
        fast_pathway = frames
        # Perform temporal sampling from the fast pathway.
        slow_pathway = torch.index_select(
            frames,
            1,
            torch.linspace(
                0, frames.shape[1] - 1, frames.shape[1] // alpha
            ).long(),
        )
        frame_list = [slow_pathway, fast_pathway]
        return frame_list
  
transform =  ApplyTransformToKey(
    key="video",
    transform=Compose(
        [
            UniformTemporalSubsample(num_frames),
            NormalizeVideo(mean, std),
            ShortSideScale(
                size=side_size
            ),
            PackPathway()
        ]
    ),
)

# The duration of the input clip is also specific to the model.
clip_duration = (num_frames * sampling_rate)/frames_per_second

## Training Loop


generator's output : [batch, 10, 3, 64, 64]

tensorflow_model_input : [batch, 8, 224, 224, 3]

pytorch_model_input : list of [batch, 3, 8, 256, 256]

In [None]:
!cp -r /content/drive/MyDrive/test /content/

In [None]:
!pip install -q pytorch_ranger

In [None]:
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from pytorch_ranger import Ranger

torch.cuda.empty_cache()

In [None]:
np.random.seed(1)
# torch.manual_seed(1)
# torch.cuda.manual_seed(1)

torch.backends.cudnn.benchmark = True

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def load_state(wt_path):
  checkpoint = torch.load(wt_path)
  generator.load_state_dict(checkpoint['generator_state'])
  student.load_state_dict(checkpoint['student_state'])
  gen_opt.load_state_dict(checkpoint['gen_optimizer'])
  stud_opt.load_state_dict(checkpoint['stud_optimizer'])
  return checkpoint['datapoints']

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
black_box_model = black_box_model.to(device)
num_epochs = 20
gen_iter = 1
stud_iter = 3
n_datapoints = 100

batch_size = 4
in_dim = 120
n_class = 400
n_frames = 5

Loss = nn.L1Loss()
scaler = torch.cuda.amp.GradScaler()

#Initializing Generator
class_label = torch.randint(low=0, high=399, size=(batch_size,)).to(device)
generator = Generator(in_dim, n_class=n_class, ch=3, n_frames=n_frames).to(device)

model_name = "slowfast_r50"
student = torch.hub.load("facebookresearch/pytorchvideo", model=model_name, pretrained=True)
# For kinetics 400 
student.blocks[6].proj = nn.Linear(in_features=2304, out_features=400, bias=True)
student = student.to(device)
black_box_model.eval()
for param in generator.parameters():
    param.requires_grad=True  

for param in student.parameters():
    param.requires_grad=True

for param in black_box_model.parameters():
    param.requires_grad=False

gen_opt = Ranger(generator.parameters(), lr=5e-4, weight_decay=0.0001)


stud_opt = Ranger(student.parameters(), lr=8e-5, weight_decay=0.0001)

gen_loss_per_dp = []
stud_loss_per_dp = []

# torch.autograd.detect_anomaly()
dp_done = load_state('/content/drive/MyDrive/model_dp56.pth')
# dp_done = 0
tk0 = tqdm(range(n_datapoints), desc='Datapoints')
for datapoints in tk0:
    if(datapoints < dp_done):
      continue
    print('x'*80)
    print('x'*80)
    batch = torch.randn(batch_size, in_dim).to(device) #fix this
    gen_epoch_loss = []
    stud_epoch_loss = []
    tk1 = tqdm(range(num_epochs), desc='Epochs')
    for epoch in tk1:
      print("="*80)
      print(f'Starting Epoch: {epoch+1} / {num_epochs}')
      print("="*80)

      student.eval()
      generator.train()
      gen_loss_arr = []
      tk2 = tqdm(range(gen_iter), desc='Generator training...')
      for _ in tk2:
        y = generator(batch, class_label)
        ## Prediction API Output ([batch_size, 400])
        with torch.no_grad():
          pred_vic = black_box_model(y.permute(0, 2, 1, 3, 4)).to(device)
        pyt_in = y.permute(0, 2, 1, 3, 4).cpu()
        pred_att = []
        for i in range(pyt_in.shape[0]):
            video_data = transform({"video": pyt_in[i]})
            inputs = video_data["video"]
            inputs = [inp.to(device)[None, ...] for inp in inputs]
            pred_att.append(student(inputs))
        pred_att = torch.cat(pred_att, 0)
        pred_att = pred_att.to(device)

        gen_loss = (100-Loss(pred_att, pred_vic)) / 100
        
        scaler.scale(gen_loss).backward()
        nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        scaler.step(gen_opt)
        scaler.update()
  
        torch.cuda.empty_cache()
        gen_loss_arr.append(gen_loss.detach().cpu().numpy())
        tk2.set_postfix(loss = gen_loss.detach().cpu().numpy())

      gen_epoch_loss.append(np.mean(gen_loss_arr))
  
      student.train()
      generator.eval()

      stud_loss_arr = []
      tk3 = tqdm(range(stud_iter), desc='Student training...')
      for _ in tk3:
        y = generator(batch, class_label)
        ## Prediction API Output ([batch_size, 400])
        with torch.no_grad():
          pred_vic = black_box_model(y.permute(0, 2, 1, 3, 4)).to(device)
        pyt_in = y.permute(0, 2, 1, 3, 4).cpu()
        pred_att = []
        for i in range(pyt_in.shape[0]):
            video_data = transform({"video": pyt_in[i]})
            inputs = video_data["video"]
            inputs = [inp.to(device)[None, ...] for inp in inputs]
            pred_att.append(student(inputs))
        pred_att = torch.cat(pred_att, 0)
        pred_att = pred_att.to(device)

        stud_loss = Loss(pred_att, pred_vic) 
        
        scaler.scale(stud_loss).backward()
        nn.utils.clip_grad_norm_(student.parameters(), max_norm=2.0)
        scaler.step(stud_opt)#.step()  
        scaler.update()

        torch.cuda.empty_cache()
        stud_loss_arr.append(stud_loss.detach().cpu().numpy())
        tk3.set_postfix(loss = stud_loss.detach().cpu().numpy())
      stud_epoch_loss.append(np.mean(stud_loss_arr))
      torch.cuda.empty_cache()
      
      print(f'Student Loss: {stud_epoch_loss[-1]} \t Generator Loss: {gen_epoch_loss[-1]}')
    gen_loss_per_dp.append(np.mean(gen_epoch_loss))
    stud_loss_per_dp.append(np.mean(stud_epoch_loss))
    print("x"*80)
    print("Finished another datapoint iteration....Loading metrics....")
    print(f'Student Loss: {stud_loss_per_dp[-1]} \t Generator Loss: {gen_loss_per_dp[-1]}')
    print("x"*80)
    plt.plot(stud_epoch_loss, label="Student loss")
    plt.plot(gen_epoch_loss, label="Generator loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()
    save_state = {'datapoints':datapoints, 'student_state':student.state_dict(),
             'generator_state':generator.state_dict(), 'stud_optimizer':stud_opt.state_dict(),
             'gen_optimizer':gen_opt.state_dict()}
    print("Saving model...")
    torch.save(save_state, f'/content/drive/MyDrive/SwinT Weights/model_dp{datapoints}.pth')
plt.plot(stud_loss_per_dp, label="Student loss")
plt.plot(gen_loss_per_dp, label="Generator loss")
plt.xlabel("DataPoints")
plt.ylabel("Loss")
plt.legend()
plt.show()