In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import time
import torchdiffeq

class TaggedModule(nn.Module):
    def __init__(self, tag):
        super().__init__()
        self.tag = tag

class ResNet50(TaggedModule):
    def __init__(self):
        super().__init__(tag="non_ode")
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1], nn.Dropout(0.5))  # Add Dropout

    def forward(self, x):
        x = self.resnet(x)
        return x.view(x.size(0), -1)

class ODEFunc(TaggedModule):
    def __init__(self):
        super().__init__(tag="ode")
        self.net = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.Tanh(),
            nn.Linear(2048, 2048),
            nn.Tanh(),
            nn.Linear(2048, 2048)
        )

    def forward(self, t, y):
        return self.net(y)

class AugmentedODEFunc(TaggedModule):
    def __init__(self, func):
        super().__init__(tag="ode")
        self.func = func

    def forward(self, t, state):
        y, adj_y = state[:, :2048], state[:, 2048:]
        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            dy = self.func(t, y)
            adj_dy = torch.autograd.grad(dy, y, adj_y, retain_graph=True)[0]
        return torch.cat([dy, adj_dy], dim=1)

class ODEModel(nn.Module):
    def __init__(self, num_classes):
        super(ODEModel, self).__init__()
        self.resnet = ResNet50()
        self.ode_func = ODEFunc()
        self.aug_ode_func = AugmentedODEFunc(self.ode_func)
        self.classifier = nn.Linear(2048, num_classes)
        self.classifier.tag = "non_ode"

    def forward(self, x):
        x = self.resnet(x)
        if self.training:
            aug_state = torch.cat([x, torch.zeros_like(x)], dim=1)
            t = torch.tensor([0., 1.]).to(x.device)
            aug_state = torchdiffeq.odeint(self.aug_ode_func, aug_state, t, method='dopri5')
            y1, adj_y1 = aug_state[-1, :, :2048], aug_state[-1, :, 2048:]
            output = self.classifier(y1)
            return output, y1, adj_y1
        else:
            t = torch.tensor([0., 1.]).to(x.device)
            y1 = torchdiffeq.odeint(self.ode_func, x, t, method='dopri5')
            output = self.classifier(y1[-1])
            return output

def get_parameters_by_tag(model, tag):
    for module in model.modules():
        if hasattr(module, 'tag') and module.tag == tag:
            yield from module.parameters()

def compute_loss_and_grad(model, data, target):
    output, y1, adj_y1 = model(data)
    loss = F.cross_entropy(output, target)
    grad_output = torch.autograd.grad(loss, y1, create_graph=True)[0]
    return loss, grad_output, adj_y1

def train(epoch, model, device, train_loader, optimizer):
    model.train()
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad(set_to_none=True)

        loss, grad_output, adj_y1 = compute_loss_and_grad(model, data, target)

        ode_params = list(get_parameters_by_tag(model, "ode"))
        non_ode_params = list(get_parameters_by_tag(model, "non_ode"))

        non_ode_grads = torch.autograd.grad(loss, non_ode_params, allow_unused=True, create_graph=False)

        with torch.no_grad():
            for param, grad in zip(non_ode_params, non_ode_grads):
                if param.requires_grad and grad is not None:
                    param.grad = grad

            adj_grad = torch.einsum('bi,bj->ij', adj_y1, grad_output)
            adj_grad_flat = adj_grad.flatten()

            start = 0
            for param in ode_params:
                if param.requires_grad:
                    num_param = param.numel()
                    if start + num_param <= adj_grad_flat.numel():
                        param_grad = adj_grad_flat[start:start+num_param].view(param.shape)
                        param.grad = param_grad
                    start += num_param

        optimizer.step()

        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    end_time = time.time()
    print(f'Epoch {epoch} training time: {end_time - start_time:.2f} seconds')

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    return test_loss, accuracy

def load_data():
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = ImageFolder('../data/Skin_cancer_ISIC/Train', transform=transform)
    test_dataset = ImageFolder('../data/Skin_cancer_ISIC/Test', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)
    return train_loader, test_loader, len(train_dataset.classes)

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader, num_classes = load_data()
    model = ODEModel(num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    for epoch in range(1, 51):
        train(epoch, model, device, train_loader, optimizer)
        test_loss, accuracy = test(model, device, test_loader)
        scheduler.step(test_loss)

if __name__ == '__main__':
    main()





Epoch 1 training time: 601.51 seconds

Test set: Average loss: 1.8563, Accuracy: 39/118 (33.05%)

Epoch 2 training time: 754.47 seconds

Test set: Average loss: 2.0576, Accuracy: 43/118 (36.44%)

Epoch 3 training time: 761.41 seconds

Test set: Average loss: 1.6107, Accuracy: 59/118 (50.00%)

Epoch 4 training time: 754.09 seconds

Test set: Average loss: 1.7179, Accuracy: 57/118 (48.31%)

Epoch 5 training time: 759.09 seconds

Test set: Average loss: 1.5113, Accuracy: 65/118 (55.08%)

Epoch 6 training time: 757.07 seconds

Test set: Average loss: 1.7649, Accuracy: 57/118 (48.31%)

Epoch 7 training time: 770.21 seconds

Test set: Average loss: 1.4517, Accuracy: 65/118 (55.08%)

Epoch 8 training time: 773.86 seconds

Test set: Average loss: 1.4640, Accuracy: 65/118 (55.08%)

Epoch 9 training time: 768.90 seconds

Test set: Average loss: 1.7129, Accuracy: 61/118 (51.69%)

Epoch 10 training time: 761.42 seconds

Test set: Average loss: 1.7351, Accuracy: 55/118 (46.61%)

Epoch 11 training t