In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torchvision.transforms.v2
from torchinfo import summary
from tqdm import tqdm
from ema_pytorch import EMA

In [None]:
train_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)
val_dataset = torchvision.datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=torchvision.transforms.v2.Compose(
        [
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ]
    ),
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2, persistent_workers=True
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, batch_size=64, shuffle=False
)

In [None]:
@torch.no_grad()
def evaluate():
    model.eval()
    accs = []
    losses = []

    for x, y in val_loader:
        x, y = x.to("mps"), y.to("mps")
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        accs.append(acc.item())
        losses.append(loss.item())

    model.train()
    return {
        "accuracy": sum(accs) / len(accs),
        "loss": sum(losses) / len(losses),
    }

In [None]:
class DepthwiseSeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.GroupNorm(4, out_channels),
            nn.GELU(),
        )

    def forward(self, x):
        return self.layers(x)

class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.layers = nn.ModuleDict(
            {
                "conv1": DepthwiseSeparableConv2d(in_channels, out_channels),
                "conv2": DepthwiseSeparableConv2d(out_channels, out_channels),
                "conv3": DepthwiseSeparableConv2d(out_channels, out_channels),
            }
        )

    def forward(self, x):
        x = self.layers["conv1"](x)
        x = x + self.layers["conv2"](x)
        x = x + self.layers["conv3"](x)
        
        return x


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleDict(
            {
                "backbone": nn.Sequential(
                    nn.GroupNorm(3, 3),
                    Block(3, 64),
                    nn.MaxPool2d(2),
                    Block(64, 128),
                    nn.MaxPool2d(2),
                    Block(128, 256),
                    nn.MaxPool2d(2),
                    Block(256, 256),
                    nn.MaxPool2d(2),
                    Block(256, 256),
                    nn.MaxPool2d(2),
                    Block(256, 256),

                ),
                "classifier": nn.Sequential(
                    nn.AdaptiveAvgPool2d(2),
                    nn.Flatten(),
                    nn.Dropout(0.6),
                    nn.Linear(256 * 2 * 2, 10),
                ),
            }
        )

    def forward(self, x):
        x = self.layers["backbone"](x)
        x = self.layers["classifier"](x)
        return x


model = Model()
print(summary(model, (1, 3, 32, 32)))
model = model.to("mps")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [None]:
ema = EMA(model, beta=0.9999, update_after_step=1000, update_every=10)

In [None]:
display_loss = 0
for epoch in range(10):
    metrics = evaluate()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    for x, y in pbar:
        x, y = x.to("mps"), y.to("mps")
        y_hat = model(x)
        loss = loss_fn(y_hat, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        display_loss = display_loss * 0.9 + loss.item() * 0.1
        ema.update()

        pbar.set_postfix_str(f"loss: {display_loss:.4f}, val_loss: {metrics['loss']:.4f}, val_acc: {metrics['accuracy']:.4f}")
