In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torchmetrics.classification import MulticlassAccuracy
import numpy as np

no_epochs = 10
learning_rate = 0.0001
batch_size = 128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

acc_function = MulticlassAccuracy(num_classes=102, average='micro').to(device)

SEED = 42
np.random.seed(SEED)
generator = torch.Generator().manual_seed(SEED)

# Data Augmentation
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = transforms.Compose([
    models.ResNet34_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)


def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(TripletDataLoader(flowers_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True, generator=generator)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True, generator=generator)
    return train_loader, test_loader, val_loader

# Custom dataloader for triplet loss
class TripletDataLoader(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        anchor, anchor_label = self.dataset[index]
        
        # Select positive sample
        positive_index = np.random.choice(np.where(np.array(self.dataset._labels) == anchor_label)[0])
        positive, _ = self.dataset[positive_index]

        # Select negative sample
        negative_index = np.random.choice(np.where(np.array(self.dataset._labels) != anchor_label)[0])
        negative, _ = self.dataset[negative_index]

        return anchor, positive, negative

    def __len__(self):
        return len(self.dataset)

# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def create_model():
    model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
    model.fc = nn.Linear(512, 102)
    model = model.to(device)
    return model

def train(model, optimizer, dataloader, loss_fn):
    running_loss_value = 0
    for anchor, positive, negative in dataloader:
        optimizer.zero_grad()
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        # Forward pass
        anchor_embedding = model(anchor)
        positive_embedding = model(positive)
        negative_embedding = model(negative)

        # Compute triplet margin loss
        loss = loss_fn(anchor_embedding, positive_embedding, negative_embedding)
        
        running_loss_value += loss.item()
        
        # Backward pass
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_test(model, train_dataloader, val_dataloader, test_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_acc_arr = [], [], []
    for i in range(no_epochs):   
        # Training
        train_loss = train(model, optimizer, train_dataloader, nn.TripletMarginLoss(margin=1.0).to(device))
        
        # Validation
        eval_acc = test_eval(model, val_dataloader)
        
        print(f'Epoch {i + 1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%')
        
        train_loss_arr.append(train_loss)
        train_acc_arr.append(eval_acc)
        eval_acc_arr.append(eval_acc)
        
        if es.early_stop(eval_acc):
            print('Early stopping activated')
            break
    
    # Testing
    test_acc = test_eval(model, test_dataloader)
    print(f"Test Accuracy: {test_acc}")
    
    return train_loss_arr, train_acc_arr, test_acc

# Loaders
train_data_loader, test_data_loader, val_data_loader = get_data_loader(batch_size)

# Model
model = create_model().to(device)

# Training and evaluation
train_loss_arr, train_acc_arr, test_acc = train_eval_test(model, train_data_loader, val_data_loader, test_data_loader, no_epochs=no_epochs)


cuda:0
