In [None]:
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from config import get_config
from model import MLPMixer, PatchEmbedding, Transformation1, Transformation2, MixerLayer

In [None]:
# Hyperparameters
config = get_config()

# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(config["in_channels"])], [0.5 for _ in range(config["in_channels"])]
        ),
    ]
)

In [None]:
# MNIST Dataset
# trainset = torchvision.datasets.MNIST(root="./mnist", train=True, download=True, transform=transform)
# testset = torchvision.datasets.MNIST(root="./mnist", train=False, download=True, transform=transform)

# CIFAR10 Dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)


# Visualize dataset

In [None]:
plt.figure(figsize=(10, 10))

for i in range(1, 101, 1):
  plt.subplot(10, 10, i)
  # plt.imshow(trainset.data[i], cmap='gray')
  plt.imshow(trainset.data[i])
  plt.axis('off')
plt.show()

## MNIST

- Input shape: `torch.Size([64, 1, 28, 28])`
- Patch Embedding output shape: `torch.Size([64, 128, 7, 7])`
- T2 transformation output shape: `torch.Size([64, 49, 128])`
- T1 transformation output shape: `torch.Size([64, 128, 49])`
- T1 transformation output shape: `torch.Size([64, 49, 128])`
- Mixer Layer output shape: `torch.Size([64, 49, 128])`

## CIFAR10

- Input shape: `torch.Size([64, 3, 32, 32])`
- Patch Embedding output shape: `torch.Size([64, 128, 8, 8])`
- T2 transformation output shape: `torch.Size([64, 64, 128])`
- T1 transformation output shape: `torch.Size([64, 128, 64])`
- T1 transformation output shape: `torch.Size([64, 64, 128])`
- Mixer Layer output shape: `torch.Size([64, 64, 128])`

In [None]:
X = torch.rand(64, 1, 28, 28) # MNIST (64, 1, 28, 28), CIFAR10 (64, 3, 32, 32)

pe = PatchEmbedding(1, 128, 4) # MNIST (1, 128, 4), CIFAR10 (3, 128, 4) 
t1 = Transformation1()
t2 = Transformation2()
ml = MixerLayer(128, 49, 256, 256) # MNIST (128, 49, 256, 256), CIFAR10 (128, 64, 256, 256)

print(f"Input shape: {X.shape}")
y1 = pe(X)
print(f"Patch Embedding output shape: {y1.shape}")
y2 = t2(y1)
print(f"T2 transformation output shape: {y2.shape}")
y3 = t1(y2)
print(f"T1 transformation output shape: {y3.shape}")
y4 = t1(y3)
print(f"T1 transformation output shape: {y4.shape}")
y5 = ml(y4)
print(f"Mixer Layer output shape: {y5.shape}")

In [None]:
model = MLPMixer(in_channels=config["in_channels"], image_size=config["image_size"], patch_size=2, num_classes=10,
                  embedding_dim=config["channel_dim"], depth=config["depth"], token_intermediate_dim=config["token_dim"], channel_intermediate_dim=config["channel_dim"]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

In [None]:
# Get accuracy on training & test to see how good our model is
def get_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            _, predictions = logits.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

    model.train()
    return num_correct / num_samples

In [None]:
for epoch in range(config["num_epochs"]):
    model.train()
    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
    for batch_index, (images, targets) in loop:
        images = images.to(device)
        targets = targets.to(device)
        logits = model(images)
        loss =  criterion(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
print(f"Accuracy on training set: {get_accuracy(train_loader, model)*100:.2f}")
print(f"Accuracy on test set: {get_accuracy(test_loader, model)*100:.2f}")

## MNIST

Accuracy on training set: 99.54

Accuracy on test set: 97.70

## CIFAR10

Accuracy on training set: 92.30

Accuracy on test set: 59.81