##### Transfer Learning  
Transfer learning is technique that allows us to use learning from a different task in some other related tasks.  
For example, models trained on classifying cats and dogs can be used in classifying cars and bikes.  
Training models from scratch is computaionally expensive, transfer learning therefore helps to efficiently train models rather than training models froms cratch.  

In [13]:
import torch
import torch.nn as nn
import glob
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision.models as models
from sklearn.metrics import accuracy_score

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")


In [3]:
class Grayscaletorgb:
    def __call__(self, tensor):
        return tensor.repeat(3, 1, 1)


transform = transforms.Compose([
    transforms.Lambda(lambda x: x.convert("RGB")),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
    ])  

# Load the Fashion MNIST training dataset
train_data = datasets.FashionMNIST(
    root='./data',  # Directory to store the dataset
    train=True,     # Specify this is the training set
    download=True,  # Download the dataset if not already present
    transform=transform
)

# Load the Fashion MNIST test dataset
test_data = datasets.FashionMNIST(
    root='./data',  # Directory to store the dataset
    train=False,    # Specify this is the test set
    download=True,  # Download the dataset if not already present
    transform=transform
)


In [4]:
import os

save_dir = "data/Preprocessed_Data"
os.makedirs(f'{save_dir}/train', exist_ok=True)
os.makedirs(f'{save_dir}/test', exist_ok=True)
for i, (img, label) in enumerate(train_data):
    if i%10000 == 0:
        print(f'Saving training sample {i}')
    torch.save({'image': img, 'label': label}, f'{save_dir}/train/train_{i}.pt')

Saving training sample 0
Saving training sample 10000
Saving training sample 20000
Saving training sample 30000
Saving training sample 40000
Saving training sample 50000


In [5]:
for i, (img, label) in enumerate(test_data):
    if i%10000 == 0:
        print(f'Saving testing sample {i}')
    torch.save({'image': img, 'label': label}, f'{save_dir}/test/test_{i}.pt')

Saving testing sample 0


In [6]:
class mnistdataset(Dataset):
    def __init__(self, path):
        super().__init__()
        self.file_list = sorted(glob.glob(f"{path}/*.pt"))

    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        data = torch.load(self.file_list[index])
        return data['image'], data['label']

train_dataset = mnistdataset(f'{save_dir}/train')
test_dataset = mnistdataset(f'{save_dir}/test')

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [8]:
resnet = models.resnet18(pretrained=True)
resnet

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/geetdesai/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:03<00:00, 13.9MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
resnet.fc = nn.Linear(resnet.fc.in_features,10)

for param in resnet.parameters():
    param.requires_grad = False

for param in resnet.fc.parameters():
    param.requires_grad = True

In [11]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
resnet.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [12]:
epochs = 5 
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.fc.parameters(), lr=0.001)

for epoch in range(epochs):
    resnet.train()
    flag = 0
    for images, labels in train_dataloader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        y_pred = resnet(images)
        loss = loss_fn(y_pred, labels)
        loss.backward()
        optimizer.step()

        flag += 1
        if flag %100 == 0:
            print(f"Batch Number: ", flag)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")


Batch Number:  100
Batch Number:  200
Batch Number:  300
Batch Number:  400
Batch Number:  500
Batch Number:  600
Batch Number:  700
Batch Number:  800
Batch Number:  900
Epoch 1/5, Loss: 0.3046
Batch Number:  100
Batch Number:  200
Batch Number:  300
Batch Number:  400
Batch Number:  500
Batch Number:  600
Batch Number:  700
Batch Number:  800
Batch Number:  900
Epoch 2/5, Loss: 0.3808
Batch Number:  100
Batch Number:  200
Batch Number:  300
Batch Number:  400
Batch Number:  500
Batch Number:  600
Batch Number:  700
Batch Number:  800
Batch Number:  900
Epoch 3/5, Loss: 0.1612
Batch Number:  100
Batch Number:  200
Batch Number:  300
Batch Number:  400
Batch Number:  500
Batch Number:  600
Batch Number:  700
Batch Number:  800
Batch Number:  900
Epoch 4/5, Loss: 0.5212
Batch Number:  100
Batch Number:  200
Batch Number:  300
Batch Number:  400
Batch Number:  500
Batch Number:  600
Batch Number:  700
Batch Number:  800
Batch Number:  900
Epoch 5/5, Loss: 0.1451


In [14]:
resnet.eval()
accuracy_list = []
for batch_features, batch_labels in test_dataloader:
    batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)
    y_pred = resnet(batch_features)
    _, predicted = torch.max(y_pred, 1)
    accuracy = accuracy_score(batch_labels.cpu(), predicted.cpu())
    accuracy_list.append(accuracy)

average_accuracy = sum(accuracy_list) / len(accuracy_list)
print(f"Average Test Accuracy: {average_accuracy:.4f}")

Average Test Accuracy: 0.0995


In [15]:
resnet.eval()
accuracy_list = []
for batch_features, batch_labels in train_dataloader:
    batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)
    y_pred = resnet(batch_features)
    _, predicted = torch.max(y_pred, 1)
    accuracy = accuracy_score(batch_labels.cpu(), predicted.cpu())
    accuracy_list.append(accuracy)

average_accuracy = sum(accuracy_list) / len(accuracy_list)
print(f"Average Test Accuracy: {average_accuracy:.4f}")

Average Test Accuracy: 0.1000
