**What is Transfer Learning**

Transfer learning is a machine learning technique where a model trained on one task is reused (partially or fully) for a different but related task.Instead of training a model from scratch, which can be computationally expensive and require large datasets, transfer learning leverages knowledge from a pre-trained model to improve learning efficiency and
performance.



**How Transfer Learning Works**

1. Pretraining on a Large Dataset :
    
    1.  A model is first trained on a large dataset (e.g., ImageNet for images, GPT for text).
    
    2. The model learns **general features, such as edges and shapes** in images or syntax and semantics in text.

2. Fine-Tuning for a New Task

    1. The pre-trained model is then adapted to a new, often smaller, dataset.

    2. Some layers may be frozen (not updated), while others are fine-tuned for the specific
task.


In [8]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

In [9]:
torch.manual_seed(42)

<torch._C.Generator at 0x1524313d0>

In [10]:
df = pd.read_csv('../2. Dataset/fmnist_small.csv')

In [11]:
x = df.iloc[:, 1:].values/255.0
y = df.iloc[:,0].values

In [12]:
xtrain , xtest , ytrain , ytest = train_test_split( x , y , test_size=0.2 , random_state=20)

In [13]:
from torchvision.transforms import transforms

custom_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # its better to send normalized input distribution to pretrained models.
])

In [14]:
from PIL import Image
import numpy as np

class CustomDataset(Dataset):

  def __init__(self, features, labels, transform):
    
    self.features = features
    self.labels = labels
    self.transform = transform

  def __len__(self):
    return len(self.features)

  def __getitem__(self, index):

    image = self.features[index].reshape(28,28)
    image = image.astype(np.uint8)
    image = np.stack([image]*3, axis=-1)

    # image = Image.fromarray(image) to PIL 
    image = self.transform(image)

    return image, torch.tensor(self.labels[index], dtype=torch.long)

In [None]:
train_dataset = CustomDataset(xtrain,ytrain,transform=custom_transform)
test_dataset = CustomDataset(xtest,ytest,transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True , pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True , pin_memory=True)

In [17]:
import torchvision.models as models

vgg16 = models.vgg16(pretrained=True)



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /Users/akashjain/.cache/torch/hub/checkpoints/vgg16-397923af.pth


100%|██████████| 528M/528M [00:51<00:00, 10.7MB/s] 


In [18]:
for param in vgg16.features.parameters():
  param.requires_grad=False

In [19]:
vgg16.classifier = nn.Sequential(
    nn.Linear(25088, 1024),  # (32,28,28)
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 10)
)

In [20]:
device = 'cpu'
if hasattr(torch,'mps') and torch.backends.mps.is_available():
    device = 'mps'
    print("MPS is available")

MPS is available


In [21]:
vgg16 = vgg16.to(device)

In [22]:
learning_rate = 0.0001
epochs = 10

In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vgg16.classifier.parameters(), lr=learning_rate)

In [25]:

for epoch in range(epochs):

  total_epoch_loss = 0

  for batch_features, batch_labels in train_loader:

    batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)

    outputs = vgg16(batch_features)

    loss = criterion(outputs, batch_labels)

    optimizer.zero_grad()

    loss.backward()

    optimizer.step()

    total_epoch_loss = total_epoch_loss + loss.item()

  avg_loss = total_epoch_loss/len(train_loader)
  print(f'Epoch: {epoch + 1} , Loss: {avg_loss}')


Epoch: 1 , Loss: 2.307210826873779


KeyboardInterrupt: 

In [None]:
model.eval()

In [None]:

total = 0
correct = 0

with torch.no_grad():

  for batch_features, batch_labels in test_loader:

    batch_features, batch_labels = batch_features.to(device), batch_labels.to(device)

    outputs = vgg16(batch_features)

    _, predicted = torch.max(outputs, 1)

    total = total + batch_labels.shape[0]

    correct = correct + (predicted == batch_labels).sum().item()

print(correct/total)