In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchmetrics.classification import MulticlassAccuracy
from torchinfo import summary

In [None]:
no_epochs = 100
learning_rate = 0.001
batch_size = 256

acc_function = MulticlassAccuracy(num_classes=102, average='micro')
loss_fn = nn.CrossEntropyLoss()

In [None]:
transformation = transforms.Compose([
    models.VGG16_BN_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=transformation)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=transformation)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=transformation)

In [None]:
def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(flowers_train, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, val_loader

In [None]:
def train(model, optimizer, dataloader, loss_fn=loss_fn):
    loss_value = 0
    device = next(model.parameters()).device
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    loss = 0
    acc = 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss += loss_fn(outputs, labels).item()
            acc += acc_function(outputs, labels).item()
    acc /= len(dataloader)
    loss /= len(dataloader)
    return acc, loss

def train_eval_loop(model, train_dataloader, val_dataloader, no_epochs=10):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr = [], [], [], []
    for i in range(no_epochs):
        print(f'Epoch {i+1}')
        train_loss = train(model, optimizer, train_dataloader)
        eval_loss, eval_acc = test_eval(model, val_dataloader)
        print('Train Loss: {}, Eval Accuracy: {}, Eval Loss: {}'.format(train_loss, eval_acc, eval_loss))
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr

In [None]:
def create_model():
    model = models.vgg16_bn(weights='DEFAULT')
    new_classifier_head = nn.Sequential(
        nn.Linear(25088, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, 102),
        nn.Softmax(dim=1)
    )
    model.classifier = new_classifier_head
    for param in model.features.parameters():
        param.requires_grad = False
    for l in model.classifier:
        for p in l.parameters():
            p.requires_grad = True
    return model

In [None]:
model = create_model()
acc_function = acc_function.to(next(model.parameters()).device)
train_data_loader, test_data_loader, val_data_loader = get_data_loader(batch_size)
train_acc, train_loss, eval_acc, eval_loss = train_eval_loop(model, train_data_loader, val_data_loader, no_epochs=no_epochs)