In [64]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Assuming other parts of your imports and code remain the same

# Load MNIST and modify transform for ViT compatibility
transform = transforms.Compose([
    transforms.Resize(32),  # Resize to the input size expected by ViT
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load dataset as before
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000, shuffle=False)

# Replace Net with a ViT model
model_vit = timm.create_model('swin_small_patch4_window7_224', num_classes=10, pretrained=False)
optimizer_vit = optim.Adam(model_vit.parameters(), lr=0.001)

# Define a simple neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(32*32, 512)  # Adjust the input size to 32x32
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 32*32)  # Adjust flattening to 32x32
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Training and testing functions adapted for use with both FC and ViT models
def train_model(model, optimizer, criterion, trainloader, epochs=10, is_vit=False):
    model.train()
    losses = []
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            if is_vit:
                outputs = model(inputs)  # ViT uses the whole image
            else:
                inputs = inputs.view(-1, 32*32)  # Correct the flattening to match 32x32 images
                outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(trainloader)
        losses.append(avg_loss)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
    return losses

def test_model(model, testloader, is_vit=False):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            if is_vit:
                outputs = model(inputs)  # ViT uses the whole image
            else:
                inputs = inputs.view(-1, 32*32)  # Correct the flattening to match 32x32 images
                outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Initialize models, criterion, and optimizers
models = [Net() for _ in range(5)]
optimizers = [
    AdamOptimizerWithMorse(models[0].parameters(), lr=0.001),
    optim.Adam(models[1].parameters(), lr=0.001),
    optim.AdamW(models[2].parameters(), lr=0.001),
    Lion(models[3].parameters(), lr=0.001),
    Adan(models[4].parameters(), lr=0.001)
]
names = ["Adam with Morse theory and Floer Homotopy", "Built-in Adam", "Built-in AdamW", "New Lion", "Adan"]
criterion = nn.CrossEntropyLoss()

# Train and test all models
for model, optimizer, name in zip(models, optimizers, names):
    print(f"Training with {name}:")
    losses = train_model(model, optimizer, criterion, trainloader)
    accuracy = test_model(model, testloader)
    print(f"{name} Accuracy: {accuracy:.2f}%")

Training with Adam with Morse theory and Floer Homotopy:
Epoch [1/10], Loss: 0.3195
Epoch [2/10], Loss: 0.1484
Epoch [3/10], Loss: 0.1115
Epoch [4/10], Loss: 0.0932
Epoch [5/10], Loss: 0.0830
Epoch [6/10], Loss: 0.0727
Epoch [7/10], Loss: 0.0652
Epoch [8/10], Loss: 0.0596
Epoch [9/10], Loss: 0.0559
Epoch [10/10], Loss: 0.0505
Adam with Morse theory and Floer Homotopy Accuracy: 97.63%
Training with Built-in Adam:
Epoch [1/10], Loss: 0.3140
Epoch [2/10], Loss: 0.1468
Epoch [3/10], Loss: 0.1096
Epoch [4/10], Loss: 0.0936
Epoch [5/10], Loss: 0.0813
Epoch [6/10], Loss: 0.0712
Epoch [7/10], Loss: 0.0657
Epoch [8/10], Loss: 0.0602
Epoch [9/10], Loss: 0.0569
Epoch [10/10], Loss: 0.0489
Built-in Adam Accuracy: 97.25%
Training with Built-in AdamW:
Epoch [1/10], Loss: 0.3093
Epoch [2/10], Loss: 0.1476
Epoch [3/10], Loss: 0.1122
Epoch [4/10], Loss: 0.0939
Epoch [5/10], Loss: 0.0829
Epoch [6/10], Loss: 0.0705
Epoch [7/10], Loss: 0.0659
Epoch [8/10], Loss: 0.0598
Epoch [9/10], Loss: 0.0568
Epoch [10