In [1]:
!pip install torch torchvision torchaudio




In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from madgrad.madgrad import MADGRAD

In [3]:
# Transform the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [4]:
# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [5]:
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)


In [6]:
# Define your model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = nn.functional.log_softmax(x, dim=1)
        return output

def train_model(optimizer, model):
    # Define your loss function
    criterion = nn.CrossEntropyLoss()

    losses = []
    # Train your model
    for epoch in range(20):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:    # print every 100 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.0
        losses.append(running_loss)
    return losses


In [7]:
# Instantiate your model
model_madgrad = Net()
model_adam = Net()
model_sgd = Net()


In [8]:
# Define your MADGRAD optimizer
optimizer_madgrad = MADGRAD(params=model_madgrad.parameters(), lr=0.001, momentum=0.9, weight_decay=0)

In [9]:
# Define your ADAM optimizer
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)

In [10]:
# Define your SGD optimizer
optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=0.001, momentum=0.9)


In [11]:
# Train models with different optimizers
losses_madgrad = train_model(optimizer_madgrad, model_madgrad)
losses_adam = train_model(optimizer_adam, model_adam)
losses_sgd = train_model(optimizer_sgd, model_sgd)



[1,   100] loss: 0.865
[1,   200] loss: 0.311
[1,   300] loss: 0.224
[1,   400] loss: 0.195
[1,   500] loss: 0.190
[1,   600] loss: 0.164
[1,   700] loss: 0.141
[1,   800] loss: 0.148
[1,   900] loss: 0.152
[2,   100] loss: 0.116
[2,   200] loss: 0.132
[2,   300] loss: 0.120
[2,   400] loss: 0.119
[2,   500] loss: 0.103
[2,   600] loss: 0.116
[2,   700] loss: 0.107
[2,   800] loss: 0.110
[2,   900] loss: 0.100
[3,   100] loss: 0.102
[3,   200] loss: 0.089
[3,   300] loss: 0.085
[3,   400] loss: 0.076
[3,   500] loss: 0.087
[3,   600] loss: 0.083
[3,   700] loss: 0.092
[3,   800] loss: 0.083
[3,   900] loss: 0.084
[4,   100] loss: 0.076
[4,   200] loss: 0.077
[4,   300] loss: 0.078
[4,   400] loss: 0.077
[4,   500] loss: 0.074
[4,   600] loss: 0.066
[4,   700] loss: 0.087
[4,   800] loss: 0.066
[4,   900] loss: 0.075
[5,   100] loss: 0.067
[5,   200] loss: 0.058
[5,   300] loss: 0.064
[5,   400] loss: 0.057
[5,   500] loss: 0.068
[5,   600] loss: 0.055
[5,   700] loss: 0.074
[5,   800] 

[19,   800] loss: 0.022
[19,   900] loss: 0.020
[20,   100] loss: 0.014
[20,   200] loss: 0.017
[20,   300] loss: 0.017
[20,   400] loss: 0.021
[20,   500] loss: 0.020
[20,   600] loss: 0.016
[20,   700] loss: 0.016
[20,   800] loss: 0.024
[20,   900] loss: 0.021
[1,   100] loss: 2.201
[1,   200] loss: 1.459
[1,   300] loss: 0.804
[1,   400] loss: 0.586
[1,   500] loss: 0.527
[1,   600] loss: 0.456
[1,   700] loss: 0.422
[1,   800] loss: 0.412
[1,   900] loss: 0.394
[2,   100] loss: 0.353
[2,   200] loss: 0.337
[2,   300] loss: 0.331
[2,   400] loss: 0.325
[2,   500] loss: 0.300
[2,   600] loss: 0.302
[2,   700] loss: 0.301
[2,   800] loss: 0.281
[2,   900] loss: 0.275
[3,   100] loss: 0.264
[3,   200] loss: 0.259
[3,   300] loss: 0.250
[3,   400] loss: 0.249
[3,   500] loss: 0.236
[3,   600] loss: 0.244
[3,   700] loss: 0.225
[3,   800] loss: 0.222
[3,   900] loss: 0.228
[4,   100] loss: 0.223
[4,   200] loss: 0.221
[4,   300] loss: 0.217
[4,   400] loss: 0.194
[4,   500] loss: 0.194


In [None]:
# Plot the losses
import matplotlib.pyplot as plt

epochs = range(1, 301)
plt.plot(epochs, losses_madgrad, label='MADGRAD')
plt.plot(epochs, losses_adam, label='Adam')
plt.plot(epochs, losses_sgd, label='SGD')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Comparison')
plt.legend()
plt.show()