# Transfer Learning using CIFAR 10 dataset with Resnet 18 Weights using PyTorch

In [0]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms

# transform the dataset for ImageNet

In [0]:
batch_size = 50

train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


valid_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [3]:
train_data = datasets.CIFAR10('./data',
                              train = True,
                             download = True,
                              transform = train_transform
                             )

Files already downloaded and verified


In [4]:
valid_data = datasets.CIFAR10('./data',
                              train = False,
                             download = True,
                              transform = valid_transform
                             )

Files already downloaded and verified


# Data Loader for Pytorch

In [0]:
train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                          batch_size=batch_size,
                                          shuffle = True,
                                          num_workers = 2)

valid_loader = torch.utils.data.DataLoader(dataset=valid_data,
                                          batch_size=batch_size,
                                          shuffle = False,
                                          num_workers = 2)

In [0]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# define train model

In [0]:
def train_model(model, loss_function, optimizer, data_loader):
    model.train()
    
    current_loss = 0
    current_acc = 0
    
    for i, (inputs, labels) in enumerate(data_loader):
        
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            
            outputs = model(inputs)
            
            _, predicted = torch.max(outputs, 1)
            
            loss = loss_function(outputs, labels)
            
            loss.backward()
            
            optimizer.step()
            
        current_loss += loss.item() * inputs.size(0)
        
        current_acc += torch.sum(predicted == labels.data)
        
    total_loss = current_loss / len(data_loader.dataset)
    total_acc = 100 * (current_acc.double() / len(data_loader.dataset))
    
    print('Train Loss: {:.4f}, Accuracy: {:.4f}%'.format(total_loss, total_acc))

# define validation model

In [0]:
def val_model(model, loss_function, data_loader):
    
    model.eval()
    
    current_loss = 0
    current_acc = 0
    
    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            
            _, predicted = torch.max(outputs, 1)
            
            loss = loss_function(outputs, labels)
            
        current_loss += loss.item() * inputs.size(0)
        
        current_acc += torch.sum(predicted == labels.data)
        
    total_loss = current_loss / len(data_loader.dataset)
    total_acc = 100 * (current_acc.double() / len(data_loader.dataset))
    
    print('Validation Loss: {:.4f}, Accuracy: {:.4f}%'.format(total_loss, total_acc))

# Train the Model

In [11]:
epoch = 3

model = torchvision.models.resnet18(pretrained=True)

for param in model.parameters():
    param.requires_grad = False
    
num_features = model.fc.in_features

model.fc = nn.Linear(num_features, 10)

model = model.to(device)

loss_function = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.fc.parameters())

for epch in range(epoch):
    
    print('Epoch {}/{}'.format(epch+1, epoch))
    train_model(model, loss_function, optimizer, train_loader)
    val_model(model, loss_function, valid_loader)

Epoch 1/3
Train Loss: 1.0464, Accuracy: 64.1760%
Validation Loss: 0.7157, Accuracy: 75.6800%
Epoch 2/3
Train Loss: 0.8564, Accuracy: 69.9600%
Validation Loss: 0.7112, Accuracy: 75.5800%
Epoch 3/3
Train Loss: 0.8378, Accuracy: 70.6960%
Validation Loss: 0.6813, Accuracy: 76.7800%
