**Transfer Learning example with CIFAR10 Dataset using PyTorch and ResNet50**

In [1]:
import torch
import numpy as np
from torchvision import datasets 
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [2]:
#Transform operation to normalized torch.FloatTensor
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
#Downloading datasets
train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10('data', train=False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


In [4]:
#Downloading the model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)

Downloading: "https://github.com/pytorch/vision/archive/v0.6.0.zip" to /root/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth


HBox(children=(FloatProgress(value=0.0, max=102502400.0), HTML(value='')))




In [5]:
#Spliting data
len_train=len(train_data)
index = list(range(len_train))
np.random.shuffle(index)
split = int(np.floor(0.2*len_train)) 
train_idx, valid_idx = index[split:], index[:split]

In [6]:
#Samplers and Dataloaders
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=30,sampler = train_sampler)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=30,sampler = valid_sampler)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=30,sampler = train_sampler)

In [7]:
cls = ['airplane', 'automobile','bird','cat','deer','dog','frog','horse','ship','truck']

In [8]:
for param in model.parameters():
  param.requires_grad=False


In [9]:
print(model.fc)

Linear(in_features=2048, out_features=1000, bias=True)


In [10]:
model.fc = nn.Sequential(nn.Linear(2048,1000), nn.ReLU(),nn.Dropout(0.4),nn.Linear(1000,10))

In [11]:
#Sending the model to CUDA
if torch.cuda.is_available():model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(),lr=0.001)

In [13]:
n_epochs = 30

valid_loss_min = np.Inf

for epoch in range(n_epochs):
  train_loss = 0
  valid_loss = 0

  model.train()
  for data, target in train_loader:
    if torch.cuda.is_available(): data, target = data.cuda(), target.cuda()
    optimizer.zero_grad()
    out=model(data)
    loss = criterion(out, target)
    loss.backward()
    optimizer.step()
    train_loss +=loss.item()*data.size(0)

  model.eval()
  for data, target in valid_loader:
    if torch.cuda.is_available(): data, target = data.cuda(), target.cuda()
    out=model(data)
    loss = criterion(out, target)
    valid_loss +=loss.item()*data.size(0)

  
  train_loss = train_loss/len(train_loader.sampler)
  valid_loss = valid_loss/len(valid_loader.sampler)

  print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:6f}'.format(epoch, train_loss, valid_loss))

  if valid_loss <= valid_loss_min: print('Saving. Valid loss {:.6f}'.format(valid_loss))


Epoch: 0 	Training Loss: 1.000123 	Validation Loss: 0.693187
Saving. Valid loss 0.693187
Epoch: 1 	Training Loss: 0.840692 	Validation Loss: 0.677902
Saving. Valid loss 0.677902
Epoch: 2 	Training Loss: 0.811567 	Validation Loss: 0.645140
Saving. Valid loss 0.645140
Epoch: 3 	Training Loss: 0.787260 	Validation Loss: 0.623511
Saving. Valid loss 0.623511
Epoch: 4 	Training Loss: 0.774209 	Validation Loss: 0.627023
Saving. Valid loss 0.627023
Epoch: 5 	Training Loss: 0.753063 	Validation Loss: 0.626805
Saving. Valid loss 0.626805
Epoch: 6 	Training Loss: 0.738036 	Validation Loss: 0.642504
Saving. Valid loss 0.642504
Epoch: 7 	Training Loss: 0.724547 	Validation Loss: 0.624163
Saving. Valid loss 0.624163
Epoch: 8 	Training Loss: 0.720632 	Validation Loss: 0.607064
Saving. Valid loss 0.607064
Epoch: 9 	Training Loss: 0.717522 	Validation Loss: 0.603427
Saving. Valid loss 0.603427
Epoch: 10 	Training Loss: 0.705051 	Validation Loss: 0.619655
Saving. Valid loss 0.619655
Epoch: 11 	Training 