MLP mixer proposes a way to use just mlps for vision

Its not better but it is competetive (at large scale) and could be researched upon ->  due to speed of infernece
plus has better tolerence to pixel shuffling

It uses channel mixing and token mixing

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

In [2]:
class MlpBlock(nn.Module):
    """
    Standard MLP block == two linear layers and a GELU nonlinearity.
    The first layer expands the dimension to mlp_dim, then shrinks back.
    fc-glu-fc
    """
    def __init__(self, in_features, mlp_dim):
        super(MlpBlock, self).__init__()
        self.fc1 = nn.Linear(in_features, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, in_features)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x

In [3]:
class MixerBlock(nn.Module):
    """
    A single Mixer block that separately mixes tokens and channels.
    It first applies token mixing (across patches) and then channel mixing (within features).
    part 1 then 2 in the arch diagram
    """
    def __init__(self, num_tokens, hidden_dim, tokens_mlp_dim, channels_mlp_dim,drop_path=0.1):
        super(MixerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_dim)
        # (1) MLP applied to the token dimension (each channel separately)
        self.token_mixing = MlpBlock(num_tokens, tokens_mlp_dim)

        self.norm2 = nn.LayerNorm(hidden_dim)
        # (2) MLP applied to the channel dimension (each token separately)
        self.channel_mixing = MlpBlock(hidden_dim, channels_mlp_dim)
        self.drop_path = drop_path


    # def forward(self, x):
    #     # x: (batch, num_tokens, hidden_dim)

    #     """
    #     # Token mixing
    #     """
    #     y = self.norm1(x)
    #     y = y.transpose(1, 2)  # (B, hidden_dim, num_tokens)
    #     y = self.token_mixing(y)
    #     y = y.transpose(1, 2)  # back to (B, num_tokens, hidden_dim)
    #     x = x + y  # skip connection
    #     """
    #     # Channel mixing
    #     """
    #     y = self.norm2(x)
    #     y = self.channel_mixing(y)
    #     return x + y  # skip connection

    def forward(self, x):
        # Token mixing
        if self.training and torch.rand(1).item() < self.drop_path:
            y = 0
        else:
            y = self.norm1(x)
            y = y.transpose(1, 2)
            y = self.token_mixing(y)
            y = y.transpose(1, 2)
            if self.drop_path > 0:
                y = y / (1 - self.drop_path)
        x = x + y

        # Channel mixing
        if self.training and torch.rand(1).item() < self.drop_path:
            y = 0
        else:
            y = self.norm2(x)
            y = self.channel_mixing(y)
            if self.drop_path > 0:
                y = y / (1 - self.drop_path)
        return x + y


In [4]:
class MlpMixer(nn.Module):
    """
    The full MLP-Mixer network.
    Given an input image, it splits it into patches via a convolution ("stem"),
    then processes the resulting tokens with several Mixer blocks,

    applies a final layer norm-> global average pooling-> linear classifier.

    Rn using CIFAR-100 (32×32 images), we set a small patch size (4×4) as imagent is too big
    """
    def __init__(self, num_classes, num_blocks, patch_size, hidden_dim,
                 tokens_mlp_dim, channels_mlp_dim, image_size=32, in_channels=3):
        super(MlpMixer, self).__init__()


        self.patch_size = patch_size
        """
        # (1) The stem Conv2d splits the image into non-overlapping patches.
        #     START OF THE PAPER WE SPLIT THE IMAGE INTO NON OVERLAPPING PATCH
        #     CNN does the same sort of thing
        """
        self.stem = nn.Conv2d(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
        self.num_tokens = (image_size // patch_size) ** 2


        """
        # (2) makes them into list of mixer block

        """
        # Create a list of Mixer blocks.
        self.mixer_blocks = nn.ModuleList([
            MixerBlock(num_tokens=self.num_tokens, hidden_dim=hidden_dim,
                       tokens_mlp_dim=tokens_mlp_dim, channels_mlp_dim=channels_mlp_dim)
            for _ in range(num_blocks)
        ])
        # Final layer normalization before classifiing
        self.norm = nn.LayerNorm(hidden_dim)

        # zero initialize the weights as in paper given
        self.head = nn.Linear(hidden_dim, num_classes)
        nn.init.zeros_(self.head.weight)
        if self.head.bias is not None:
            nn.init.zeros_(self.head.bias)

    def forward(self, x):
        # x: (batch, in_channels, image_size, image_size)
        x = self.stem(x)  # → (B, hidden_dim, H', W') where H' = image_size/patch_size
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # → (B, num_tokens, hidden_dim)
        for block in self.mixer_blocks:
            x = block(x)
        x = self.norm(x)
        x = x.mean(dim=1)  # global average pooling
        x = self.head(x)
        return x

In [5]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    avg_loss = running_loss / len(dataloader.dataset)
    return avg_loss

In [6]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

In [22]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

class TransformSubset(torch.utils.data.Dataset):
    """A dataset wrapper that applies a transform to a subset."""
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

if __name__ == '__main__':
    # Normalization stats for CIFAR-10
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2470, 0.2435, 0.2616]

    # Define transforms
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # Load the training dataset without transforms initially
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=None)

    # Split into training and validation sets
    train_size = int(0.8 * len(trainset))  # 80% for training
    val_size = len(trainset) - train_size  # 20% for validation
    train_subset, val_subset = random_split(trainset, [train_size, val_size])

    # Apply transforms to each subset
    train_dataset = TransformSubset(train_subset, transform=transform_train)
    val_dataset = TransformSubset(val_subset, transform=transform_test)

    # Create data loaders
    trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    valloader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    # Load test set
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


100%|██████████| 170M/170M [00:03<00:00, 45.3MB/s]


In [None]:
import copy
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

# Define the model for CIFAR-10
model = MlpMixer(
    num_classes=10,           # CIFAR-10 has 10 classes
    num_blocks=2,
    patch_size=4,
    hidden_dim=512,
    tokens_mlp_dim=256,
    channels_mlp_dim=256,
    image_size=32,
    in_channels=3
)
model.to(device)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.05)

# Warmup + Cosine Annealing Scheduler
num_epochs = 100
warmup_epochs = 5

warmup = LinearLR(optimizer, start_factor=1e-5, total_iters=warmup_epochs)
cosine = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs)
scheduler = SequentialLR(optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs])

# Loss function with label smoothing for better generalization
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Training loop
print("Starting training on CIFAR-10")
best_val_acc = 0
best_model_state = None

for epoch in range(num_epochs):
    train_loss = train_epoch(model, trainloader, criterion, optimizer, device)
    val_acc = evaluate(model, valloader, device)
    scheduler.step()

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = copy.deepcopy(model.state_dict())

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Val Accuracy: {val_acc * 100:.2f}%")

# Load best model and evaluate on test set
model.load_state_dict(best_model_state)
test_acc = evaluate(model, testloader, device)
print(f"\n✅ Final Test Accuracy: {test_acc * 100:.2f}%")

Starting training on CIFAR-10
Epoch [1/100], Loss: 2.3023, Val Accuracy: 18.97%
Epoch [2/100], Loss: 1.8501, Val Accuracy: 45.36%
Epoch [3/100], Loss: 1.6670, Val Accuracy: 51.71%
Epoch [4/100], Loss: 1.5779, Val Accuracy: 56.55%
Epoch [5/100], Loss: 1.5101, Val Accuracy: 59.20%
Epoch [6/100], Loss: 1.4851, Val Accuracy: 60.90%
Epoch [7/100], Loss: 1.4287, Val Accuracy: 63.29%
Epoch [8/100], Loss: 1.3874, Val Accuracy: 65.76%
Epoch [9/100], Loss: 1.3541, Val Accuracy: 66.32%
Epoch [10/100], Loss: 1.3288, Val Accuracy: 66.11%
