In [83]:
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from torchvision import models
from torchvision import datasets
import matplotlib.pyplot as plt
import math
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.models as models
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from torchvision.models import vgg16
from torchvision.models import vgg19
from grad_cam  import *

In [64]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy

# Data
print('Data transformation')

transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
#trainloader = torch.utils.data.DataLoader(
 #   trainset=torch.utils.data.Subset((trainset, inds)), batch_size=128, shuffle=True, num_workers=2)

class_inds = [torch.nonzero(torch.tensor(trainset.targets) == class_idx) for class_idx in trainset.class_to_idx.values()]

trainloader = [
    torch.utils.data.DataLoader(
        dataset=torch.utils.data.Subset(trainset, inds),
        batch_size=128,
        shuffle=True,
        drop_last=False,
        num_workers=2)
    for inds in class_inds]



testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Data transformation
Files already downloaded and verified
Files already downloaded and verified


In [65]:
net=models.resnet50(pretrained=True)
teacher=models.resnet18(pretrained=True)
#teacher=VGG_19()
#net=VGG_16()
teacher_target_layers=[teacher.layer4[-1]]
student_target_layers=[net.layer4[-1]]
earning_rate=1e-3
momentum=0.9
loss=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=learning_rate,momentum=momentum) 
#optimizer=optim.Adam(net.parameters(),lr=learning_rate)
net = net.to(device)
loss_sum=0
#transform1 = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

In [66]:
def train(epoch):             
    print('\nEpoch: %d' % epoch)
    net.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
      inputs, targets = inputs.to(device), targets.to(device)
      for idx in range(len(inputs)):
        input=inputs[idx].unsqueeze(0)
        target=targets[idx].unsqueeze(0)
        #print("Batch Index",batch_idx)        
        # Write your code here
        out = net(input)
        t=Get_Model("VGG19",teacher)
        s=Get_Model("VGG16",net)
        image = get_transformed_image(input, transform1)
        l=loss(out,target)+Grad_cam_loss(input,t,s)
        print("loss:",l.item())
        #loss_sum+=loss
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

In [78]:
def train(epoch):             
    print('\nEpoch: %d' % epoch)
    net.train()
    for class_idx in range(len(trainset.class_to_idx.values())):
        for batch_idx, (inputs, targets) in enumerate(trainloader[class_idx]):
          print("Len",len(inputs))   
          print("Type",type(inputs))
          print("Shape",inputs.shape)
          inputs, targets = inputs.to(device), targets.to(device)
          out = net(inputs)
          category=targets[0]
          l=loss(out,targets) + Grad_cam_loss(inputs,teacher,net,teacher_target_layers,student_target_layers,category,torch.cuda.is_available(),True)
          print("loss:",l.item())
          #loss_sum+=loss
          optimizer.zero_grad()
          l.backward()
          optimizer.step()

In [68]:
def test(epoch):
    global best_acc
    net.eval()
    true_prediction=0
    total_sample_size=0
    loss_max=0
    sample=[]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            out=net(inputs)
            l=loss(out,targets)
            loss_max=max(l,loss_max)
            prob,prediction=torch.max(out.data,dim=1)
            true_prediction+=torch.sum(torch.stack([prediction==targets])).item()
            total_sample_size+=targets.size(0)
        accuracy=(true_prediction/total_sample_size)*100       
         # Save checkpoint for the model which yields best accuracy
        if accuracy>best_acc:
            print("Saving checkpoint with accuracy = ",accuracy)
            best_acc=accuracy
            torch.save({
                'epoch':epoch,
                'model_state_dict':net.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'loss':l,
                'accuracy':best_acc
            },'ckpt.pth')

In [69]:
def Grad_cam_loss(img,teacher,student,teacher_target_layers,student_target_layers,target_category,cuda=False,isTensor=False):
    total_loss=0    
    assert(len(teacher_target_layers)==len(student_target_layers)), 'Number of student model and teacher model target layers must be same.' 
    t_hm=heatmap(img,teacher,teacher_target_layers,target_category,cuda=cuda,isTensor=isTensor)
    s_hm=heatmap(img,student,student_target_layers,target_category,cuda=cuda,isTensor=isTensor)
    for idx in range(len(teacher_target_layers)):
        for img_idx in range(len(img)):
            t_layer_hm=t_hm[idx][img_idx]
            s_layer_hm=s_hm[idx][img_idx]
            total_loss+=sum(abs(sum(t_layer_hm-s_layer_hm)))        
    return total_loss

In [70]:
def train_test():
    best_acc=0
    for epoch in range(0,1):
        print("Training")
        train(epoch)
        print("Testing")
        test(epoch)