In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import matplotlib.pyplot as plt
from ema_pytorch import EMA
from torchinfo import summary
from ml_zoo.datamodules import MNISTDataModule

In [4]:
dm = MNISTDataModule(
    data_dir="data",
    dataset_params={
        "download": True,
        "transform": torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((32, 32)),
                torchvision.transforms.ToTensor(),
            ]
        ),
    },
    loader_params={
        "batch_size": 128,
        "num_workers": 2,
    },
)
dm.prepare_data()
dm.setup()
trian_loader = dm.train_dataloader()
test_loader = dm.test_dataloader()

In [18]:
class Classifer(nn.Module):
    def __init__(self):
        super(Classifer, self).__init__()
        self.backbone = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 8, 3, 1, 1),
            nn.GELU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, 3, 1, 1),
            nn.GELU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.GELU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 16, 3, 1, 1),
            nn.GELU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.classifier = nn.Sequential(
            nn.GELU(),
            nn.BatchNorm1d(16 * 2 * 2),
            nn.Linear(16 * 2 * 2, 10),
        )

    def forward(self, x):
        x = self.backbone(x).view(x.size(0), -1)
        x = self.classifier(x)
        return x


model = Classifer().to("mps")
summary(
    model,
    input_data=torch.randn(1, 1, 32, 32, device="mps", requires_grad=False),
    depth=2,
)

Layer (type:depth-idx)                   Output Shape              Param #
Classifer                                [1, 10]                   --
├─Sequential: 1-1                        [1, 16, 2, 2]             --
│    └─BatchNorm2d: 2-1                  [1, 1, 32, 32]            2
│    └─Conv2d: 2-2                       [1, 8, 32, 32]            80
│    └─GELU: 2-3                         [1, 8, 32, 32]            --
│    └─MaxPool2d: 2-4                    [1, 8, 16, 16]            --
│    └─BatchNorm2d: 2-5                  [1, 8, 16, 16]            16
│    └─Conv2d: 2-6                       [1, 16, 16, 16]           1,168
│    └─GELU: 2-7                         [1, 16, 16, 16]           --
│    └─MaxPool2d: 2-8                    [1, 16, 8, 8]             --
│    └─BatchNorm2d: 2-9                  [1, 16, 8, 8]             32
│    └─Conv2d: 2-10                      [1, 16, 8, 8]             2,320
│    └─GELU: 2-11                        [1, 16, 8, 8]             --
│    └─Max

In [19]:
ema = EMA(model, beta=0.9999, update_after_step=100, update_every=10)
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
criterion = nn.CrossEntropyLoss()

In [20]:
roll_loss = 0
test_loss = 0
test_acc = 0
for epoch in range(10):
    model.train()
    pbar = tqdm(trian_loader, desc=f"Epoch {epoch+1}")
    for img, label in pbar:
        img, label = img.to("mps"), label.to("mps")
        optimizer.zero_grad()
        output = model(img)
        
        loss = criterion(output, label)
            
        loss.backward()
        optimizer.step()
        roll_loss = roll_loss * 0.9 + loss.item() * 0.1

        ema.update()
        pbar.set_postfix_str(f"loss: {roll_loss:.4f}, test_loss: {test_loss:.4f}, test_acc: {test_acc:.4f}")

    model.eval()
    test_loss = 0
    test_acc = 0
    with torch.no_grad():
        for img, label in tqdm(test_loader, desc="Testing", leave=True):
            img, label = img.to("mps"), label.to("mps")
            output = model(img)
            test_loss += criterion(output, label)
            test_acc += (output.argmax(1) == label).float().mean()

    test_loss /= len(test_loader)
    test_loss = test_loss.item()
    test_acc /= len(test_loader)
    test_acc = test_acc.item()

Epoch 1: 100%|██████████| 469/469 [00:08<00:00, 56.53it/s, loss: 0.1224, test_loss: 0.0000, test_acc: 0.0000]
Testing: 100%|██████████| 79/79 [00:01<00:00, 45.10it/s]
Epoch 2: 100%|██████████| 469/469 [00:06<00:00, 71.00it/s, loss: 0.0630, test_loss: 0.1456, test_acc: 0.9712]
Testing: 100%|██████████| 79/79 [00:01<00:00, 47.03it/s]
Epoch 3: 100%|██████████| 469/469 [00:06<00:00, 71.45it/s, loss: 0.0466, test_loss: 0.0758, test_acc: 0.9823]
Testing: 100%|██████████| 79/79 [00:01<00:00, 46.98it/s]
Epoch 4: 100%|██████████| 469/469 [00:06<00:00, 69.54it/s, loss: 0.0406, test_loss: 0.0573, test_acc: 0.9839]
Testing: 100%|██████████| 79/79 [00:01<00:00, 46.80it/s]
Epoch 5: 100%|██████████| 469/469 [00:06<00:00, 73.09it/s, loss: 0.0433, test_loss: 0.0474, test_acc: 0.9857]
Testing: 100%|██████████| 79/79 [00:01<00:00, 45.71it/s]
Epoch 6: 100%|██████████| 469/469 [00:06<00:00, 72.39it/s, loss: 0.0812, test_loss: 0.0504, test_acc: 0.9845]
Testing: 100%|██████████| 79/79 [00:01<00:00, 46.11it/s

KeyboardInterrupt: 