### Transfer Learning using ViT Model


In [None]:
# Import required libraries
import pandas as pd
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import timm

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Step 1: Define Image Transformations

In [None]:
# Prepare and augment data for ViT model
train_val_transform = transforms.Compose([
    transforms.Resize((224, 224)), # because ViT was trained on images with this resolution
    transforms.RandomHorizontalFlip(), # common data augmentation
    #transforms.RandomRotation(15),
    #transforms.RandomCrop(224, padding=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # specific standardization for ViT
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for test
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

### Step 2: Load Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
root = "/content/drive/MyDrive/"
train_val_folder = root + "train_val_data/"

# Use ImageFolder to load the dataset
train_val_dataset = datasets.ImageFolder(train_val_folder, transform=train_val_transform)
num_classes = len(train_val_dataset.classes)  # Number of classes in the dataset
print(f"Number of classes: {num_classes}")

Number of classes: 12


### Step 3: Split Data into Training and Validation

In [None]:
train_percentage = 0.8
train_size = int(train_percentage * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_dataset, val_dataset = random_split(
    train_val_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
)

### Step 4: DataLoaders for Efficient Data Loading

In [None]:
batch_size = 32  # Adjust based on GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Training images: {len(train_dataset)}, Validation images: {len(val_dataset)}")

Training images: 2400, Validation images: 600


### Step 5: ViT Model Definition

Thereby, we freeze some layers for Transfer Learning:
* freezing earlier layers retains general features learned on a large dataset (e.g., ImageNet)
* we only fine-tune the later layers and the classification head

In [None]:
# Load pre-trained Vision Transformer from `timm` library
def create_model(num_classes):
    model = timm.create_model('vit_base_patch16_224', pretrained=True)  # Load ViT
    # Freeze all layers initially
    for param in model.parameters():
        param.requires_grad = False
    # Replace the classification head with a new one for our task
    model.head = nn.Linear(model.head.in_features, num_classes)
    # Unfreeze the final few layers for fine-tuning
    for param in model.blocks[-2:].parameters():  # Unfreeze last two blocks
        param.requires_grad = True
    return model.to(device)

model = create_model(num_classes=num_classes)

### Step 6: Define Loss Function and Optimizer

For AdamW optimizer, we adjust the learning rate for fine-tuning. Therefore, we
* use a lower learning rate for the pre-trained layers (frozen or partially fine-tuned)
* use a higher learning rate for the new classification head

In [None]:
criterion = nn.CrossEntropyLoss()  # Common multi-class classification loss
optimizer = optim.AdamW([
    {'params': model.head.parameters(), 'lr': 1e-3},  # Higher learning rate for the head
    {'params': model.blocks[-2:].parameters(), 'lr': 1e-4}  # Lower learning rate for fine-tuned layers
], weight_decay=1e-5)

### Step 7: Train Model

In [None]:
def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10):
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        #print(f"\nEpoch {epoch+1}/{num_epochs}")
        model.train()  # Set the model to training mode
        train_loss, correct, total = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()  # Clear gradients
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            train_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)  # Get predicted classes
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_loss /= total
        train_accuracy = 100 * correct / total

        # Validation loop
        model.eval()  # Set the model to evaluation mode
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        val_loss /= total
        val_accuracy = 100 * correct / total

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

        # Save the best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_model.pth")
            print("Best model saved.")

    print(f"Best Validation Accuracy: {best_val_accuracy:.2f}%")
    return model

In [None]:
model = train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10)

Epoch 1/10, Train Loss: 0.6763, Train Acc: 78.25%, Val Loss: 0.4374, Val Acc: 86.33%
Best model saved.
Epoch 2/10, Train Loss: 0.2179, Train Acc: 92.33%, Val Loss: 0.5094, Val Acc: 84.50%
Epoch 3/10, Train Loss: 0.1112, Train Acc: 96.29%, Val Loss: 0.6716, Val Acc: 85.00%
Epoch 4/10, Train Loss: 0.0471, Train Acc: 98.21%, Val Loss: 0.7943, Val Acc: 86.17%
Epoch 5/10, Train Loss: 0.0327, Train Acc: 99.04%, Val Loss: 0.7802, Val Acc: 86.00%
Epoch 6/10, Train Loss: 0.0379, Train Acc: 98.75%, Val Loss: 0.8434, Val Acc: 83.33%
Epoch 7/10, Train Loss: 0.0288, Train Acc: 99.17%, Val Loss: 0.8553, Val Acc: 86.33%
Epoch 8/10, Train Loss: 0.0066, Train Acc: 99.92%, Val Loss: 0.8289, Val Acc: 87.00%
Best model saved.
Epoch 9/10, Train Loss: 0.0005, Train Acc: 100.00%, Val Loss: 0.8622, Val Acc: 87.17%
Best model saved.
Epoch 10/10, Train Loss: 0.0001, Train Acc: 100.00%, Val Loss: 0.9033, Val Acc: 87.67%
Best model saved.
Best Validation Accuracy: 87.67%


### Step 8: Test Model

In [None]:
def create_result_file(model, test_dataset, classes):

    keys = ["ImageName", *classes]
    prediction_dict = {key: [] for key in keys}
    names = test_dataset.imgs
    model.to(device)
    model.eval()

    for i in range(len(test_dataset)):
        input = test_dataset.__getitem__(i)
        input = input[0].to(device).unsqueeze(0)
        with torch.no_grad():
            outputs = model(input).cpu().squeeze().numpy()
            prediction_dict["ImageName"].append(os.path.basename(names[i][0]))
            for class_idx, class_name in enumerate(classes):
                prediction_dict[class_name].append(outputs[class_idx])

    df = pd.DataFrame.from_dict(prediction_dict)
    df.to_csv("result2.csv", index=False)
    print("Results saved to result.csv")

In [None]:
# Load test dataset
test_folder = root + "test_folder/"
test_dataset = datasets.ImageFolder(test_folder, transform=test_transform)

In [None]:
# Save predictions
classes = train_val_dataset.classes
create_result_file(model, test_dataset, classes)

Results saved to result.csv


In [None]:
from google.colab import files
files.download('result2.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>