In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import torchdiffeq
import time

# ResNet-50 model adapted for skin cancer prediction
class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])  # Remove last FC layer

    def forward(self, x):
        x = self.resnet(x)
        return x.view(x.size(0), -1)

# ODE function definition for the hidden state dynamics
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, 2048),
            nn.ReLU(),
            nn.Linear(2048, 2048)
        )
        self.apply(self._initialize_weights)  # Apply weight initialization

    def _initialize_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, t, y):
        return self.net(y)

# Combined model with ResNet and Neural ODE
class ODEModel(nn.Module):
    def __init__(self, num_classes):
        super(ODEModel, self).__init__()
        self.resnet = ResNet50()
        self.ode_func = ODEFunc()
        self.classifier = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        t = torch.tensor([0., 1.]).to(x.device)
        y = torchdiffeq.odeint(self.ode_func, x, t, method='dopri5')
        return self.classifier(y[-1])

# Function to compute loss
def compute_loss(model, data, target):
    output = model(data)
    loss = nn.CrossEntropyLoss()(output, target)
    return loss

# Training routine
def train(epoch, model, device, train_loader, optimizer):
    model.train()
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        loss = compute_loss(model, data, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    end_time = time.time()
    print(f'Epoch {epoch} training time: {end_time - start_time:.2f} seconds')

# Testing routine
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    return test_loss, accuracy

# Load ISIC data with updated data augmentation
def load_data(batch_size=32):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    train_dataset = ImageFolder('../data/Skin_cancer_ISIC/Train', transform=transform)
    test_dataset = ImageFolder('../data/Skin_cancer_ISIC/Test', transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    return train_loader, test_loader, len(train_dataset.classes)

# Main function to train and evaluate the model
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader, num_classes = load_data()
    model = ODEModel(num_classes).to(device)
    
    # Using Adam with weight decay
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Using ReduceLROnPlateau scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    for epoch in range(1, 51):
        train(epoch, model, device, train_loader, optimizer)
        test_loss, accuracy = test(model, device, test_loader)
        scheduler.step(test_loss)

        # Save checkpoint
        if epoch % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': test_loss,
                'accuracy': accuracy
            }, f'checkpoint_epoch_{epoch}.pth')

if __name__ == '__main__':
    main()




Epoch 1 training time: 976.82 seconds

Test set: Average loss: 1.8091, Accuracy: 45/118 (38.14%)

Epoch 2 training time: 849.61 seconds

Test set: Average loss: 2.0114, Accuracy: 47/118 (39.83%)

Epoch 3 training time: 813.29 seconds

Test set: Average loss: 2.7975, Accuracy: 32/118 (27.12%)

Epoch 4 training time: 988.65 seconds

Test set: Average loss: 1.8852, Accuracy: 54/118 (45.76%)

Epoch 5 training time: 1071.99 seconds

Test set: Average loss: 1.9483, Accuracy: 40/118 (33.90%)

Epoch 6 training time: 1303.45 seconds

Test set: Average loss: 2.6007, Accuracy: 37/118 (31.36%)

Epoch 7 training time: 1274.11 seconds

Test set: Average loss: 1.7331, Accuracy: 52/118 (44.07%)

Epoch 8 training time: 1458.01 seconds

Test set: Average loss: 2.0309, Accuracy: 53/118 (44.92%)

Epoch 9 training time: 1469.49 seconds

Test set: Average loss: 2.6271, Accuracy: 42/118 (35.59%)

Epoch 10 training time: 1490.46 seconds

Test set: Average loss: 2.6311, Accuracy: 49/118 (41.53%)

Epoch 11 trai