In [1]:
import torch
import torch.nn as nn
import numpy as np
from einops.layers.torch import Rearrange

In [2]:
class MLPBlock(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0):
        super().__init__()
        self.linear1 = nn.Linear(dim, hidden_dim)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_dim, dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(self.gelu(self.linear1(x)))
        x = self.dropout2(self.linear2(x))
        return x

In [3]:
class MixerBlock(nn.Module):
    def __init__(self, dim, num_patch, token_dim, channel_dim, dropout = 0.):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'), #[1, 512, 196]
            MLPBlock(num_patch, token_dim, dropout), #[1, 512 , 196]
            Rearrange('b d n -> b n d')  #[1, 196 , 512]
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim),
            MLPBlock(dim, channel_dim, dropout) #[1, 196, 512]
        )

    def forward(self, x):

        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)

        return x
    

In [4]:
class MLPMixer(nn.Module):

    def __init__(self, 
                 in_channels = 3, 
                 dim = 512, 
                 num_classes = 1000, 
                 patch_size = 16, 
                 image_size = 224, 
                 depth = 8, 
                 token_dim = 256, 
                 channel_dim = 2048):
        super().__init__()
        
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        self.num_patch =  (image_size// patch_size) ** 2
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, dim, patch_size, patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )
        
        self.mixer_blocks = nn.ModuleList([])

        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim))

        self.layer_norm = nn.LayerNorm(dim)
        self.mlp_head = nn.Sequential(
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        ### x shape: [1, 3, 224, 224]
        x = self.to_patch_embedding(x)
        ### x shape: [1, 196, 512]

        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)

        ### x shape [1, 196, 512]
        x = self.layer_norm(x)

        ### x shape [1, 512]
        x = x.mean(dim=1)

        return self.mlp_head(x) ### [1, num_classes]

In [5]:
img = torch.ones([1, 3, 224, 224])

model = MLPMixer(in_channels=3, image_size=224, patch_size=16, num_classes=1000, dim=512, depth=8, token_dim=256, channel_dim=2048)

parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)

out_img = model(img)
print("Shape of out :", out_img.shape)  # [batch_size, num_classes]

Trainable Parameters: 18.528M
Shape of out : torch.Size([1, 1000])
