In [None]:
import torch
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.ops import deform_conv2d
from torch.utils.data import DataLoader
import torch.nn as nn

# Define your device at the beginning
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data Augmentation
train_transforms = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Assuming you want to keep the default transformations for testing/validation:
default_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

flowers_train = datasets.Flowers102(root='./data', split='train', download=True, transform=train_transforms)
flowers_test = datasets.Flowers102(root='./data', split='test', download=True, transform=default_transforms)
flowers_val = datasets.Flowers102(root='./data', split='val', download=True, transform=default_transforms)

def get_data_loader(batch_size):
    train_loader = DataLoader(flowers_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(flowers_test, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(flowers_val, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader, val_loader

# Define a function to replace the first conv layer of each block with deform_conv2d
def insert_deformable_convs(resnet_model):
    for layer_name in ['layer1', 'layer2', 'layer3', 'layer4']:
        layer = getattr(resnet_model, layer_name)
        for block in layer:
            # Replace the conv1 in each BasicBlock
            conv1 = block.conv1
            out_channels = conv1.out_channels
            in_channels = conv1.in_channels
            kernel_size = conv1.kernel_size
            stride = conv1.stride
            padding = conv1.padding
            block.conv1 = nn.Sequential(
                nn.Conv2d(in_channels, in_channels * kernel_size[0] * kernel_size[1], 
                          kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels),
                nn.BatchNorm2d(in_channels * kernel_size[0] * kernel_size[1]),
                nn.ReLU(inplace=True),
                deform_conv2d.DeformConv2d(in_channels, out_channels, kernel_size=kernel_size, 
                                           stride=stride, padding=padding)
            )
    return resnet_model

# Early stopping based on accuracy
class AccuracyEarlyStopper:
    def __init__(self, patience=3, min_delta=0.5):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.max_validation_accuracy = 0

    def early_stop(self, validation_accuracy):
        if validation_accuracy > (self.max_validation_accuracy + self.min_delta):
            self.max_validation_accuracy = validation_accuracy
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

def train(model, optimizer, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    for images, labels in dataloader:
        optimizer.zero_grad()
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        running_loss_value += loss.item()
        loss.backward()
        optimizer.step()
    return running_loss_value / len(dataloader)

def test_eval(model, dataloader, loss_fn=loss_fn):
    running_loss_value = 0
    running_acc_value = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            acc = acc_function(outputs, labels)
            running_loss_value += loss.item()
            running_acc_value += acc.item()
    running_acc_value /= len(dataloader)
    running_loss_value /= len(dataloader)
    return running_acc_value*100, running_loss_value

def train_eval_test(model, train_dataloader, val_dataloader, test_dataloader, no_epochs=10):
    es = AccuracyEarlyStopper()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr = [], [], [], []
    for i in range(no_epochs):
        train_loss = train(model, optimizer, train_dataloader)
        eval_acc, eval_loss = test_eval(model, val_dataloader)
        print(f'Epoch {i+1} Train Loss: {train_loss:>8f}, Eval Accuracy: {eval_acc:>0.2f}%, Eval Loss: {eval_loss:>8f}')
        train_loss_arr.append(train_loss)
        eval_loss_arr.append(eval_loss)
        eval_acc_arr.append(eval_acc)
#         if es.early_stop(eval_acc):
#             print('Early stopping activated')
#             break
    test_acc, test_loss = test_eval(model, test_dataloader)
    print(f"Test Accuracy: {test_acc}, Test Loss: {test_loss}")
    return train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, test_acc, test_loss

# Function to create a ResNet model and replace convolutions with deformable convolutions
def create_deformable_resnet_model():
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 102)  # Assuming 102 classes for the Flowers dataset

    # Replace certain conv layers with deformable conv layers
    model = insert_deformable_convs(model)

    # Freeze all layers except deformable convolutions and the final classifier
    for name, param in model.named_parameters():
        if "conv1.2" not in name and "fc" not in name:  # conv1.2 is the deformable conv layer
            param.requires_grad = False

    return model.to(device)

# Main code
batch_size = 64
no_epochs = 10
learning_rate = 0.001
train_loader, test_loader, val_loader = get_data_loader(batch_size)

# Create and train the deformable ResNet model
deformable_resnet_model = create_deformable_resnet_model(size=18)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, deformable_resnet_model.parameters()), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

train_loss_arr, train_acc_arr, eval_loss_arr, eval_acc_arr, test_acc, test_loss = train_eval_test(
    deformable_resnet_model,
    train_loader,
    val_loader,
    test_loader,
    no_epochs=no_epochs
)

# At this point, you can print out the results or save the model
print(f"Test accuracy: {test_acc}, Test loss: {test_loss}")