In [None]:
# Import necessary Libraries
import torch 
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn 
import torch.optim as optim
import numpy as np 
import matplotlib.pyplot as plt 
import time
import argparse

In [None]:
# Bulding the network using depth-wise seperable convolutions such that no specific initialization of keys,queries,values are needed in a encoder.
# Depth wise seperable convolutions does all this things in a wise manner.

# Residual connection aka skip connection (Similar to Resnet), where we gonna add the residue to the output as in forward method.
class Residual(nn.Module):
  def __init__(self,fn):
    super().__init__()
    self.fn = fn 
  def forward(self,x):
    return self.fn(x) + x 

# Attention process
def ConvMixer(dim,depth,kernel_size=5,patch_size=2,n_classes=10):
  return nn.Sequential(
      # Performing Convolutions with the patch_size, such that all the patches get into linear-fashioned embeddings.
      # Uses GELU as to add non-linearity.
      # Doing Normalization for this batch.
      nn.Conv2d(3,dim,kernel_size=patch_size,stride=patch_size),
      nn.GELU(),
      nn.BatchNorm2d(dim),
      # The following first 3 steps does the depth-wise seperable convolution + activation function + BatchNormalization
      # The next 3 steps 
      *[nn.Sequential(
          Residual(nn.Sequential(nn.Conv2d(dim,dim,kernel_size,groups=dim,padding='same'),
          nn.GELU(),
          nn.BatchNorm2d(dim))),
          nn.Conv2d(dim,dim,kernel_size=1),
          nn.GELU(),
          nn.BatchNorm2d(dim)
      )for i in range(depth)],
      nn.AdaptiveAvgPool2d((1,1)),
      nn.Flatten(),
      nn.Linear(dim,n_classes)
  )

In [None]:
# Getting data 

cifar10_mean = (0.4914,0.4822,0.4465)
cifar10_std = (0.2471,0.2435,0.2616)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomResizedCrop(32,scale=(0.75,1.0),ratio=(1.0,1.0)),
    transforms.RandAugment(num_ops=1,magnitude=8),
    transforms.ColorJitter(0.1,0.1,0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean,cifar10_std),
    transforms.RandomErasing(p=0.25)])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean,cifar10_std)])

trainset = torchvision.datasets.CIFAR10(root='./data',download=True,train=True,transform=train_transforms)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=128,shuffle=True,num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',download=True,train=False,transform=test_transforms)
testloader = torch.utils.data.DataLoader(testset,batch_size=128,num_workers=2)



Files already downloaded and verified
Files already downloaded and verified


In [None]:
epochs = 24
depth,hdim,psize,conv_ks,clip_norm = 10,256,2,5,True
lr_scheduler = lambda t: np.interp([t],[0,epochs*2//5,epochs*4//5,epochs],[0,0.01,0.01/20,0])[0]

In [None]:
model = ConvMixer(hdim,depth,patch_size=psize,kernel_size=conv_ks,n_classes=10)
model = nn.DataParallel(model,device_ids=[0]).cuda()
optimiz = optim.AdamW(model.parameters(),lr=0.01,weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [None]:
from tqdm import tqdm

for ep in range(epochs):
  start = time.time()
  train_loss,train_acc,n = 0,0,0 
  
  pbar = tqdm(trainloader)
 
  processed = 0
  correct =0 
  for batch_idx,(data,target) in enumerate(pbar):
    model.train()
    data,target = data.cuda(),target.cuda()
    lr = lr_scheduler(ep+(batch_idx+1)/len(trainloader))
    optimiz.param_groups[0].update(lr=lr)
    optimiz.zero_grad() 
    
    
    with torch.cuda.amp.autocast():
      output = model(data)
      loss = criterion(output,target) 

    scaler.scale(loss).backward()
    if clip_norm:
      scaler.unscale_(optimiz)
      nn.utils.clip_grad_norm_(model.parameters(),1.0)
    scaler.step(optimiz)
    scaler.update()
    train_loss += loss.item() 
    
    pred = output.argmax(dim=1,keepdim=True)
    correct += pred.eq(target.view_as(pred)).sum().item()
    processed += len(data)

    pbar.set_description(desc=f"Loss={loss.item()} Batch_id={batch_idx} train-acc={100*correct/processed:0.2f}")

  model.eval()
  test_acc = 0
  m = 0
  test_loss =0 
  correct = 0
  with torch.no_grad():
    for data,target in testloader:
      data,target = data.cuda(),target.cuda() 
      with torch.cuda.amp.autocast():
        output = model(data)
      test_loss += criterion(output,target).item()
      pred = output.argmax(dim=1,keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
      m += target.size(0)

  test_loss = test_loss/len(testloader.dataset)
  print(f"Test-Loss: {test_loss} val-accuracy: {correct/m, 100.*correct/m}")
  
  
  print(f'ConvMixer: Time: {time.time() - start:.1f}, lr: {lr:.6f}')


Loss=0.9519866704940796 Batch_id=390 train-acc=60.62: 100%|██████████| 391/391 [01:14<00:00,  5.28it/s]


Test-Loss: 0.007870556640625 val-accuracy: (0.6507, 65.07)
ConvMixer: Time: 77.7, lr: 0.001111


Loss=0.8901956677436829 Batch_id=390 train-acc=65.13: 100%|██████████| 391/391 [01:15<00:00,  5.17it/s]


Test-Loss: 0.00811865234375 val-accuracy: (0.6667, 66.67)
ConvMixer: Time: 79.0, lr: 0.002222


Loss=0.7415332198143005 Batch_id=390 train-acc=69.11: 100%|██████████| 391/391 [01:16<00:00,  5.14it/s]


Test-Loss: 0.006064208984375 val-accuracy: (0.7419, 74.19)
ConvMixer: Time: 79.5, lr: 0.003333


Loss=0.859330952167511 Batch_id=390 train-acc=72.83: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Test-Loss: 0.0051286865234375 val-accuracy: (0.7801, 78.01)
ConvMixer: Time: 79.8, lr: 0.004444


Loss=0.6765726804733276 Batch_id=390 train-acc=75.50: 100%|██████████| 391/391 [01:16<00:00,  5.11it/s]


Test-Loss: 0.004757421875 val-accuracy: (0.7946, 79.46)
ConvMixer: Time: 79.9, lr: 0.005556


Loss=0.64264976978302 Batch_id=390 train-acc=76.87: 100%|██████████| 391/391 [01:16<00:00,  5.12it/s]


Test-Loss: 0.005041845703125 val-accuracy: (0.7875, 78.75)
ConvMixer: Time: 79.8, lr: 0.006667


Loss=0.5921646356582642 Batch_id=390 train-acc=78.23: 100%|██████████| 391/391 [01:15<00:00,  5.19it/s]


Test-Loss: 0.0043831787109375 val-accuracy: (0.8099, 80.99)
ConvMixer: Time: 79.6, lr: 0.007778


Loss=0.9091554880142212 Batch_id=390 train-acc=79.11: 100%|██████████| 391/391 [01:14<00:00,  5.23it/s]


Test-Loss: 0.0042706787109375 val-accuracy: (0.815, 81.5)
ConvMixer: Time: 79.3, lr: 0.008889


Loss=0.7849539518356323 Batch_id=390 train-acc=79.90: 100%|██████████| 391/391 [01:14<00:00,  5.23it/s]


Test-Loss: 0.00443883056640625 val-accuracy: (0.8174, 81.74)
ConvMixer: Time: 78.8, lr: 0.010000


Loss=0.46423134207725525 Batch_id=390 train-acc=80.54: 100%|██████████| 391/391 [01:10<00:00,  5.55it/s]


Test-Loss: 0.0037575927734375 val-accuracy: (0.8368, 83.68)
ConvMixer: Time: 73.5, lr: 0.009050


Loss=0.5162789821624756 Batch_id=390 train-acc=82.51: 100%|██████████| 391/391 [01:09<00:00,  5.64it/s]


Test-Loss: 0.0034214111328125 val-accuracy: (0.8523, 85.23)
ConvMixer: Time: 73.5, lr: 0.008100


Loss=0.5699135065078735 Batch_id=390 train-acc=83.94: 100%|██████████| 391/391 [01:09<00:00,  5.66it/s]


Test-Loss: 0.00348258056640625 val-accuracy: (0.8509, 85.09)
ConvMixer: Time: 72.1, lr: 0.007150


Loss=0.4112052917480469 Batch_id=390 train-acc=84.99: 100%|██████████| 391/391 [01:08<00:00,  5.67it/s]


Test-Loss: 0.00281409912109375 val-accuracy: (0.8753, 87.53)
ConvMixer: Time: 73.2, lr: 0.006200


Loss=0.31334465742111206 Batch_id=390 train-acc=86.12: 100%|██████████| 391/391 [01:09<00:00,  5.66it/s]


Test-Loss: 0.00288087158203125 val-accuracy: (0.8765, 87.65)
ConvMixer: Time: 72.1, lr: 0.005250


Loss=0.4022236466407776 Batch_id=390 train-acc=87.22: 100%|██████████| 391/391 [01:09<00:00,  5.66it/s]


Test-Loss: 0.00269033203125 val-accuracy: (0.8837, 88.37)
ConvMixer: Time: 72.9, lr: 0.004300


Loss=0.2903965413570404 Batch_id=390 train-acc=88.63: 100%|██████████| 391/391 [01:09<00:00,  5.66it/s]


Test-Loss: 0.0025605712890625 val-accuracy: (0.8913, 89.13)
ConvMixer: Time: 72.2, lr: 0.003350


Loss=0.24843725562095642 Batch_id=390 train-acc=89.67: 100%|██████████| 391/391 [01:10<00:00,  5.55it/s]


Test-Loss: 0.0023769775390625 val-accuracy: (0.8994, 89.94)
ConvMixer: Time: 73.5, lr: 0.002400


Loss=0.3716115653514862 Batch_id=390 train-acc=91.01: 100%|██████████| 391/391 [01:08<00:00,  5.74it/s]


Test-Loss: 0.002169146728515625 val-accuracy: (0.9106, 91.06)
ConvMixer: Time: 71.2, lr: 0.001450


Loss=0.2372910976409912 Batch_id=390 train-acc=92.31: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s]


Test-Loss: 0.00210303955078125 val-accuracy: (0.9155, 91.55)
ConvMixer: Time: 71.8, lr: 0.000500


Loss=0.19016499817371368 Batch_id=390 train-acc=92.94: 100%|██████████| 391/391 [01:08<00:00,  5.72it/s]


Test-Loss: 0.002056024169921875 val-accuracy: (0.9173, 91.73)
ConvMixer: Time: 71.4, lr: 0.000400


Loss=0.1579231172800064 Batch_id=390 train-acc=93.19: 100%|██████████| 391/391 [01:07<00:00,  5.82it/s]


Test-Loss: 0.00205999755859375 val-accuracy: (0.9165, 91.65)
ConvMixer: Time: 71.2, lr: 0.000300


Loss=0.1432209014892578 Batch_id=390 train-acc=93.55: 100%|██████████| 391/391 [01:08<00:00,  5.73it/s]


Test-Loss: 0.002034747314453125 val-accuracy: (0.9177, 91.77)
ConvMixer: Time: 71.2, lr: 0.000200


Loss=0.2985072731971741 Batch_id=390 train-acc=93.48: 100%|██████████| 391/391 [01:06<00:00,  5.90it/s]


Test-Loss: 0.002022528076171875 val-accuracy: (0.9193, 91.93)
ConvMixer: Time: 69.2, lr: 0.000100


Loss=0.09510357677936554 Batch_id=390 train-acc=93.80: 100%|██████████| 391/391 [01:07<00:00,  5.79it/s]


Test-Loss: 0.0020044677734375 val-accuracy: (0.9191, 91.91)
ConvMixer: Time: 71.1, lr: 0.000000


In [None]:
print("Done!")

Done!
