In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchmetrics.classification import MulticlassAccuracy
from torchinfo import summary
import os.path

In [2]:
no_epochs = 50
learning_rate = 0.001
batch_size = 128

acc_function = MulticlassAccuracy(num_classes=102, average='micro')
loss_fn = nn.CrossEntropyLoss()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

PATH = './model.pth'

cuda:0


In [3]:
# Data Augmentation
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = transforms.Compose([
    models.VGG16_BN_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)


In [4]:
def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(flowers_train, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, val_loader

In [5]:
# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [6]:
def train(model, optimizer, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_loop(model, train_dataloader, val_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr = [], [], [], []
    for i in range(no_epochs):
        train_loss = train(model, optimizer, train_dataloader)
        eval_acc, eval_loss = test_eval(model, val_dataloader)
        print(f'Epoch {i+1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%, Eval Loss: {eval_loss:>8f}')
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
#         if es.early_stop(eval_acc):
#             print('Early stopping activated')
#             break
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr

In [7]:
def create_model():
    model = models.vgg16_bn(weights=models.VGG16_BN_Weights.DEFAULT)
    new_classifier_head = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(4096, 102)
    )
    
    for param in model.features.parameters():
        param.requires_grad = False
        
    model.classifier = new_classifier_head

    return model.to(device=device)

In [8]:
model = create_model()
if os.path.isfile(PATH):
    model.load_state_dict(torch.load(PATH))
    print("Model loaded")
acc_function = acc_function.to(device)
train_data_loader, test_data_loader, val_data_loader = get_data_loader(batch_size)

In [9]:
train_acc, train_loss, eval_acc, eval_loss = train_eval_loop(model, train_data_loader, val_data_loader, no_epochs=no_epochs)

Epoch 1 Train Loss: 4.894980, Eval Accuracy: 3.03%, Eval Loss: 4.540098
Epoch 2 Train Loss: 4.469771, Eval Accuracy: 11.57%, Eval Loss: 4.046946
Epoch 3 Train Loss: 3.883506, Eval Accuracy: 23.13%, Eval Loss: 3.229333


In [None]:
torch.save(model.state_dict(), PATH)