In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchdiffeq import odeint_adjoint as odeint
import time

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = F.relu(self.batch_norm1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.batch_norm2(self.conv2(x)))
        x = F.relu(self.batch_norm3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(self.dropout(x)))
        return x

class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128)
        )

    def forward(self, t, y):
        return self.net(y)

class ODEModel(nn.Module):
    def __init__(self):
        super(ODEModel, self).__init__()
        self.cnn = CNN()
        self.ode_func = ODEFunc()
        self.classifier = nn.Linear(128, 10)

    def forward(self, x):
        x = self.cnn(x)
        y1 = odeint(self.ode_func, x, torch.tensor([0.0, 1.0]), method='dopri5', options={'step_size': 0.1})[-1]
        output = self.classifier(y1)
        return output

def compute_gradients(ode_func, y0, t):
    y = y0.clone().requires_grad_(True)
    dy_dt = ode_func(t, y)
    gradients = torch.autograd.grad(dy_dt, y, grad_outputs=torch.ones_like(dy_dt))[0]
    return gradients

def train(epoch, model, device, train_loader, optimizer):
    model.train()
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

    end_time = time.time()

    print(f'Epoch {epoch} training time: {end_time - start_time:.2f} seconds')


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')
    return test_loss, accuracy

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    model = ODEModel().to(device)
    optimizer = optim.Adam(model.parameters())

    for epoch in range(1, 31):
        train(epoch, model, device, train_loader, optimizer)
        test_loss, accuracy = test(model, device, test_loader)

if __name__ == '__main__':
    main()




Epoch 1 training time: 138.29 seconds

Test set: Average loss: 0.0391, Accuracy: 9879/10000 (98.79%)

Epoch 2 training time: 145.86 seconds

Test set: Average loss: 0.0408, Accuracy: 9862/10000 (98.62%)

Epoch 3 training time: 146.97 seconds

Test set: Average loss: 0.0393, Accuracy: 9883/10000 (98.83%)

Epoch 4 training time: 149.52 seconds

Test set: Average loss: 0.0489, Accuracy: 9846/10000 (98.46%)

Epoch 5 training time: 150.48 seconds

Test set: Average loss: 0.0389, Accuracy: 9888/10000 (98.88%)

Epoch 6 training time: 151.31 seconds

Test set: Average loss: 0.0236, Accuracy: 9927/10000 (99.27%)

Epoch 7 training time: 156.93 seconds

Test set: Average loss: 0.0325, Accuracy: 9907/10000 (99.07%)

Epoch 8 training time: 161.70 seconds

Test set: Average loss: 0.0239, Accuracy: 9924/10000 (99.24%)

Epoch 9 training time: 160.50 seconds

Test set: Average loss: 0.0260, Accuracy: 9921/10000 (99.21%)

Epoch 10 training time: 167.77 seconds

Test set: Average loss: 0.0318, Accuracy: 