# Transfer Learning using a pre-trained Deep Neural Network

### Cifar-10 dataset (https://www.cs.toronto.edu/~kriz/cifar.html)

In [None]:
%matplotlib inline
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import models,datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy
import time

In [None]:
# Check availability of GPU
use_gpu = torch.cuda.is_available()
if use_gpu:
    pinMem = True # Flag for pinning GPU memory
    print('GPU is available!')
else:
    pinMem = False

### Downloading datset

In [None]:
apply_transforms = transforms.Compose([transforms.Resize(224),
                                       transforms.ToTensor(),])
                                       #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainLoader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./CIFAR10/', train=True, download=True,
                   transform = apply_transforms), batch_size=8, shuffle=True, num_workers=1, pin_memory=pinMem)
testLoader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./CIFAR10/', train=False,
                  transform = apply_transforms), batch_size=8, shuffle=True, num_workers=1, pin_memory=pinMem)

In [None]:
# Size of train and test datasets
trainSetSize = len(trainLoader.dataset)
testSetSize = len(testLoader.dataset)
print('No. of samples in train set: '+str(trainSetSize))
print('No. of samples in test set: '+str(testSetSize))

### Initialize the network

In [None]:
# Pre-trained AlexNet
net = models.alexnet(pretrained=True)
new_classifier = nn.Sequential(*list(net.classifier.children())[:-1])
new_classifier.add_module('fc',nn.Linear(4096,10))
net.classifier = new_classifier
print(net)
if use_gpu:
    net = net.cuda()

### Define loss function and optimizer

In [None]:
criterion = nn.NLLLoss() # Negative Log-likelihood
optimizer = optim.SGD(net.parameters(), lr=1e-4, momentum=0.9) # Stochastic gradient descent

### Train the network

In [None]:
iterations = 2
trainLoss = []
testAcc = []
start = time.time()
for epoch in range(iterations):
    epochStart = time.time()
    runningLoss = 0    
    net.train(True) # For training
    for i, data in enumerate(trainLoader, 0):
#         print(str(i)+'/'+str(1000))
        if i == 1000:
            break
        inputs,labels = data
        # Wrap them in Variable
        if use_gpu:
            inputs, labels = Variable(inputs.cuda()), \
                Variable(labels.cuda())
        else:
            inputs, labels = Variable(inputs), Variable(labels)         
        # Initialize gradients to zero
        optimizer.zero_grad()
        # Feed-forward input data through the network        
        outputs = net(inputs)
        # Compute loss/error
        loss = criterion(F.log_softmax(outputs,dim=1), labels)
        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        optimizer.step()
        # Accumulate loss per batch
        runningLoss += loss.item()
        
    avgTrainLoss = runningLoss/8000.0
    trainLoss.append(avgTrainLoss)
    
    # Evaluating performance on test set for each epoch
    net.train(False) # For testing
    running_correct = 0
    for i, data in enumerate(testLoader, 0):
#         print(str(i)+'/'+str(100))
        if i == 100:
            break
        inputs,labels = data
        # Wrap them in Variable
        if use_gpu:
            inputs = Variable(inputs.cuda())
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            predicted = predicted.cpu()
        else:
            inputs = Variable(inputs)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
        running_correct += (predicted == labels).sum()      
    avgTestAcc = running_correct.numpy()/800.0
    testAcc.append(avgTestAcc)
        
    # Plotting training loss vs Epochs
    fig1 = plt.figure(1)        
    plt.plot(range(epoch+1),trainLoss,'r-',label='train')        
    if epoch==0:
        plt.legend(loc='upper left')
        plt.xlabel('Epochs')
        plt.ylabel('Training loss')   
    # Plotting testing accuracy vs Epochs
    fig2 = plt.figure(2)        
    plt.plot(range(epoch+1),testAcc,'g-',label='test')        
    if epoch==0:
        plt.legend(loc='upper left')
        plt.xlabel('Epochs')
        plt.ylabel('Testing accuracy')    
    epochEnd = time.time()-epochStart
    print('At Iteration: {:.0f} /{:.0f}  ;  Training Loss: {:.6f} ; Testing Acc: {:.3f} ; Time consumed: {:.0f}m {:.0f}s '\
          .format(epoch + 1,iterations,avgTrainLoss,avgTestAcc*100,epochEnd//60,epochEnd%60))
end = time.time()-start
print('Training completed in {:.0f}m {:.0f}s'.format(end//60,end%60))