In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchmetrics.classification import MulticlassAccuracy
import numpy as np
from collections import OrderedDict

In [None]:
no_epochs = 300
learning_rate = 0.0001
batch_size = 128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

acc_function = MulticlassAccuracy(num_classes=102, average='micro').to(device)
loss_fn = nn.CrossEntropyLoss()

SEED = 42
np.random.seed(SEED)
generator = torch.Generator().manual_seed(SEED)

PATH = './model.pth'

In [None]:
# Data Augmentation
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = transforms.Compose([
    models.VGG16_BN_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)


In [None]:
def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(flowers_train, batch_size=batch_size, shuffle=True, generator=generator)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True, generator=generator)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True, generator=generator)
    return train_loader, test_loader, val_loader

In [None]:
# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def train(model, optimizer, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_test(model, train_dataloader, val_dataloader, test_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr = [], [], [], []
    best_acc = 0
    for i in range(no_epochs):
        train_loss = train(model, optimizer, train_dataloader)
        eval_acc, eval_loss = test_eval(model, val_dataloader)
        print(f'Epoch {i+1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%, Eval Loss: {eval_loss:>8f}')
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
        
        if eval_acc > best_acc:
            best_acc = eval_acc
            best_state = model.state_dict()
        
        if i % 100 == 5:
            torch.save(best_state, "dropblock_model.zip")
    test_acc, test_loss = test_eval(model, test_dataloader)
    print(f"Test Accuracy: {test_acc}, Test Loss: {test_loss}")
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, test_acc, test_loss

In [None]:
train_data_loader, test_data_loader, val_data_loader = get_data_loader(batch_size)

In [None]:
def add_dropblock(block, p=0.1, block_size=3):
    newBlock = nn.Sequential(OrderedDict([]))
    count = 0
    for layer in block.children():
        newBlock.add_module(str(count),layer)
        if isinstance(layer, nn.ReLU):
            newBlock.add_module('DropBlock', torchvision.ops.DropBlock2d(p=p, block_size=block_size))
        # newBlock.add_module(str(count),layer)
        count += 1
    return newBlock

class ResNet34WithDB(nn.Module):
    def __init__(self, num_classes):
        super(ResNet34WithDB, self).__init__()
        self.resnet = torchvision.models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        for i in range(len(self.resnet.layer1)):
            self.resnet.layer1[i] = add_dropblock(self.resnet.layer1[i], p=0.1, block_size=3)
            
        for i in range(1, len(self.resnet.layer2)):
            self.resnet.layer2[i] = add_dropblock(self.resnet.layer2[i], p=0.1, block_size=3)
            
        for i in range(1, len(self.resnet.layer3)):
            self.resnet.layer3[i] = add_dropblock(self.resnet.layer3[i], p=0.1, block_size=3)
            
        for i in range(1, len(self.resnet.layer4)):
            self.resnet.layer4[i] = add_dropblock(self.resnet.layer4[i], p=0.1, block_size=3)

        self.resnet.fc = nn.Linear(512, num_classes)

        
    def forward(self, x):
        x = self.resnet(x)
        return x

num_classes = 102 
model = ResNet34WithDB(num_classes)

model = model.to(device=device)

In [None]:
train_acc, train_loss, eval_acc, eval_loss, test_acc, _test_loss = train_eval_test(
    model, 
    train_data_loader, 
    val_data_loader, 
    test_data_loader,
    no_epochs=no_epochs
)