In [7]:
from torch import nn
from einops.layers.torch import Rearrange
import torch


class MixerBlock(nn.Module):
    def __init__(self, dim, patch):
        super().__init__()
        self.pre_layer_norm = nn.LayerNorm(dim)
        self.post_layer_norm = nn.LayerNorm(dim)
        
        self.token_mixer = nn.Sequential(
                            nn.Linear(patch, dim),
                            nn.GELU(),
                            nn.Dropout(0.1),
                            nn.Linear(dim, patch),
                            nn.Dropout(0.1)
                            )
        
        self.channel_mixer = nn.Sequential(
                            nn.Linear(dim, dim),
                            nn.GELU(),
                            nn.Dropout(0.1),
                            nn.Linear(dim, dim),
                            nn.Dropout(0.1)
                            )
    def forward(self, x):
        z =self.pre_layer_norm(x)
        y = self.token_mixer(z.transpose(1,2)).transpose(1,2)
        y = y + x
        post_ln = self.post_layer_norm(y)
        cm_out = self.channel_mixer(post_ln)+y
        return cm_out
    
    
class MLPMixer(nn.Module):
    def __init__(self,input_size, patch_size, dim = 512, img_channel=3, layers = 12, num_classes=12):
        super().__init__()
        patch = int(input_size[0]/patch_size[0] * input_size[1]/patch_size[1])
        patch_dim = img_channel * patch_size[0] * patch_size[1]
        self.embedding = nn.Sequential(
                                                Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size[0], p2 = patch_size[1]),
                                                nn.Linear(patch_dim, dim)
                                                )
        self.main_architecture = nn.Sequential(*[nn.Sequential(MixerBlock(dim,patch)) for _ in range(layers)])
        
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(dim,num_classes)
    
    def forward(self,x):
        x = self.embedding(x)
        x = self.main_architecture(x)
        return self.classifier(self.pool(x.transpose(1,2)).squeeze(2))

In [8]:
import torchvision.transforms as transforms
import torchvision
transform = transforms.Compose([transforms.Pad(4),transforms.RandomHorizontalFlip(),transforms.RandomCrop(32),transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
test_transform  = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))])

train_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=True,transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='../../data/', train=False,transform=test_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=100, shuffle=False)

Files already downloaded and verified


In [9]:
model = MLPMixer(
        input_size = (256,256),
        patch_size = (1,256),
        dim = 512,
        layers = 12,
        num_classes = 10,   
        )
img = torch.randn(10, 3, 256, 256)
pred = model(img)
pred.size()

torch.Size([10, 10])

In [10]:
import time 
device = "cuda"

model = MLPMixer(
        input_size = (32,32),
        patch_size = (1,32),
        dim = 512,
        layers = 12,
        num_classes = 10,   
        ).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=lr_decay)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=lr_decay)
total_time = 0 
epochs = 56
for epoch in range(epochs):
  for i, (images, labels) in enumerate(train_loader):
    start_time = time.time()
    images = images.to(device)
    labels = labels.to(device)
    outputs = model(images)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step() 
    if (i+1) %500 == 0:
      elapsed_time = time.time() - start_time
      total_time += elapsed_time
      print ("Epoch {}, Step {} Loss: {:.4f} time : {:.4f}sec".format(epoch+1, i+1, loss.item(),total_time))
 # scheduler.step()




#trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)

Epoch 1, Step 500 Loss: 1.5347 time : 0.1748sec
Epoch 2, Step 500 Loss: 1.2596 time : 0.3491sec
Epoch 3, Step 500 Loss: 1.3192 time : 0.5247sec
Epoch 4, Step 500 Loss: 1.4859 time : 0.7010sec
Epoch 5, Step 500 Loss: 1.0644 time : 0.8779sec
Epoch 6, Step 500 Loss: 1.2136 time : 1.0528sec
Epoch 7, Step 500 Loss: 1.2327 time : 1.2290sec
Epoch 8, Step 500 Loss: 0.8495 time : 1.4095sec
Epoch 9, Step 500 Loss: 1.0462 time : 1.5856sec
Epoch 10, Step 500 Loss: 0.8641 time : 1.7606sec
Epoch 11, Step 500 Loss: 0.7262 time : 1.9349sec
Epoch 12, Step 500 Loss: 0.7857 time : 2.1098sec
Epoch 13, Step 500 Loss: 0.8486 time : 2.2858sec
Epoch 14, Step 500 Loss: 0.6858 time : 2.4617sec
Epoch 15, Step 500 Loss: 0.7438 time : 2.6381sec
Epoch 16, Step 500 Loss: 0.7348 time : 2.8130sec
Epoch 17, Step 500 Loss: 0.7044 time : 2.9877sec
Epoch 18, Step 500 Loss: 0.6365 time : 3.1628sec
Epoch 19, Step 500 Loss: 0.6478 time : 3.3389sec
Epoch 20, Step 500 Loss: 0.6265 time : 3.5148sec
Epoch 21, Step 500 Loss: 0.71

In [16]:
model.eval()
with torch.no_grad():
    total = 0
    correct =0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print('Accuracy ( test images ) : {} %'.format(100 * correct / total))

Accuracy ( test images ) : 90.1 %


In [17]:
from torchsummary import summary

model_stats = summary(model, (3, 32, 32))
summary_str = str(model_stats)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 32, 96]               0
            Linear-2              [-1, 32, 512]          49,664
         LayerNorm-3              [-1, 32, 512]           1,024
            Linear-4             [-1, 512, 512]          16,896
              GELU-5             [-1, 512, 512]               0
           Dropout-6             [-1, 512, 512]               0
            Linear-7              [-1, 512, 32]          16,416
           Dropout-8              [-1, 512, 32]               0
         LayerNorm-9              [-1, 32, 512]           1,024
           Linear-10              [-1, 32, 512]         262,656
             GELU-11              [-1, 32, 512]               0
          Dropout-12              [-1, 32, 512]               0
           Linear-13              [-1, 32, 512]         262,656
          Dropout-14              [-1, 