#**Neural Networks Project**<br>
**Implementing the paper [Pay Attention to MLPs](https://arxiv.org/pdf/2105.08050v1.pdf)**<br>

by *Daniel Caliman* (2122749, calimandaniel5@gmail.com) and *Nikolas Jochens* (2118698, nj@andaco.de)

*For the current project, we will user the CIFAR-10 Dataset. The dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research. The CIFAR-10 dataset contains 60,000 32x32 color images in 10 different classes. The 10 different classes represent airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. There are 6,000 images of each class.*

In [48]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn as nn
from typing import Optional
import einops
from tqdm import tqdm
import numpy as np
from PIL import Image

In [49]:
# hyperparameters for vision and language models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64

n_epochs = 30
img_size = 32
n_classes = 10
lr = 0.001

tensor_transforms = transforms.Compose([transforms.ToTensor()])
print(device) # check if GPU is available

cuda


In [50]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=tensor_transforms)
validation_data = datasets.CIFAR10("data/", train=False, download=True, transform=tensor_transforms)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=False)
x, y = next(iter(train_loader))
print(len(train_data), x.shape, y.shape)

Files already downloaded and verified
Files already downloaded and verified
50000 torch.Size([64, 3, 32, 32]) torch.Size([64])


In [51]:
class SpacialGatingUnit(nn.Module):

    def __init__(self, d_z: int, seq_len: int):
        super().__init__()
        self.norm = nn.LayerNorm([d_z // 2])
        self.weight = nn.Parameter(torch.zeros(seq_len, seq_len).uniform_(-0.01, 0.01), requires_grad=True)
        self.bias = nn.Parameter(torch.ones(seq_len), requires_grad=True)

    def forward(self, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
        seq_len = z.shape[0]
        z1, z2 = torch.chunk(z, 2, dim=-1)
        if mask is not None:
            assert mask.shape[0] == 1 or mask.shape[0] == seq_len
            assert mask.shape[1] == seq_len
            assert mask.shape[2] == 1
            mask = mask[:, :, 0]

        z2 = self.norm(z2)
        weight = self.weight[:seq_len, :seq_len]
        if mask is not None:
            weight = weight * mask
        z2 = torch.einsum('ij,jbd->ibd', weight, z2) + self.bias[:seq_len, None, None]
        return z1 * z2

In [52]:
class GMLPBlock(nn.Module):
    def __init__(self, d_model: int, d_ffn: int, seq_len: int):
        super().__init__()
        self.norm = nn.LayerNorm([d_model])
        self.activation = nn.GELU()
        self.proj1 = nn.Linear(d_model, d_ffn)
        self.sgu = SpacialGatingUnit(d_ffn, seq_len)
        self.proj2 = nn.Linear(d_ffn // 2, d_model)
        self.size = d_model

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        shortcut = x
        x = self.norm(x)
        z = self.activation(self.proj1(x))
        z = self.sgu(z, mask)
        z = self.proj2(z)

        return z + shortcut


In [53]:
class gMLP(nn.Module):
    def __init__(self, seq_len=256, d_model=256, d_ffn=512, n_layers=6):
        super().__init__()
        self.blocks = nn.Sequential(
            *[GMLPBlock(d_model, d_ffn, seq_len) for _ in range(n_layers)]
        )

    def forward(self, x):
        return self.blocks(x)

In [54]:
class gMLPVisionModel(nn.Module):
    def __init__(self, in_channels=3, image_size=256, patch_size=4, d_model=32, d_ffn=64, n_layers=6, n_classes=1000):
        super().__init__()
        assert image_size % patch_size == 0, "image size must be divisible by patch size!!"
        n_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(in_channels, d_model, kernel_size=patch_size, stride=patch_size)
        self.gmlp = gMLP(n_patches, d_model, d_ffn, n_layers)
        self.fc_out = nn.Linear(d_model, n_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = einops.rearrange(x, "b c h w -> b (h w) c")
        x = self.gmlp(x)
        x = x.mean(1)
        out = self.fc_out(x)
        return out

In [55]:
gmlp_vm = gMLPVisionModel(n_classes=n_classes, image_size = 32).to(device)
inp = torch.randn(1, 3, img_size, img_size).to(device)
out = gmlp_vm(inp)
print(out.shape)
del inp, out

torch.Size([1, 10])


In [56]:
optimizer_vm = torch.optim.Adam(gmlp_vm.parameters(), lr=lr)
loss_fn_vm = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)
    return acc

In [57]:
def loop_vm(net, loader, is_train):
    net.train(is_train)
    losses = []
    accs = []
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        with torch.set_grad_enabled(is_train):
            preds = net(x)
            loss = loss_fn_vm(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())
        if is_train:
            optimizer_vm.zero_grad()
            loss.backward()
            optimizer_vm.step()
        pbar.set_description(f'epoch={epoch}, train={int(is_train)}, loss={np.mean(losses):.4f}, acc={np.mean(accs):.4f}')

In [58]:
for epoch in range(n_epochs):
    loop_vm(gmlp_vm, train_loader, True)
    loop_vm(gmlp_vm, validation_loader, False)

epoch=0, train=1, loss=1.9171, acc=0.2753: 100%|██████████| 782/782 [00:36<00:00, 21.35it/s]
epoch=0, train=0, loss=1.7428, acc=0.3398: 100%|██████████| 157/157 [00:05<00:00, 31.04it/s]
epoch=1, train=1, loss=1.6547, acc=0.3826: 100%|██████████| 782/782 [00:32<00:00, 24.41it/s]
epoch=1, train=0, loss=1.5855, acc=0.4174: 100%|██████████| 157/157 [00:04<00:00, 36.90it/s]
epoch=2, train=1, loss=1.5374, acc=0.4368: 100%|██████████| 782/782 [00:26<00:00, 29.45it/s]
epoch=2, train=0, loss=1.5457, acc=0.4324: 100%|██████████| 157/157 [00:03<00:00, 51.53it/s]
epoch=3, train=1, loss=1.4659, acc=0.4650: 100%|██████████| 782/782 [00:27<00:00, 28.77it/s]
epoch=3, train=0, loss=1.4561, acc=0.4593: 100%|██████████| 157/157 [00:04<00:00, 35.93it/s]
epoch=4, train=1, loss=1.4255, acc=0.4831: 100%|██████████| 782/782 [00:27<00:00, 28.34it/s]
epoch=4, train=0, loss=1.3946, acc=0.4874: 100%|██████████| 157/157 [00:02<00:00, 64.57it/s]
epoch=5, train=1, loss=1.3838, acc=0.4961: 100%|██████████| 782/782 [0

KeyboardInterrupt: 

In [None]:
@torch.no_grad()
def recognize_img(net, img):
    net.eval()
    img = Image.open(img).convert("RGB")
    img = tensor_transforms(img).to(device)
    pred = net(img.unsqueeze(0))
    pred = pred.argmax(dim=1)
    return train_data.classes[pred.item()]

In [None]:
out = recognize_img(gmlp_vm, 'dog.jpg')
print(out)

bird
