In [1]:
import torch
import torch.nn as nn

In [2]:
from torch.utils.data import DataLoader

In [3]:
import torchvision

In [4]:
# download the mnist dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)

In [5]:
class MnistModel(nn.Module):
    
    def __init__(
        self,
        x_len: int = 28,
    ):
        # Convulsion + Attention
        super(MnistModel, self).__init__()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len, x_len]
        self.conv = nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1
        )
        self.relu = nn.ReLU()
        # [batch, channel, x_len, x_len] -> [batch, channel, x_len/2, x_len/2]
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )
        # [batch, channel, x_len/2, x_len/2] -> [batch, channel, x_len/4, x_len/4]
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # [batch, channel, l, l] -> [batch, channel, l * l]
        self.flatten_before_attn = nn.Flatten(start_dim=-2)
        # [batch, channel, l * l] -> [batch, channel, l * l]
        self.attn = nn.MultiheadAttention(
            embed_dim=(x_len // 4) ** 2, num_heads=1, batch_first=True
        )
        # [batch, channel, l * l] -> [batch, channel * l * l]
        self.flatten_after_attn = nn.Flatten(start_dim=-2)
        flattened_last_dim = 64 * (x_len // 4) ** 2
        self.fc1 = nn.Linear(flattened_last_dim, flattened_last_dim * 2)
        self.activation1 = nn.ReLU()
        self.fc2 = nn.Linear(flattened_last_dim * 2, flattened_last_dim)
        self.activation2 = nn.ReLU()
        self.fc3 = nn.Linear(flattened_last_dim, 10)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = self.flatten_before_attn(x)
        x, _ = self.attn(x, x, x)
        x = self.flatten_after_attn(x)
        x = self.fc1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        return x
        

In [6]:
device = "mps"

In [7]:
import wandb

In [8]:
wandb.init(
    project="demo",
    name="mnist-demo",
    tags=["demo"],
    config={
        "lr": 1e-4,
        "epoch": 4,
        "batch_size": 128,
        "weight_decay": 1e-5
    }
)
wandb.define_metric("loss", summary="min")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [None]:
model = MnistModel().to(device=device)

In [None]:
model(train_dataset[0][0].unsqueeze(0).to(device=device))

tensor([[-0.0032, -0.0035, -0.0056, -0.0138, -0.0096, -0.0110,  0.0192,  0.0166,
          0.0048,  0.0017]], device='mps:0', grad_fn=<LinearBackward0>)

In [None]:
optim = torch.optim.Adam(
    model.parameters(),
    lr=wandb.config.lr,
    weight_decay=wandb.config.weight_decay
)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=wandb.config.batch_size,
    shuffle=True
)

In [None]:
import time

In [None]:
start = time.time()

In [None]:
loss_fn = nn.CrossEntropyLoss()
for ep in range(wandb.config.epoch):
    for idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)
        optim.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optim.step()
        wandb.log({"loss": loss.item()})
torch.save(model.state_dict(), f"result.pth")
wandb.log_model(path=f"result.pth", name=f"result")

In [None]:
end = time.time()
wandb.summary["duration"] = end - start
steps = wandb.config.epoch * len(train_loader)
wandb.summary["steps"] = steps
wandb.summary["steps_per_second"] = steps / (end - start)

In [None]:
eval_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)

In [None]:
# evaluate the model
model.eval()

MnistModel(
  (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten_before_attn): Flatten(start_dim=-2, end_dim=-1)
  (attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=49, out_features=49, bias=True)
  )
  (flatten_after_attn): Flatten(start_dim=-2, end_dim=-1)
  (fc1): Linear(in_features=3136, out_features=6272, bias=True)
  (activation1): ReLU()
  (fc2): Linear(in_features=6272, out_features=3136, bias=True)
  (activation2): ReLU()
  (fc3): Linear(in_features=3136, out_features=10, bias=True)
)

In [None]:
eval_loader = DataLoader(eval_dataset, batch_size=wandb.config.batch_size)

In [None]:
total = 0
correct = 0

for batch in eval_loader:
    x, y = batch
    x, y = x.to(device), y.to(device)
    y_pred = model(x)
    _, predicted = torch.max(y_pred, 1)
    total += y.size(0)
    correct += (predicted == y).sum().item()

In [None]:
correct / total

0.9619

In [None]:
# note the accuracy to wandb
wandb.summary["acc"] = correct / total

In [None]:
wandb.finish()

VBox(children=(Label(value='150.332 MB of 150.336 MB uploaded\r'), FloatProgress(value=0.9999741814968328, max…

0,1
loss,█▅▄▃▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▂▁▁▂▁▁▁▁▁▁▁▁▂▁▁▁▁▁

0,1
acc,0.9619
duration,88.63494
steps,1876.0
steps_per_second,21.16547
