MLP mixer proposes a way to use just mlps for vision

Its not better but it is competetive (at large scale) and could be researched upon ->  due to speed of infernece
plus has better tolerence to pixel shuffling 

It uses channel mixing and token mixing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms,datasets

In [2]:
class MlpBlock(nn.Module):
    """
    Standard MLP block == two linear layers and a GELU nonlinearity.
    The first layer expands the dimension to mlp_dim, then shrinks back.
    fc-glu-fc
    """
    def __init__(self, in_features, mlp_dim):
        super(MlpBlock, self).__init__()
        self.fc1 = nn.Linear(in_features, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, in_features)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

In [3]:
class MixerBlock(nn.Module):
    """
    A single Mixer block that separately mixes tokens and channels.
    It first applies token mixing (across patches) and then channel mixing (within features).
    part 1 then 2 in the arch diagram
    """
    def __init__(self, num_tokens, hidden_dim, tokens_mlp_dim, channels_mlp_dim,drop_path=0.1):
        super(MixerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        # (1) MLP applied to the token dimension (each channel separately)
        self.token_mixing = MlpBlock(num_tokens, tokens_mlp_dim)
        
        self.norm2 = nn.LayerNorm(hidden_dim)
        # (2) MLP applied to the channel dimension (each token separately)
        self.channel_mixing = MlpBlock(hidden_dim, channels_mlp_dim)
        self.drop_path = drop_path

                
    # def forward(self, x):
    #     # x: (batch, num_tokens, hidden_dim)

    #     """
    #     # Token mixing
    #     """
    #     y = self.norm1(x)
    #     y = y.transpose(1, 2)  # (B, hidden_dim, num_tokens)
    #     y = self.token_mixing(y)
    #     y = y.transpose(1, 2)  # back to (B, num_tokens, hidden_dim)
    #     x = x + y  # skip connection
    #     """
    #     # Channel mixing
    #     """
    #     y = self.norm2(x)
    #     y = self.channel_mixing(y)
    #     return x + y  # skip connection

    def forward(self, x):
        # Token mixing
        if self.training and torch.rand(1).item() < self.drop_path:
            y = 0
        else:
            y = self.norm1(x)
            y = y.transpose(1, 2)
            y = self.token_mixing(y)
            y = y.transpose(1, 2)
            if self.drop_path > 0:
                y = y / (1 - self.drop_path)
        x = x + y
        
        # Channel mixing
        if self.training and torch.rand(1).item() < self.drop_path:
            y = 0
        else:
            y = self.norm2(x)
            y = self.channel_mixing(y)
            if self.drop_path > 0:
                y = y / (1 - self.drop_path)
        return x + y


In [4]:
class MlpMixer(nn.Module):
    """
    The full MLP-Mixer network.
    Given an input image, it splits it into patches via a convolution ("stem"),
    then processes the resulting tokens with several Mixer blocks,

    applies a final layer norm-> global average pooling-> linear classifier.
    
    Rn using CIFAR-100 (32×32 images), we set a small patch size (4×4) as imagent is too big
    """
    def __init__(self, num_classes, num_blocks, patch_size, hidden_dim,
                 tokens_mlp_dim, channels_mlp_dim, image_size=32, in_channels=3):
        super(MlpMixer, self).__init__()

        
        self.patch_size = patch_size
        """
        # (1) The stem Conv2d splits the image into non-overlapping patches. 
        #     START OF THE PAPER WE SPLIT THE IMAGE INTO NON OVERLAPPING PATCH
        #     CNN does the same sort of thing
        """
        self.stem = nn.Conv2d(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
        self.num_tokens = (image_size // patch_size) ** 2
        

        """
        # (2) makes them into list of mixer block
            
        """
        # Create a list of Mixer blocks.
        self.mixer_blocks = nn.ModuleList([
            MixerBlock(num_tokens=self.num_tokens, hidden_dim=hidden_dim,
                       tokens_mlp_dim=tokens_mlp_dim, channels_mlp_dim=channels_mlp_dim)
            for _ in range(num_blocks)
        ])
        # Final layer normalization before classifiing
        self.norm = nn.LayerNorm(hidden_dim)

        # zero initialize the weights as in paper given
        self.head = nn.Linear(hidden_dim, num_classes)
        nn.init.zeros_(self.head.weight)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)
        
    def forward(self, x):
        # x: (batch, in_channels, image_size, image_size)
        x = self.stem(x)  # → (B, hidden_dim, H', W') where H' = image_size/patch_size
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # → (B, num_tokens, hidden_dim)
        for block in self.mixer_blocks:
            x = block(x)
        x = self.norm(x)
        x = x.mean(dim=1)  # global average pooling
        x = self.head(x)
        return x

In [5]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    avg_loss = running_loss / len(dataloader.dataset)
    return avg_loss

In [6]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [7]:

if __name__ == '__main__':
    # Data augmentation and normalization for training;
    # simple normalization for now , ask proff for better ones for smaller dataset/ larger
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    ])
    """
    Images are normalized using mean and standard deviation
    Data Augmentation: Random Horizontal Flip
    """


    # Load the Food-101 dataset
    train_dataset = datasets.Food101(root='./data', split='train', transform=transform, download=True)
    test_dataset = datasets.Food101(root='./data', split='test', transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Set up
    # We choose a small hidden_dim  but still get overfitting
# model = MlpMixer(
#         num_classes=100,     # 100 classes
#         num_blocks=8,
#         patch_size=4,
#         hidden_dim=256,      # Hidden channel size
#         tokens_mlp_dim=128,   # Hidden dimension for token-mixing MLP
#         channels_mlp_dim=512, # Hidden dimension for channel-mixing MLP
#         image_size=32,       # CIFAR 32x32
#         in_channels=3        # (RGB)
#     )
# model.to(device)

model = MlpMixer(
    num_classes=101,             # Food-101 has 101 classes
    num_blocks=8,               # More Mixer blocks for larger images
    patch_size=8,               # 16x16 patch for 224x224 input
    hidden_dim=128,              # Hidden dimension size
    tokens_mlp_dim=256,         # Token mixing MLP size
    channels_mlp_dim=512,       # Channel mixing MLP size
    image_size=224,              # Updated input size
    in_channels=3
)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.3e-2, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

epochs=50

print("Starting training on FOOD-101")

for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        train_acc = evaluate(model, train_loader, device)
        scheduler.step()
        avg_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Train Accuracy: {train_acc * 100:.2f}%")

print("Training complete :)")


Starting training on FOOD-101


Epoch [1/50], Loss: 4.2919, Test Accuracy: 7.99%
Epoch [2/50], Loss: 4.0319, Test Accuracy: 11.97%
Epoch [3/50], Loss: 3.8506, Test Accuracy: 14.63%
Epoch [4/50], Loss: 3.7238, Test Accuracy: 16.52%
Epoch [5/50], Loss: 3.6250, Test Accuracy: 18.15%
Epoch [6/50], Loss: 3.5340, Test Accuracy: 20.48%
Epoch [7/50], Loss: 3.4350, Test Accuracy: 21.96%
Epoch [8/50], Loss: 3.3497, Test Accuracy: 23.68%
Epoch [9/50], Loss: 3.2571, Test Accuracy: 25.14%
Epoch [10/50], Loss: 3.1602, Test Accuracy: 28.87%
Epoch [11/50], Loss: 3.0509, Test Accuracy: 31.07%
Epoch [12/50], Loss: 2.9512, Test Accuracy: 33.67%
Epoch [13/50], Loss: 2.8386, Test Accuracy: 36.35%
Epoch [14/50], Loss: 2.7121, Test Accuracy: 39.20%
Epoch [15/50], Loss: 2.5822, Test Accuracy: 41.70%
Epoch [16/50], Loss: 2.4541, Test Accuracy: 46.69%
Epoch [17/50], Loss: 2.3205, Test Accuracy: 50.39%
Epoch [18/50], Loss: 2.1790, Test Accuracy: 55.00%
Epoch [19/50], Loss: 2.0324, Test Accuracy: 57.97%
Epoch [20/50], Loss: 1.8838, Test Accurac

In [1]:
# ✅ Evaluation Loop
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    avg_loss = total_loss / total
    return avg_loss, acc
val_loss, val_acc = evaluate(model, test_loader, criterion)

NameError: name 'model' is not defined

Things we think we should fix :

Have better regularization as being overfit

maybe change in LR

have larger data set 

Dropout

Stochastic Depth ??

RandAugment ??

Mixup

Weight Decay (L2 regularization)