In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import v2
from torchmetrics.classification import MulticlassAccuracy
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
import datetime
import pandas as pd

In [None]:
no_epochs = 50
learning_rate = 0.0001
batch_size = 128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

acc_function = MulticlassAccuracy(num_classes=102, average='micro').to(device)
loss_fn = nn.CrossEntropyLoss()

SEED = 196
np.random.seed(SEED)
generator = torch.Generator().manual_seed(SEED)

In [None]:
# Data Augmentation
train_transforms = v2.Compose([
    v2.RandomRotation(30),
    v2.RandomResizedCrop(224),
    v2.RandomHorizontalFlip(),
    v2.ToTensor(),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = v2.Compose([
    models.VGG16_BN_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)


In [None]:
def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(flowers_train, batch_size=batch_size, shuffle=True, generator=generator)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True, generator=generator)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True, generator=generator)
    return train_loader, test_loader, val_loader

In [None]:
# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def train(model, optimizer, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_test(model, train_dataloader, val_dataloader, test_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, train_time = [], [], [], [], []
    for i in range(no_epochs):
        start = datetime.datetime.now()
        train_loss = train(model, optimizer, train_dataloader)
        end = datetime.datetime.now()
        eval_acc, eval_loss = test_eval(model, val_dataloader)
        time_taken = (end-start).total_seconds()
        print(f'Epoch {i+1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%, Eval Loss: {eval_loss:>8f}, Train Time: {time_taken:>0.2f}s')
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
        train_time.append(time_taken)
        if es.early_stop(eval_acc):
            print('Early stopping activated')
            break
    test_acc, test_loss = test_eval(model, test_dataloader)
    print(f"Test Accuracy: {test_acc}, Test Loss: {test_loss}")
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, test_acc, test_loss, train_time

In [None]:
def get_subset_data_loader(dataset, subset_size, batch_size, generator):
    # Create a random subset of indices
    subset_indices = torch.randperm(len(dataset), generator=generator)[:subset_size].tolist()
    
    # Create a DataLoader for only the subset of data
    subset_data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        sampler=torch.utils.data.SubsetRandomSampler(subset_indices, generator=generator)
    )
    
    return subset_data_loader


def few_shot_train_eval_test(model, train_dataset, val_dataloader, test_dataloader, few_shot_size, no_epochs=10):
    # Get the subset DataLoader for few-shot learning
    train_data_loader = get_subset_data_loader(train_dataset, few_shot_size, batch_size, generator)
    
    # Continue with the rest of your training code
    return train_eval_test(model, train_data_loader, val_dataloader, test_dataloader, no_epochs)

In [None]:
train_data_loader, test_data_loader, val_data_loader = get_data_loader(batch_size)

In [None]:
# Apply few-shot learning
few_shot_size = 20  # For example, use 20 images per class for few-shot learning
base_model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
base_model.fc = nn.Linear(512, 102)
base_model = base_model.to(device)

# Fine-tune the pre-trained model with the few-shot subset
few_shot_train_acc, few_shot_train_loss, few_shot_eval_loss, few_shot_eval_acc, few_shot_test_acc, few_shot_test_loss, few_shot_train_time = few_shot_train_eval_test(
    base_model, 
    train_data_loader,  # Pass the train dataset directly
    val_data_loader, 
    test_data_loader,
    few_shot_size,  # Pass the few-shot size
    no_epochs=no_epochs
)

# Now you can compare the results to see how well the model performs with few-shot learning.