In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchdiffeq
import time
import matplotlib.pyplot as plt
import copy

# Define a tagged module class
class TaggedModule(nn.Module):
    def __init__(self, tag):
        super().__init__()
        self.tag = tag

# CNN for feature extraction
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(0.5)  # Ajustez le taux de dropout

    def forward(self, x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.relu(self.batch_norm3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(self.dropout(x)))
        return x

# ODE Function for AFGE
class ODEFunc(TaggedModule):
    def __init__(self):
        super().__init__(tag="ode")
        self.net = nn.Sequential(
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128)
        )

    def forward(self, t, y):
        return self.net(y)

# Augmented ODE Function with gradients for AFGE
class AugmentedODEFunc(TaggedModule):
    def __init__(self, func):
        super().__init__(tag="ode")
        self.func = func

    def forward(self, t, state):
        y, adj_y = state[:, :128], state[:, 128:]
        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            dy = self.func(t, y)
            adj_dy = torch.autograd.grad(dy, y, adj_y, create_graph=True, retain_graph=True)[0]
        return torch.cat([dy, adj_dy], dim=1)

# Initializer network for dy(0)/dθ
class InitializerNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# Define the full model
class ODEModel(nn.Module):
    def __init__(self, num_classes=10):
        super(ODEModel, self).__init__()
        self.cnn = CNN()
        self.ode_func = ODEFunc()
        self.aug_ode_func = AugmentedODEFunc(self.ode_func)
        self.classifier = nn.Linear(128, num_classes)
        self.classifier.tag = "non_ode"
        self.initializer = InitializerNetwork(128, 128)
        self.last_adj_y = []  # Liste au lieu d'un seul tenseur

    def forward(self, x, is_first_batch=False):
        cnn_output = self.cnn(x)
        batch_size = cnn_output.shape[0]

        if self.training:
            init_adj_y = self.initializer(cnn_output)

            if is_first_batch or not self.last_adj_y:
                adj_y = init_adj_y
            else:
                # Utiliser les dernières valeurs disponibles, répétées si nécessaire
                last_adj_y = torch.cat(self.last_adj_y, dim=0)
                last_adj_y = last_adj_y.repeat(batch_size // last_adj_y.shape[0] + 1, 1)[:batch_size]
                adj_y = 0.5 * init_adj_y + 0.5 * last_adj_y.detach()

            aug_state = torch.cat([cnn_output, adj_y], dim=1)
            t = torch.tensor([0., 1.]).to(x.device)
            aug_state = torchdiffeq.odeint(self.aug_ode_func, aug_state, t, method='dopri5', options={'step_size': 0.1})  # Ajustez les pas de temps
            y1, adj_y1 = aug_state[-1, :, :128], aug_state[-1, :, 128:]

            # Mettre à jour last_adj_y
            self.last_adj_y = [adj_y1]

            output = self.classifier(y1)
            return output, y1, adj_y1
        else:
            t = torch.tensor([0., 1.]).to(x.device)
            y1 = torchdiffeq.odeint(self.ode_func, cnn_output, t, method='dopri5', options={'step_size': 0.1})  # Ajustez les pas de temps
            output = self.classifier(y1[-1])
            return output

# Function to get parameters based on tags
def get_parameters_by_tag(model, tag):
    for module in model.modules():
        if hasattr(module, 'tag') and module.tag == tag:
            yield from module.parameters()

# Training function
def train(epoch, model, device, train_loader, optimizer):
    model.train()
    start_time = time.time()

    ode_params_before = copy.deepcopy(list(get_parameters_by_tag(model, "ode")))

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        is_first_batch = (batch_idx == 0)
        output, y1, adj_y1 = model(data, is_first_batch)
        loss = F.cross_entropy(output, target)
        loss.backward()

        optimizer.step()

    # Check if ODE parameters have changed
    ode_params_after = list(get_parameters_by_tag(model, "ode"))

    end_time = time.time()
    print(f'Epoch {epoch} training time: {end_time - start_time:.2f} seconds')

# Testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    return test_loss, accuracy

# Load data
def load_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
    return train_loader, test_loader

# Main function to train and test the model
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader = load_data()
    model = ODEModel().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)  # Ajustez le weight decay
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    for epoch in range(1, 31):
        train(epoch, model, device, train_loader, optimizer)
        test_loss, accuracy = test(model, device, test_loader)
        scheduler.step(test_loss)

if __name__ == '__main__':
    main()




Epoch 1 training time: 141.12 seconds

Test set: Average loss: 0.0471, Accuracy: 9845/10000 (98.45%)

Epoch 2 training time: 147.06 seconds

Test set: Average loss: 0.0364, Accuracy: 9884/10000 (98.84%)

Epoch 3 training time: 146.55 seconds

Test set: Average loss: 0.0341, Accuracy: 9883/10000 (98.83%)

Epoch 4 training time: 144.93 seconds

Test set: Average loss: 0.0371, Accuracy: 9879/10000 (98.79%)

Epoch 5 training time: 145.10 seconds

Test set: Average loss: 0.0249, Accuracy: 9923/10000 (99.23%)

Epoch 6 training time: 144.69 seconds

Test set: Average loss: 0.0245, Accuracy: 9926/10000 (99.26%)

Epoch 7 training time: 147.35 seconds

Test set: Average loss: 0.0260, Accuracy: 9920/10000 (99.20%)

Epoch 8 training time: 147.58 seconds

Test set: Average loss: 0.0231, Accuracy: 9926/10000 (99.26%)

Epoch 9 training time: 148.34 seconds

Test set: Average loss: 0.0201, Accuracy: 9944/10000 (99.44%)

Epoch 10 training time: 152.50 seconds

Test set: Average loss: 0.0200, Accuracy: 