# CIFAR-10 Classification using Pretrained VGG16

This notebook implements a transfer learning approach using VGG16 for CIFAR-10 classification. 
1. Setup and Dependencies
2. Dataset Exploration
3. Data Preprocessing
4. Model Architecture
5. Training Implementation
6. Model Training and Evaluation

## 1. Setup and Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vgg16, VGG16_Weights
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Constants
BATCH_SIZE = 128
IMG_SIZE = 224  # VGG16 expected input size
NUM_EPOCHS = 35

## 2. Dataset Exploration
We will load and explore the CIFAR-10 dataset.

In [None]:
# Load raw CIFAR-10 dataset for exploration
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
# Print dataset information
print(f"x_train shape: {trainset.data.shape}")
print(f"y_train shape: ({len(trainset)}, 1)")
print(f"{len(trainset)} train samples")
print(f"{len(testset)} test samples")

### 2.1 Visualize Sample Images

### 2.2 Class Distributioin

## 3. Data Preprocessing
Set up our data transforms and create the data loaders.

In [None]:
# Training transforms with augmentation
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomAffine(0, translate=(0.14, 0.14)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4863, 0.4532, 0.4155], std=[0.2621, 0.2557, 0.2582]) #eda based on train dataset
])
# Test transforms (no augmentation)
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4863, 0.4532, 0.4155], std=[0.2621, 0.2557, 0.2582]) #eda
])

In [None]:
train_dataset = torchvision.datasets.CIFAR10(root='data', 
                                 train=True, 
                                 transform=train_transform,
                                 download=True)

test_dataset = torchvision.datasets.CIFAR10(root='data', 
                                train=False, 
                                transform=test_transform)


train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size, 
                          num_workers=8,
                          shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size, 
                         num_workers=8,
                         shuffle=False)

In [None]:
# Checking the dataset
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    print(labels)
    break

## 4. Model Architecture

In [None]:
class VGG16Model(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16Model, self).__init__()
        # Load pretrained VGG16
        self.vgg16 = models.vgg16(pretrained=True)
        
        # Modify first conv layer to handle 32x32 images
        self.vgg16.features[0] = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        
        # Freeze VGG16 layers (optional, you might want to train them for CIFAR-10)
        for param in self.vgg16.parameters():
            param.requires_grad = False
            
        # Remove original classifier
        self.features = self.vgg16.features
        
        # For 32x32 input, after VGG16 features, we get 512 x 1 x 1
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.BatchNorm1d(512),  
            nn.Linear(512, 256),  
            nn.Softplus(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            
            nn.Linear(256, 256),
            nn.Softplus(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            
            nn.Linear(256, 256),
            nn.Softplus(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            
            nn.Linear(256, 256),
            nn.Softplus(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
        
        # Initialize weights for the new layers
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [None]:
# Create model instance"
model = VGG16Model().to(device)
# Print model summary,
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')

## 5. Training Implementation

In [None]:
def plot_metrics(train_losses, val_losses, train_accs, val_accs):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    ax1.plot(train_losses, label='Train Loss')
    ax1.plot(val_losses, label='Val Loss')
    ax1.set_title('Loss vs. Epoch')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    
    # Plot accuracies
    ax2.plot(train_accs, label='Train Acc')
    ax2.plot(val_accs, label='Val Acc')
    ax2.set_title('Accuracy vs. Epoch')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}')
    
    return running_loss / len(train_loader), 100. * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(val_loader), 100. * correct / total

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    best_val_loss = float('inf')
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 20)
        
        # Training phase
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validation phase
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Step the scheduler
        scheduler.step(val_loss)
        
        # Store metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, 'best_vgg16_cifar.pth')
    
    return train_losses, val_losses, train_accs, val_accs


In [None]:
def main():
    # Model
    model = VGG16Model().to(device)
    
   # Define loss function (criterion)
    criterion = nn.CrossEntropyLoss()

    # Define optimizer with initial learning rate
    initial_lr = 0.001
    optimizer = optim.SGD(params=params_to_update, lr=0.initial_lr, momentum=0.9) # try SGD
    # optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=1e-4)  # Added weight decay for regularization

    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
    
    # Train the model
    train_losses, val_losses, train_accs, val_accs = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS,
        device=device
    )
    
    # Plot results
    plot_metrics(train_losses, val_losses, train_accs, val_accs)

In [None]:
if __name__ == "__main__":
    main()