In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from torchmetrics.classification import MulticlassAccuracy
from torchinfo import summary
import os.path

In [None]:
no_epochs = 50
learning_rate = 0.0001
batch_size = 128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

acc_function = MulticlassAccuracy(num_classes=102, average='micro').to(device)
loss_fn = nn.CrossEntropyLoss()

PATH = './model.pth'

In [None]:
# Data Augmentation
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = transforms.Compose([
    models.VGG16_BN_Weights.IMAGENET1K_V1.transforms()
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)


In [None]:
def get_data_loader(batch_size):
    train_loader = torch.utils.data.DataLoader(flowers_train, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(flowers_test, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(flowers_val, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, val_loader

In [None]:
# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=5, min_delta=0.01):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def train(model, optimizer, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_test(model, train_dataloader, val_dataloader, test_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr = [], [], [], []
    for i in range(no_epochs):
        train_loss = train(model, optimizer, train_dataloader)
        eval_acc, eval_loss = test_eval(model, val_dataloader)
        print(f'Epoch {i+1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%, Eval Loss: {eval_loss:>8f}')
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
        if es.early_stop(eval_acc):
            print('Early stopping activated')
            break
    test_acc, test_loss = test_eval(model, test_dataloader)
    print(f"Test Accuracy: {test_acc}, Test Loss: {test_loss}")
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, test_acc, test_loss

In [None]:
class ModifiedResNet(nn.Module):
    def __init__(self, num_classes=102):
        super(ModifiedResNet, self).__init__()
        self.resnet = models.resnet34(weights=models.ResNet18_Weights.DEFAULT)
        new_last_layer = nn.

In [None]:


# Define the Deformable Conv2d module
class DeformConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=False):
        super(DeformConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        # Offset convolution
        self.offset_conv = nn.Conv2d(in_channels,
                                     2 * kernel_size * kernel_size,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=padding,
                                     dilation=dilation)

        # Deformable convolution
        self.deform_conv = ops.DeformConv2d(in_channels,
                                            out_channels,
                                            kernel_size=kernel_size,
                                            stride=stride,
                                            padding=padding,
                                            dilation=dilation,
                                            groups=groups,
                                            bias=bias)

        self.init_offset()

    def init_offset(self):
        torch.nn.init.constant_(self.offset_conv.weight, 0)
        torch.nn.init.constant_(self.offset_conv.bias, 0)

    def forward(self, x):
        offset = self.offset_conv(x)
        return self.deform_conv(x, offset)

# Function to replace standard convolutions in ResNet with deformable convolutions
def replace_conv_with_deform_conv(resnet_layer):
    for layer in resnet_layer.children():
        if isinstance(layer, nn.Sequential):
            for basic_block in layer.children():
                old_conv = basic_block.conv2
                setattr(basic_block, 'conv2', DeformConv2d(old_conv.in_channels,
                                                           old_conv.out_channels,
                                                           kernel_size=old_conv.kernel_size[0],
                                                           stride=old_conv.stride[0],
                                                           padding=old_conv.padding[0],
                                                           bias=(old_conv.bias is not None)))

# Custom ResNet34 with deformable convolutions in layer4
class CustomResNet34(nn.Module):
    def __init__(self, num_classes=102):
        super(CustomResNet34, self).__init__()
        # Load pre-trained ResNet-34
        self.resnet34 = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        replace_conv_with_deform_conv(self.resnet34.layer4)
        # Replace the fully connected layer
        self.resnet34.fc = nn.Linear(self.resnet34.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet34(x)

# Instantiate the custom model
model = CustomResNet34(num_classes=102).to(device)

# Define the optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_function = nn.CrossEntropyLoss()

# Train and evaluate the model
no_epochs = 10  # Define the number of epochs
ResNet_train_loss, ResNet_train_acc, ResNet_eval_loss, ResNet_eval_acc, ResNet_test_acc, ResNet_test_loss = train_eval_test(
    model,
    train_data_loader,
    val_data_loader,
    test_data_loader,
    no_epochs=no_epochs
)

# Assume that the 'train_eval_test' function prints the training progress and final test accuracy and loss