In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

from tqdm import tqdm
from torchinfo import summary

import utils.data as data

# MLP Mixer

In [2]:
class MLP(nn.Module):
    def __init__(self, dim, expansion):
        super(MLP, self).__init__()
        self.lin1 = nn.Linear(dim, dim * expansion)
        self.act = nn.GELU()
        self.lin2 = nn.Linear(dim * expansion, dim)

    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)

        return self.lin2(x)

In [3]:
class MixerLayer(nn.Module):
    def __init__(self, num_patches, num_channels, expansion):
        super(MixerLayer, self).__init__()
        self.norm1 = nn.LayerNorm(num_channels)
        self.by_patch = MLP(num_patches, expansion)
        self.by_channel = MLP(num_channels, expansion)
        self.norm2 = nn.LayerNorm(num_channels)

    def forward(self, x):
        # x.shape -> B, P, C
        identity = x
        x = self.norm1(x)

        # x.shape -> B, C, P
        x = self.by_patch(torch.transpose(x, 1, 2))

        # x.shape -> B, P, C
        x = torch.transpose(x, 1, 2) + identity

        # x.shape -> B, P, C
        identity = x
        x = self.norm2(x)
        x = self.by_channel(x) + identity

        return x

In [4]:
class MLPMixer(pl.LightningModule):
    def __init__(
        self,
        img_sz,
        img_channels,
        num_classes,
        depth,
        num_patches,
        num_channels,
        expansion,
    ):
        super(MLPMixer, self).__init__()
        self.img_sz = img_sz
        self.img_channels = img_channels

        self.num_patches = num_patches
        self.num_channels = num_channels

        self.patch_sz = int(((self.img_sz ** 2) // self.num_patches) ** (1 / 2))

        inp_channels = ((img_sz ** 2) // num_patches) * img_channels

        self.per_patch = nn.Linear(inp_channels, num_channels)

        self.mixer_layers = nn.ModuleList(
            [MixerLayer(num_patches, num_channels, expansion) for _ in range(depth)]
        )

        self.classifier = nn.Linear(num_channels, num_classes)

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(0)

        bs = x.shape[0]

        x = (
            x.data.unfold(1, self.img_channels, self.img_channels)
            .unfold(2, self.patch_sz, self.patch_sz)
            .unfold(3, self.patch_sz, self.patch_sz)
        )

        x = x.reshape(bs, -1, self.img_channels * self.patch_sz * self.patch_sz)

        x = self.per_patch(x)

        for layer in self.mixer_layers:
            x = layer(x)

        x = x.mean(1)

        return self.classifier(x)

    def training_step(self, xb, batch_idx):
        inp, labels = xb
        out = self(inp)

        return self.loss(out, labels)

    def validation_step(self, xb, batch_idx):
        inp, labels = xb
        out = self(inp)

        labels_hat = torch.argmax(out, dim=1)
        val_acc = torch.sum(labels == labels_hat).item() / (len(labels) * 1.0)

        self.log(
            "val_loss",
            self.loss(out, labels),
            prog_bar=True,
            on_step=True,
            on_epoch=True,
        )
        self.log("val_acc", val_acc, prog_bar=True, on_step=True, on_epoch=True)

        return val_acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=2e-4)

# Execution

In [5]:
IMG_SZ = 28
IMG_CHANNELS = 1
NUM_CLASSES = 10
NUM_PATCHES = 4
NUM_CHANNELS = 128
DEPTH = 10
EXPANSION = 4
EPOCHS = 5

model = MLPMixer(
    IMG_SZ,
    img_channels=IMG_CHANNELS,
    num_classes=NUM_CLASSES,
    depth=DEPTH,
    num_patches=NUM_PATCHES,
    num_channels=NUM_CHANNELS,
    expansion=EXPANSION,
)

print(summary(model, input_size=(1, IMG_CHANNELS, IMG_SZ, IMG_SZ)))

trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=EPOCHS,
    logger=pl.loggers.TensorBoardLogger("logs/", name="mlp_mixer", version=0),
    precision=16,
)

trainer.fit(model, train_dataloader=data.train_dl, val_dataloaders=data.val_dl)

model = model.eval()
ys = []
outs = []

with torch.no_grad():
    for x, y in tqdm(data.val_dl):
        out = model(x).detach()
        outs.append(out.reshape(x.shape[0], -1))
        ys.append(y)

outs = torch.cat(outs, dim=0)
labels = torch.cat(ys, dim=0)

labels_hat = torch.argmax(outs, dim=1)
val_acc = torch.sum(labels == labels_hat).item() / (len(labels) * 1.0)

print("\nValidation Accuracy: ", round(100*val_acc, 4))

ize (MB): 5.40
Estimated Total Size (MB): 5.90

  | Name         | Type             | Params
--------------------------------------------------
0 | per_patch    | Linear           | 25.2 K
1 | mixer_layers | ModuleList       | 1.3 M 
2 | classifier   | Linear           | 1.3 K 
3 | loss         | CrossEntropyLoss | 0     
--------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.401     Total estimated model params size (MB)
Epoch 0:  86%|████████▌ | 938/1095 [00:44<00:07, 20.86it/s, loss=0.105, v_num=0, val_loss_epoch=2.390, val_acc_epoch=0.0703]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/157 [00:00<?, ?it/s][A
Epoch 0:  86%|████████▌ | 940/1095 [00:45<00:07, 20.73it/s, loss=0.105, v_num=0, val_loss_epoch=2.390, val_acc_epoch=0.0703]
Epoch 0:  86%|████████▋ | 945/1095 [00:45<00:07, 20.80it/s, loss=0.105, v_num=0, val_loss_epoch=2.390, val_acc_epoch=0.0703]
Epoch 0:  87%|████████▋ | 952/1