##Part I: ODE-Net / Res-Net Replacement##


In [None]:
!pip install torch torchvision torchdiffeq matplotlib
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint_adjoint as odeint  # Adjoint method for memory efficiency
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Collecting torchdiffeq
  Downloading torchdiffeq-0.2.5-py3-none-any.whl.metadata (440 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
C

###ODE Dynamics (ODEFunc)###

In [None]:
class ODEFunc(nn.Module):
    """Defines the dynamics dh/dt = f(h(t), t, theta)"""
    def __init__(self, dim=64):
        super().__init__()
        self.nfe = 0  # Count function evaluations
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Tanh(),
            nn.Linear(dim, dim)
        )

    def forward(self, t, h):
        self.nfe += 1
        return self.net(h)

###ODE Block (ODEBlock)###

In [None]:
class ODEBlock(nn.Module):
    """Wraps ODEFunc into a replaceable ResNet block"""
    def __init__(self, odefunc, tol=1e-3):
        super().__init__()
        self.odefunc = odefunc
        self.tol = tol

    def forward(self, x):
        # Integrate from t=0 to t=1 (arbitrary time interval)
        t = torch.tensor([0., 1.], dtype=torch.float32).to(x.device)
        out = odeint(self.odefunc, x, t, rtol=self.tol, atol=self.tol)
        return out[1]  # Return final state

AttributeError: 'ODEBlock' object has no attribute 'get_nfe'

###ODE-Net Architecture###

In [None]:
class ODE_NET(nn.Module):
    def __init__(self):
        super().__init__()
        # Downsampling (similar to ResNet)
        self.downsample = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2),  # 28x28 → 13x13
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16*13*13, 64)  # Map to hidden dim
        )
        # ODE Block
        self.odeblock = ODEBlock(ODEFunc(dim=64))
        # Classifier
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.downsample(x)
        x = self.odeblock(x)
        return self.classifier(x)

###ResNet Baseline (Simplified)###

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16*13*13, 64),
            *[nn.Sequential(nn.Linear(64, 64), nn.ReLU()) for _ in range(6)],  # 6 ResNet blocks
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.net(x)

###Training Loop (MNIST)###
####Data Loading####

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True, transform=transform),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transform),
    batch_size=1000)

100%|██████████| 9.91M/9.91M [00:00<00:00, 13.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 414kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.25MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.79MB/s]


####Training Function####

In [None]:
def train(model, optimizer, epochs=10):
    model.train()
    for epoch in tqdm(range(epochs)):
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = model(x)
            loss = nn.CrossEntropyLoss()(output, y)
            loss.backward()
            optimizer.step()

####Evaluation####

In [None]:
def accuracy(model, loader):
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(dim=1)
            correct += (pred == y).sum().item()
    return correct / len(loader.dataset)

###Run Experiments###
####Train ODE-Net####

In [None]:
odenet = ODE_NET().to(device)
optimizer = optim.Adam(odenet.parameters(), lr=1e-3)
train(odenet, optimizer, epochs=10)
odenet_acc = accuracy(odenet, test_loader)
print(f"ODE-Net Test Accuracy: {odenet_acc:.2%}")

100%|██████████| 10/10 [04:09<00:00, 24.94s/it]


ODE-Net Test Accuracy: 98.10%


####Train ResNet####

In [None]:
resnet = ResNet().to(device)
optimizer = optim.Adam(resnet.parameters(), lr=1e-3)
train(resnet, optimizer, epochs=10)
resnet_acc = accuracy(resnet, test_loader)
print(f"ResNet Test Accuracy: {resnet_acc:.2%}")

100%|██████████| 10/10 [01:13<00:00,  7.31s/it]


ResNet Test Accuracy: 97.68%


###Visualize Results###
####Table 1 (Accuracy Comparison)####

In [None]:
import pandas as pd
results = pd.DataFrame({
    "Model": ["ODE-Net", "ResNet"],
    "Test Error (%)": [100 * (1 - odenet_acc), 100 * (1 - resnet_acc)],
    "Params (M)": [sum(p.numel() for p in odenet.parameters()) / 1e6,
                   sum(p.numel() for p in resnet.parameters()) / 1e6]
})
print(results)

     Model  Test Error (%)  Params (M)
0  ODE-Net            1.90     0.18225
1   ResNet            2.32     0.19889


####Plot Function Evaluations (NFE)####

In [None]:
# After training, check NFE
print(f"ODE-Net NFE: {odenet.odeblock.odefunc.nfe}")

ODE-Net NFE: 141198


####Scaled up Training####

In [None]:
import time

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 1. Define GroupNorm function
def norm(dim):
    return nn.GroupNorm(min(32, dim), dim)

# 2. Implement ConcatConv2d
class ConcatConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size=3, stride=1, padding=0):
        super().__init__()
        self._layer = nn.Conv2d(
            dim_in + 1,  # +1 for time dimension
            dim_out,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)

# 3. Flatten layer
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

# 4. ODE Function (without NFE tracking)
class ODEFunc(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv2d(dim, dim, kernel_size=3, padding=1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv2d(dim, dim, kernel_size=3, padding=1)
        self.norm3 = norm(dim)

    def forward(self, t, x):
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out

# 5. ODE Block (simplified)
class ODEBlock(nn.Module):
    def __init__(self, odefunc, tol=1e-3):
        super().__init__()
        self.odefunc = odefunc
        self.tol = tol
        self.integration_time = torch.tensor([0., 1.]).float()

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, rtol=self.tol, atol=self.tol)
        return out[1]

# Model architecture
downsampling_layers = [
    nn.Conv2d(1, 64, kernel_size=3, stride=1),
    norm(64),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
    norm(64),
    nn.ReLU(inplace=True),
    nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
    norm(64),
    nn.ReLU(inplace=True)
]

odefunc = ODEFunc(64)
feature_layers = [ODEBlock(odefunc)]
fc_layers = [
    norm(64),
    nn.ReLU(inplace=True),
    nn.AdaptiveAvgPool2d((1, 1)),
    Flatten(),
    nn.Linear(64, 10)
]

model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
#     optimizer, milestones=[60, 100, 140], gamma=0.1
# )

# Data loading
# ==================== DATA LOADING ====================
transform = transforms.Compose([
    transforms.RandomCrop(28, padding=4),  # Data augmentation
    transforms.ToTensor(),
])

# Split original training set into train and validation
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_size = int(0.9 * len(train_dataset))  # 90% for training
val_size = len(train_dataset) - train_size  # 10% for validation
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Test set (no augmentation)
test_dataset = datasets.MNIST('./data', train=False, transform=transforms.ToTensor())

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=1000, num_workers=2)

# ==================== TRAINING SETUP ====================
# Modify your optimizer to include weight decay
optimizer = torch.optim.SGD(model.parameters(),
                          lr=0.1,
                          momentum=0.9,
                          weight_decay=1e-4)  # Added regularization

# Keep your existing scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[11, 19, 26],  # 60/160*30≈11, 100/160*30≈19, 140/160*30≈26
    gamma=0.1
)
# Training loop without NFE tracking
for epoch in range(30):
    model.train()
    epoch_start = time.time()

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

    scheduler.step()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

    acc = 100 * correct / total
    epoch_time = time.time() - epoch_start
    print(f'Epoch {epoch+1}/30 | Time: {epoch_time:.2f}s | Test Acc: {acc:.2f}%')

Using device: cuda
Epoch 1/30 | Time: 57.04s | Test Acc: 98.12%
Epoch 2/30 | Time: 54.82s | Test Acc: 98.53%
Epoch 3/30 | Time: 56.96s | Test Acc: 98.81%
Epoch 4/30 | Time: 55.40s | Test Acc: 98.96%
Epoch 5/30 | Time: 54.44s | Test Acc: 99.18%
Epoch 6/30 | Time: 56.48s | Test Acc: 99.29%
Epoch 7/30 | Time: 56.88s | Test Acc: 99.15%
Epoch 8/30 | Time: 59.13s | Test Acc: 99.09%
Epoch 9/30 | Time: 60.39s | Test Acc: 98.94%
Epoch 10/30 | Time: 60.08s | Test Acc: 99.20%
Epoch 11/30 | Time: 58.80s | Test Acc: 99.09%
Epoch 12/30 | Time: 60.24s | Test Acc: 99.12%
Epoch 13/30 | Time: 60.72s | Test Acc: 99.32%
Epoch 14/30 | Time: 61.22s | Test Acc: 99.27%
Epoch 15/30 | Time: 60.89s | Test Acc: 99.42%
Epoch 16/30 | Time: 61.12s | Test Acc: 99.37%
Epoch 17/30 | Time: 61.18s | Test Acc: 98.92%
Epoch 18/30 | Time: 60.82s | Test Acc: 99.26%
Epoch 19/30 | Time: 61.12s | Test Acc: 99.34%
Epoch 20/30 | Time: 60.81s | Test Acc: 99.17%
Epoch 21/30 | Time: 61.14s | Test Acc: 99.17%
Epoch 22/30 | Time: 60.9

####Comparison with Res Net####

In [None]:
# 1. First let's define a comparable ResNet
class BasicResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1),
            norm(64),
            nn.ReLU(),
            self._make_layer(64, 64, 6),  # 6 residual blocks
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            nn.Linear(64, 10)
        )

    def _make_layer(self, in_channels, out_channels, blocks):
        layers = []
        for _ in range(blocks):
            layers.append(ResBlock(in_channels, out_channels))
        return nn.Sequential(*layers)

# 2. Define a residual block (for the ResNet)
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = norm(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = norm(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out += residual
        out = self.relu(out)
        return out

# 3. Train and compare both models
def train_and_compare():
    # Initialize models
    ode_model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device)
    resnet_model = BasicResNet().to(device)

    # Training function
    def train_model(model, name):
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

        print(f"\nTraining {name}...")
        for epoch in range(30):
            model.train()
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                output = model(x)
                loss = criterion(output, y)
                loss.backward()
                optimizer.step()
            scheduler.step()

            # Validation
            model.eval()
            correct = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(device), y.to(device)
                    outputs = model(x)
                    _, predicted = torch.max(outputs.data, 1)
                    correct += (predicted == y).sum().item()

            val_acc = 100 * correct / len(val_dataset)
            print(f'Epoch {epoch+1}/30 | Val Acc: {val_acc:.2f}%')

        # Final test evaluation
        test_correct = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                _, predicted = torch.max(outputs.data, 1)
                test_correct += (predicted == y).sum().item()

        test_acc = 100 * test_correct / len(test_dataset)
        params = sum(p.numel() for p in model.parameters()) / 1e6  # in millions
        return test_acc, params

    # Train and compare
    ode_acc, ode_params = train_model(ode_model, "ODE-Net")
    resnet_acc, resnet_params = train_model(resnet_model, "ResNet")

    # Print comparison table
    print("\n=== Final Comparison ===")
    print(f"{'Model':<10} | {'Test Acc (%)':<12} | {'Params (M)':<10} | {'Memory Efficiency':<16}")
    print("-"*50)
    print(f"{'ODE-Net':<10} | {ode_acc:<12.2f} | {ode_params:<10.2f} | O(1) (constant)")
    print(f"{'ResNet':<10} | {resnet_acc:<12.2f} | {resnet_params:<10.2f} | O(L) (linear)")

# Run the comparison
# train_and_compare()


Training ODE-Net...
Epoch 1/30 | Val Acc: 99.50%
Epoch 2/30 | Val Acc: 99.22%
Epoch 3/30 | Val Acc: 99.32%
Epoch 4/30 | Val Acc: 98.83%
Epoch 5/30 | Val Acc: 99.28%
Epoch 6/30 | Val Acc: 99.33%
Epoch 7/30 | Val Acc: 99.18%
Epoch 8/30 | Val Acc: 99.33%
Epoch 9/30 | Val Acc: 99.37%
Epoch 10/30 | Val Acc: 99.17%
Epoch 11/30 | Val Acc: 99.48%
Epoch 12/30 | Val Acc: 99.42%
Epoch 13/30 | Val Acc: 99.37%
Epoch 14/30 | Val Acc: 99.38%
Epoch 15/30 | Val Acc: 99.48%
Epoch 16/30 | Val Acc: 99.33%
Epoch 17/30 | Val Acc: 99.48%
Epoch 18/30 | Val Acc: 99.52%
Epoch 19/30 | Val Acc: 99.57%
Epoch 20/30 | Val Acc: 99.60%
Epoch 21/30 | Val Acc: 99.60%
Epoch 22/30 | Val Acc: 99.38%
Epoch 23/30 | Val Acc: 99.67%
Epoch 24/30 | Val Acc: 99.60%
Epoch 25/30 | Val Acc: 99.60%
Epoch 26/30 | Val Acc: 99.55%
Epoch 27/30 | Val Acc: 99.60%
Epoch 28/30 | Val Acc: 99.68%
Epoch 29/30 | Val Acc: 99.65%
Epoch 30/30 | Val Acc: 99.63%

Training ResNet...


NotImplementedError: Module [BasicResNet] is missing the required "forward" function