<center>
    <h1> Machine Learning for Computer Vision</h1>
    <h2> Project Work </h2>
</center>

<br>

**Student**: Matteo Donati <br>
**Student ID**: 0001032227 <br>
**E-mail**: <a href="mailto:matteo.donati10@studio.unibo.it">matteo.donati10@studio.unibo.it</a>

<br>

This is an unofficial implementation of the paper by Guo et al., (2021)<sup>[[1]](#references)</sup>.

## Libraries

In [1]:
!pip install -U "git+https://github.com/facebookresearch/fvcore"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/facebookresearch/fvcore
  Cloning https://github.com/facebookresearch/fvcore to /tmp/pip-req-build-pb1qi_9f
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/fvcore /tmp/pip-req-build-pb1qi_9f
  Resolved https://github.com/facebookresearch/fvcore to commit fd5043ff8b2e6790f5bd7c9632695c68986cc658
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.6
  Downloading yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting iopath>=0.1.7
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 KB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker
  Downloading portalocker-2.6.0-py2.py3-none-any.whl (15 kB)
Building wheels for collected packages: fvcore, iopath
  Building wheel

In [2]:
import numpy as np
import random
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms as T
import multiprocessing as mp
from matplotlib import pyplot as plt
from fvcore.nn import FlopCountAnalysis, flop_count_str
import pickle
from tqdm import tqdm
from sklearn.metrics import classification_report

In [3]:
def set_reproducibility(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

set_reproducibility(0)

## Data Loading and Preprocessing

In [4]:
# Retrieves the CIFAR10 dataset.
def get_cifar10(train_transform = None, test_transform = None):

  # Downloading the training set.
  train_dataset =  datasets.CIFAR10(
    root = "CIFAR10",
    train = True,
    transform = train_transform,
    download = True
  )

  # Downloading the test set.
  test_dataset = datasets.CIFAR10(
    root = "CIFAR10",
    train = False,
    transform = test_transform,
    download = True
  )

  # Returning the two sets.
  return train_dataset, test_dataset

# Fuction that creates train, validation and test data loaders.
def create_data_loaders(train_transform, 
                        test_transform, 
                        img_size = 224, 
                        split = (0.8, 0.2), 
                        batch_size = 32, 
                        num_workers = 1):

  # Retrieving CIFAR10.
  train_dataset, test_dataset = get_cifar10(train_transform, test_transform)

  # Splitting train_dataset to create a validation set.
  train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, 
                                                             (int(len(train_dataset) * split[0]), 
                                                              int(len(train_dataset) * split[1])))

  # Train data loader.
  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = True,
    sampler = None
  )

  # Validation data loader.
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size = batch_size,
    shuffle = False,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False,
    sampler = None
  )

  # Test data loader.
  test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = batch_size,
    shuffle = False,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False,
    sampler = None
  )

  # Returning the three data loaders.
  return train_loader, val_loader, test_loader

In [5]:
# Training transformations.
train_transform = T.Compose([
    T.RandomCrop(32, padding = 4),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(mean = [0.491, 0.482, 0.447], std = [0.247, 0.243, 0.262])
])

# Test transformations.
test_transform = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(mean = [0.491, 0.482, 0.447], std = [0.247, 0.243, 0.262])
])

# Creating data loaders.
train_loader, val_loader, test_loader = create_data_loaders(
    train_transform,
    test_transform,
    img_size = 224,
    batch_size = 32,
    num_workers = mp.cpu_count()
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to CIFAR10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting CIFAR10/cifar-10-python.tar.gz to CIFAR10
Files already downloaded and verified


## Models Definition and Training

### Utilities

In [6]:
# Choosing between GPU and CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Stem module of the CMT architecture.
class Stem(nn.Module):
    
  # Constructor. Requires number of input and output channels.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Conv 3 x 3, stride 2.
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
    self.bn1 = nn.BatchNorm2d(out_channels)

    # Conv 3 x 3. 
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
    self.bn2 = nn.BatchNorm2d(out_channels)

    # Conv 3 x 3.
    self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
    self.bn3 = nn.BatchNorm2d(out_channels)

    # GELU activation.
    self.gelu = nn.GELU()

  # Forward pass.
  def forward(self, x):
    x = self.conv1(x)
    x = self.gelu(x)
    x = self.bn1(x)
    x = self.conv2(x)
    x = self.gelu(x)
    x = self.bn2(x)
    x = self.conv3(x)
    x = self.gelu(x)
    y = self.bn3(x)
    return y

In [8]:
# Patch embedding module of the CMT architecture.
class PatchEmbedding(nn.Module):

  # Constructor. Requires number of input and output channels.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Conv 2 x 2, stride 2.
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 2, stride = 2, padding = 0)

  # Forward pass.
  def forward(self, x):
    x = self.conv(x)
    _, c, h, w = x.size()
    y = torch.nn.functional.layer_norm(x, (c, h, w))
    return y

In [9]:
# Local Perception Unit: LPU(X) = DWConv(X) + X.
class LPU(nn.Module):

  # Constructor.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Depthwise convolution.
    self.dwconv = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, groups = in_channels)

  # Forward pass.
  def forward(self, x):
    y = self.dwconv(x) + x
    return y

In [10]:
# Lightweight Multi-Head-Self-Attention.
class LMHSA(nn.Module):
    
  # Constructor.
  def __init__(self, input_size, channels, d_k, d_v, stride, heads):

    super().__init__()

    # Depthwise convolutions.
    self.dwconv = nn.Conv2d(channels, channels, kernel_size = stride, stride = stride, groups = channels)

    # Projection matrices.
    self.fc_q = nn.Linear(channels, heads * d_k)
    self.fc_k = nn.Linear(channels, heads * d_k)
    self.fc_v = nn.Linear(channels, heads * d_v)
    self.fc_o = nn.Linear(heads * d_k, channels)

    self.channels = channels
    self.d_k = d_k
    self.d_v = d_v
    self.stride = stride
    self.heads = heads

    # Relative position bias to each self-attention module.
    self.B = nn.Parameter(torch.randn(1, self.heads, input_size ** 2, (input_size // stride) ** 2), requires_grad = True)
    
  # Forward pass.
  def forward(self, x):

    # Extracting shape from input signal.
    b, c, h, w = x.shape

    # Reshaping and permuting x. Final shape is (b, h * w, c).
    x_reshape = x.view(b, c, h * w).permute(0, 2, 1)

    # Layer norm.
    x_reshape = torch.nn.functional.layer_norm(x_reshape, (b, h * w, c))

    # Getting queries by applying the fc_q linear projection.
    q = self.fc_q(x_reshape)

    # Reshaping and permuting the queries. Final shape is (b, heads, n = h * w, d_k). 
    q = q.view(b, h * w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Applying depthwise conv to x.
    k = self.dwconv(x)

    # Extracting shape of keys.
    k_b, k_c, k_h, k_w = k.shape

    # Reshaping and permuting keys. Final shape is (k_b, k_h * k_w, k_c).
    k = k.view(k_b, k_c, k_h * k_w).permute(0, 2, 1).contiguous()

    # Projecting through fc_k.
    k = self.fc_k(k)

    # Reshaping and permuting the keys. Final shape is (k_b, heads, k_h * k_w, d_k).
    k = k.view(k_b, k_h * k_w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Applying depthwise conv to x.
    v = self.dwconv(x)

    # Extracting shape of values.
    v_b, v_c, v_h, v_w = v.shape

    # Reshaping and permuting values. Final shape is (v_b, v_h * v_w, v_c).
    v = v.view(v_b, v_c, v_h * v_w).permute(0, 2, 1).contiguous()

    # Projecting through fc_v.
    v = self.fc_v(v)

    # Reshaping and permuting the keys. Final shape is (v_b, heads, v_h * v_w, d_v).
    v = v.view(v_b, v_h * v_w, self.heads, self.d_v).permute(0, 2, 1, 3).contiguous()

    # Computing softmax((Q K') / sqrt(d_k) + B). Final shape is (b, heads, n = h * w, k_h * k_w).
    attention = torch.einsum("... i d, ... j d -> ... i j", q, k) * (self.d_k ** -0.5)
    attention = attention + self.B
    attention = torch.softmax(attention, dim = -1)

    # Applying attention scores to values by taking the dot product.
    tmp = torch.matmul(attention, v).permute(0, 2, 1, 3)

    # Permuting the result. Final shape is (b, n = h * w, heads, d_v).
    tmp = tmp.contiguous().view(b, h * w, self.heads * self.d_v)

    # Projecting using fc_o and reshaping. Final shape is (b, c, h, w).
    tmp = self.fc_o(tmp).view(b, self.channels, h, w)

    # Returning tmp + x (skip connection).
    return tmp + x

In [11]:
# Standard Multi-Head-Self-Attention.
class MHSA(nn.Module):
    
  # Constructor.
  def __init__(self, input_size, channels, d_k, d_v, stride, heads):

    super().__init__()

    # Projection matrices.
    self.fc_q = nn.Linear(channels, heads * d_k)
    self.fc_k = nn.Linear(channels, heads * d_k)
    self.fc_v = nn.Linear(channels, heads * d_v)
    self.fc_o = nn.Linear(heads * d_k, channels)

    self.channels = channels
    self.d_k = d_k
    self.d_v = d_v
    self.stride = stride
    self.heads = heads
    
  # Forward pass.
  def forward(self, x):

    # Extracting shape from input signal.
    b, c, h, w = x.shape

    # Reshaping and permuting x. Final shape is (b, h * w, c).
    x_reshape = x.view(b, c, h * w).permute(0, 2, 1)

    # Layer norm.
    x_reshape = torch.nn.functional.layer_norm(x_reshape, (b, h * w, c))

    # Getting queries by applying the fc_q linear projection.
    q = self.fc_q(x_reshape)

    # Reshaping and permuting the queries. Final shape is (b, heads, n = h * w, d_k). 
    q = q.view(b, h * w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Projecting through fc_k.
    k = self.fc_k(x_reshape)

    # Reshaping and permuting the keys. Final shape is (b, heads, h * w, d_k).
    k = k.view(b, h * w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Projecting through fc_v.
    v = self.fc_v(x_reshape)

    # Reshaping and permuting the keys. Final shape is (b, heads, h * w, d_v).
    v = v.view(b, h * w, self.heads, self.d_v).permute(0, 2, 1, 3).contiguous()

    # Computing softmax((Q K) / sqrt(d_k)).
    attention = torch.einsum("... i d, ... j d -> ... i j", q, k) * (self.d_k ** -0.5)
    attention = torch.softmax(attention, dim = -1)

    # Applying attention scores to values by taking the dot product.
    tmp = torch.matmul(attention, v).permute(0, 2, 1, 3)

    # Permuting the result. Final shape is (b, n = h * w, heads, d_v).
    tmp = tmp.contiguous().view(b, h * w, self.heads * self.d_v)

    # Projecting using fc_o and reshaping. Final shape is (b, c, h, w).
    tmp = self.fc_o(tmp).view(b, self.channels, h, w)

    # Returning tmp + x (skip connection).
    return tmp + x

In [12]:
# Inverted Residual Feed-forward Network: IRFFN(X) = Conv(F(Conv(X))), F(X) = DWConv(X) + X.
class IRFFN(nn.Module):

  # Constructor.
  def __init__(self, in_channels, R):

    super().__init__()

    # Number of channels after expansion.
    out_channels = int(in_channels * R)

    # Conv 1 x 1.
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
    self.bn1 = nn.BatchNorm2d(out_channels)

    # Depthwise convolution.
    self.dwconv = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, groups = out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)

    # Conv 1 x 1.
    self.conv2 = nn.Conv2d(out_channels, in_channels, kernel_size = 1)
    self.bn3 = nn.BatchNorm2d(in_channels)

    # GELU activation.
    self.gelu = nn.GELU()

  # Forward pass.
  def forward(self, x):
    tmp = self.conv1(x)
    tmp = self.bn1(tmp)
    tmp = self.gelu(tmp)
    tmp = self.dwconv(tmp)
    tmp = self.bn2(tmp)
    tmp = self.gelu(tmp)
    tmp = self.conv2(tmp)
    tmp = self.bn3(tmp)
    y = x + tmp
    return y

In [13]:
# CMT block of the CMT architecture.
class CMTBlock(nn.Module):

  # Constructor. By default it loads the CMT-Ti configuration.
  def __init__(self, img_size, k, d_k, d_v, num_heads, R = 3.6, in_channels = 46, attention_type = "light"):

    super().__init__()

    # Local Perception Unit.
    self.lpu = LPU(in_channels, in_channels)

    # Lightweight MHSA.
    if attention_type == "light":
      self.mhsa = LMHSA(img_size, in_channels, d_k, d_v, k, num_heads)
    
    # Lightweight MHSA.
    elif attention_type == "standard":
      self.mhsa = MHSA(img_size, in_channels, d_k, d_v, k, num_heads)
    
    # No attention.
    else:
      self.mhsa = None

    # Inverted Residual FFN.
    self.irffn = IRFFN(in_channels, R)

  # Forward pass.
  def forward(self, x):
    x = self.lpu(x)
    if self.mhsa != None: x = self.mhsa(x)
    y = self.irffn(x)
    return y

In [14]:
# CMT architecture.
class CMT(nn.Module):

  # Constructor. By default it loads the CMT-Ti configuration.
  def __init__(self,
    in_channels = 3,
    stem_channels = 16,
    cmt_channels = [46, 92, 184, 368],
    patch_channels = [46, 92, 184, 368],
    block_layers = [2, 2, 10, 2],
    R = 3.6,
    img_size = 224,
    num_class = 10,
    attention_type = "light"
  ):

    super(CMT, self).__init__()

    # Stem layer
    self.stem = Stem(in_channels, stem_channels)

    # Patch Aggregation Layer
    self.pe1 = PatchEmbedding(stem_channels, patch_channels[0])
    self.pe2 = PatchEmbedding(patch_channels[0], patch_channels[1])
    self.pe3 = PatchEmbedding(patch_channels[1], patch_channels[2])
    self.pe4 = PatchEmbedding(patch_channels[2], patch_channels[3])

    # Stage 1 CMT blocks.
    stage1 = [CMTBlock(img_size = img_size // 4, 
                       k = 8, 
                       d_k = cmt_channels[0], 
                       d_v = cmt_channels[0], 
                       num_heads = 1, 
                       R = R, 
                       in_channels = patch_channels[0],
                       attention_type = attention_type) for _ in range(block_layers[0])]
    self.stage1 = nn.Sequential(*stage1)

    # Stage 2 CMT blocks.
    stage2 = [CMTBlock(img_size = img_size // 8,
                       k = 4,
                       d_k = cmt_channels[1] // 2,
                       d_v = cmt_channels[1] // 2,
                       num_heads = 2,
                       R = R,
                       in_channels = patch_channels[1],
                       attention_type = attention_type) for _ in range(block_layers[1])]
    self.stage2 = nn.Sequential(*stage2)

    # Stage 3 CMT blocks.
    stage3 = [CMTBlock(img_size = img_size // 16,
                       k = 2,
                       d_k = cmt_channels[2] // 4,
                       d_v = cmt_channels[2] // 4,
                       num_heads = 4,
                       R = R,
                       in_channels = patch_channels[2],
                       attention_type = attention_type) for _ in range(block_layers[2])]
    self.stage3 = nn.Sequential(*stage3)

    # Stage 4 CMT blocks.
    stage4 = [CMTBlock(img_size = img_size // 32,
                       k = 1,
                       d_k = cmt_channels[3] // 8,
                       d_v = cmt_channels[3] // 8,
                       num_heads = 8,
                       R = R,
                       in_channels = patch_channels[3],
                       attention_type = attention_type) for _ in range(block_layers[3])]
    self.stage4 = nn.Sequential(*stage4)

    # Global average pooling.
    self.avg_pool = nn.AdaptiveAvgPool2d(1)

    # Projection layer.
    self.projection = nn.Sequential(
      nn.Conv2d(cmt_channels[3], 1280, kernel_size = 1),
      nn.ReLU(inplace = True),
    )

    # Classifier.
    self.classifier = nn.Linear(1280, num_class)

  # Forward pass.
  def forward(self, x):
    x = self.stem(x)
    x = self.pe1(x)
    x = self.stage1(x)
    x = self.pe2(x)
    x = self.stage2(x)
    x = self.pe3(x)
    x = self.stage3(x)
    x = self.pe4(x)
    x = self.stage4(x)
    x = self.avg_pool(x)
    x = self.projection(x)
    x = torch.flatten(x, 1)
    y = self.classifier(x)
    return y

In [15]:
# Trains the model.
def train(model, optimizer, scheduler, loss_fn, train_loader, val_loader, epochs, device = "cuda", history = None):

  # Metrics.
  if history == None: 
    history = {"train_loss": [], "train_accuracy": [], "val_loss": [], "val_accuracy": [], "lr": []}

  # Iterating over epochs.
  for epoch in range(epochs):

    print("Epoch %2d/%2d:" % (epoch + 1, epochs))

    # Training.
    model.train()

    # Metrics
    train_loss = 0.0
    num_train_correct = 0
    num_train_examples = 0

    # Iterating over mini-batches.
    for batch in tqdm(train_loader, desc = "Training", position = 0):

      # Gradient reset.
      optimizer.zero_grad()

      # Moving x and y to GPU.
      x = batch[0].to(device)
      y = batch[1].to(device)

      # Predictions.
      yhat = model(x)

      # Loss.
      loss = loss_fn(yhat, y)

      # Computing the gradient.
      loss.backward()

      # Updating the parameters.
      optimizer.step()

      # Updating the metrics.
      train_loss += loss.item() * x.shape[0]
      num_train_correct += (torch.argmax(yhat, 1) == y).sum().item()
      num_train_examples += x.shape[0]

    # Computing the epoch's metrics.
    train_accuracy = num_train_correct / num_train_examples
    train_loss = train_loss / len(train_loader.dataset)

    # Saving learning rate.
    lr = scheduler.get_last_lr()[0]

    # Updating the learning rate.
    scheduler.step()

    # Validation.
    model.eval()

    # Metrics.
    val_loss = 0.0
    num_val_correct = 0
    num_val_examples = 0

    with torch.no_grad():

      # Iterating over mini-batches.
      for batch in tqdm(val_loader, desc = "Validation", position = 0):

        # Moving x and y to GPU.
        x = batch[0].to(device)
        y = batch[1].to(device)

        # Predictions.
        yhat = model(x)

        # Loss.
        loss = loss_fn(yhat, y)

        # Updating the metrics.
        val_loss += loss.item() * x.shape[0]
        num_val_correct += (torch.argmax(yhat, 1) == y).sum().item()
        num_val_examples += y.shape[0]

      # Computing the epoch's metrics.
      val_accuracy = num_val_correct / num_val_examples
      val_loss = val_loss / len(val_loader.dataset)

    print("{}train_loss: {:.4f}, train_accuracy: {:.4f}, val_loss: {:.4f}, val_accuracy: {:.4f}, lr: {:.4e}{}".format("\n" if epoch == epochs - 1 else "",
                                                                                                                      train_loss, 
                                                                                                                      train_accuracy, 
                                                                                                                      val_loss, 
                                                                                                                      val_accuracy,
                                                                                                                      lr,
                                                                                                                      "\n" if epoch != epochs - 1 else ""))

    # Appending the epoch's metrics to history lists.
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_accuracy"].append(train_accuracy)
    history["val_accuracy"].append(val_accuracy)
    history["lr"].append(lr)

  # Returning history.
  return history

In [16]:
# Loads a checkpoint.
def load_checkpoint(path, model, optimizer, scheduler):

  # Loading checkpoint.
  checkpoint = torch.load(path)
  model.load_state_dict(checkpoint["model"])
  optimizer.load_state_dict(checkpoint["optimizer"])
  scheduler.load_state_dict(checkpoint["scheduler"])

  # Returning checkpoint entries.
  return model, optimizer, scheduler

### Lightweight MHSA model ($m_1$)

In [17]:
# Defining the model that uses LMHSA.
m_1 = CMT(
  in_channels = 3,
  stem_channels = 16,
  cmt_channels = [46, 92, 184, 368],
  patch_channels = [46, 92, 184, 368],
  block_layers = [2, 2, 10, 2],
  R = 3.6,
  img_size = 224,
  num_class = 10,
  attention_type = "light")

# Printing number of parameters.
print(f"m_1 has {sum(p.numel() for p in m_1.parameters())} parameters.")

m_1 has 9014696 parameters.


In [18]:
# Computing number of FLOPs.
flops = FlopCountAnalysis(m_1, (torch.randn((1, 3, 224, 224)),))

# Printing number of FLOPs per layer.
print(flop_count_str(flops))

CMT(
  #params: 9.01M, #flops: 1.31G
  (stem): Stem(
    #params: 5.18K, #flops: 66.23M
    (conv1): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
      #params: 0.45K, #flops: 5.42M
    )
    (bn1): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv2): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn2): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv3): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn3): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    #params: 2.99K, #flops: 9.81M
    (conv): Conv2d(
      16, 46,

In [19]:
# Loading model to device.
m_1.to(device)

CMT(
  (stem): Stem(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    (conv): Conv2d(16, 46, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe2): PatchEmbedding(
    (conv): Conv2d(46, 92, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe3): PatchEmbedding(
    (conv): Conv2d(92, 184, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe4): PatchEmbedding(
    (conv): Conv2d(184, 368, kernel_size=(2, 2), stride=(2, 2))
  )
  (stage1): Sequential(
    (0): CMTBlock(
      (lpu): LPU(
 

In [20]:
# Epochs.
EPOCHS = 25

# Initial learning rate.
LR = 6e-5

# Weight decay.
WD = 1e-5

# Loss function.
loss_fn = nn.CrossEntropyLoss()

# Optimizer.
optimizer = torch.optim.AdamW(m_1.parameters(), lr = LR, weight_decay = WD)

# Scheduler.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

In [21]:
# Training.
history = train(m_1, optimizer, scheduler, loss_fn, train_loader, val_loader, EPOCHS, device)

# Creating a checkpoint.
checkpoint = {
  "model": m_1.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict()
}

# Saving the model.
torch.save(checkpoint, "checkpoint1.pt")

# Saving history.
with open("m_1_history.pkl", "wb") as f:
  pickle.dump(history, f)

Epoch  1/25:


Training: 100%|██████████| 1250/1250 [08:42<00:00,  2.39it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.66it/s]


train_loss: 1.4515, train_accuracy: 0.4688, val_loss: 1.1542, val_accuracy: 0.5838, lr: 6.0000e-05

Epoch  2/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.66it/s]


train_loss: 1.0376, train_accuracy: 0.6309, val_loss: 0.9275, val_accuracy: 0.6715, lr: 5.9763e-05

Epoch  3/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s]


train_loss: 0.8631, train_accuracy: 0.6942, val_loss: 0.7936, val_accuracy: 0.7194, lr: 5.9057e-05

Epoch  4/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.7377, train_accuracy: 0.7417, val_loss: 0.7036, val_accuracy: 0.7554, lr: 5.7893e-05

Epoch  5/25:


Training: 100%|██████████| 1250/1250 [08:39<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.6417, train_accuracy: 0.7767, val_loss: 0.6473, val_accuracy: 0.7782, lr: 5.6289e-05

Epoch  6/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.5791, train_accuracy: 0.7991, val_loss: 0.6013, val_accuracy: 0.7889, lr: 5.4271e-05

Epoch  7/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.66it/s]


train_loss: 0.5228, train_accuracy: 0.8179, val_loss: 0.5615, val_accuracy: 0.8057, lr: 5.1869e-05

Epoch  8/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.57it/s]


train_loss: 0.4731, train_accuracy: 0.8344, val_loss: 0.5351, val_accuracy: 0.8160, lr: 4.9123e-05

Epoch  9/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.4419, train_accuracy: 0.8462, val_loss: 0.4994, val_accuracy: 0.8280, lr: 4.6075e-05

Epoch 10/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.63it/s]


train_loss: 0.3999, train_accuracy: 0.8595, val_loss: 0.4820, val_accuracy: 0.8350, lr: 4.2773e-05

Epoch 11/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.57it/s]


train_loss: 0.3712, train_accuracy: 0.8706, val_loss: 0.4877, val_accuracy: 0.8309, lr: 3.9271e-05

Epoch 12/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s]


train_loss: 0.3380, train_accuracy: 0.8829, val_loss: 0.4560, val_accuracy: 0.8423, lr: 3.5621e-05

Epoch 13/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.59it/s]


train_loss: 0.3091, train_accuracy: 0.8931, val_loss: 0.4518, val_accuracy: 0.8464, lr: 3.1884e-05

Epoch 14/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s]


train_loss: 0.2885, train_accuracy: 0.8992, val_loss: 0.4382, val_accuracy: 0.8529, lr: 2.8116e-05

Epoch 15/25:


Training: 100%|██████████| 1250/1250 [08:36<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.2623, train_accuracy: 0.9079, val_loss: 0.4370, val_accuracy: 0.8551, lr: 2.4379e-05

Epoch 16/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.2391, train_accuracy: 0.9169, val_loss: 0.4273, val_accuracy: 0.8592, lr: 2.0729e-05

Epoch 17/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.54it/s]


train_loss: 0.2244, train_accuracy: 0.9221, val_loss: 0.4191, val_accuracy: 0.8650, lr: 1.7227e-05

Epoch 18/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s]


train_loss: 0.2112, train_accuracy: 0.9274, val_loss: 0.4243, val_accuracy: 0.8606, lr: 1.3925e-05

Epoch 19/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.1922, train_accuracy: 0.9327, val_loss: 0.4147, val_accuracy: 0.8643, lr: 1.0877e-05

Epoch 20/25:


Training: 100%|██████████| 1250/1250 [08:37<00:00,  2.42it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.57it/s]


train_loss: 0.1859, train_accuracy: 0.9361, val_loss: 0.4148, val_accuracy: 0.8673, lr: 8.1309e-06

Epoch 21/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.61it/s]


train_loss: 0.1757, train_accuracy: 0.9388, val_loss: 0.3977, val_accuracy: 0.8719, lr: 5.7295e-06

Epoch 22/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.65it/s]


train_loss: 0.1641, train_accuracy: 0.9436, val_loss: 0.3936, val_accuracy: 0.8737, lr: 3.7108e-06

Epoch 23/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.64it/s]


train_loss: 0.1607, train_accuracy: 0.9447, val_loss: 0.4077, val_accuracy: 0.8685, lr: 2.1067e-06

Epoch 24/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.61it/s]


train_loss: 0.1605, train_accuracy: 0.9447, val_loss: 0.3986, val_accuracy: 0.8729, lr: 9.4251e-07

Epoch 25/25:


Training: 100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
Validation: 100%|██████████| 313/313 [00:36<00:00,  8.66it/s]



train_loss: 0.1560, train_accuracy: 0.9467, val_loss: 0.3914, val_accuracy: 0.8730, lr: 2.3656e-07


### Standard MHSA model ($m_2$)

In [None]:
# Defining the model that uses standard MHSA.
m_2 = CMT(
  in_channels = 3,
  stem_channels = 16,
  cmt_channels = [46, 92, 184, 368],
  patch_channels = [46, 92, 184, 368],
  block_layers = [2, 2, 10, 2],
  R = 3.6,
  img_size = 224,
  num_class = 10,
  attention_type = "standard")

# Printing number of parameters.
print(f"m_2 has {sum(p.numel() for p in m_2.parameters())} parameters.")

m_2 has 8111348 parameters.


In [None]:
# Computing number of FLOPs.
flops = FlopCountAnalysis(m_2, (torch.randn((1, 3, 224, 224)),))

# Printing number of FLOPs per layer.
print(flop_count_str(flops))

CMT(
  #params: 8.11M, #flops: 3.56G
  (stem): Stem(
    #params: 5.18K, #flops: 66.23M
    (conv1): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
      #params: 0.45K, #flops: 5.42M
    )
    (bn1): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv2): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn2): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv3): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn3): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    #params: 2.99K, #flops: 9.81M
    (conv): Conv2d(
      16, 46,

In [None]:
# Loading model to device.
m_2.to(device)

In [None]:
# Epochs.
EPOCHS = 25

# Initial learning rate.
LR = 6e-5

# Weight decay.
WD = 1e-5

# Loss function.
loss_fn = nn.CrossEntropyLoss()

# Optimizer.
optimizer = torch.optim.AdamW(m_2.parameters(), lr = LR, weight_decay = WD)

# Scheduler.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

In [None]:
# Training.
history = train(m_2, optimizer, scheduler, loss_fn, train_loader, val_loader, EPOCHS, device)

# Creating a checkpoint.
checkpoint = {
  "model": m_2.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict()
}

# Saving the model.
torch.save(checkpoint, "checkpoint2.pt")

# Saving history.
with open("m_2_history.pkl", "wb") as f:
  pickle.dump(history, f)

### No attention model ($m_3$)

In [None]:
# Defining the model that does not use attention.
m_3 = CMT(
  in_channels = 3,
  stem_channels = 16,
  cmt_channels = [46, 92, 184, 368],
  patch_channels = [46, 92, 184, 368],
  block_layers = [2, 2, 10, 2],
  R = 3.6,
  img_size = 224,
  num_class = 10,
  attention_type = None)

# Printing number of parameters.
print(f"m_3 has {sum(p.numel() for p in m_3.parameters())} parameters.")

m_3 has 5577668 parameters.


In [None]:
# Computing number of FLOPs.
flops = FlopCountAnalysis(m_3, (torch.randn((1, 3, 224, 224)),))

# Printing number of FLOPs per layer.
print(flop_count_str(flops))

CMT(
  #params: 5.58M, #flops: 0.95G
  (stem): Stem(
    #params: 5.18K, #flops: 66.23M
    (conv1): Conv2d(
      3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
      #params: 0.45K, #flops: 5.42M
    )
    (bn1): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv2): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn2): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (conv3): Conv2d(
      16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      #params: 2.32K, #flops: 28.9M
    )
    (bn3): BatchNorm2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      #params: 32, #flops: 1M
    )
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    #params: 2.99K, #flops: 9.81M
    (conv): Conv2d(
      16, 46,

In [None]:
# Loading model to device.
m_3.to(device)

CMT(
  (stem): Stem(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    (conv): Conv2d(16, 46, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe2): PatchEmbedding(
    (conv): Conv2d(46, 92, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe3): PatchEmbedding(
    (conv): Conv2d(92, 184, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe4): PatchEmbedding(
    (conv): Conv2d(184, 368, kernel_size=(2, 2), stride=(2, 2))
  )
  (stage1): Sequential(
    (0): CMTBlock(
      (lpu): LPU(
 

In [None]:
# Epochs.
EPOCHS = 25

# Initial learning rate.
LR = 6e-5

# Weight decay.
WD = 1e-5

# Loss function.
loss_fn = nn.CrossEntropyLoss()

# Optimizer.
optimizer = torch.optim.AdamW(m_3.parameters(), lr = LR, weight_decay = WD)

# Scheduler.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

In [None]:
# Training.
history = train(m_3, optimizer, scheduler, loss_fn, train_loader, val_loader, EPOCHS, device)

# Creating a checkpoint.
checkpoint = {
  "model": m_3.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict()
}

# Saving the model.
torch.save(checkpoint, "checkpoint3.pt")

# Saving history.
with open("m_3_history.pkl", "wb") as f:
  pickle.dump(history, f)

Epoch  1/25:


Training: 100%|██████████| 1250/1250 [04:19<00:00,  4.82it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.90it/s]


train_loss: 1.4455, train_accuracy: 0.4697, val_loss: 1.1578, val_accuracy: 0.5822, lr: 6.0000e-05

Epoch  2/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.97it/s]


train_loss: 1.0345, train_accuracy: 0.6299, val_loss: 0.8865, val_accuracy: 0.6867, lr: 5.9763e-05

Epoch  3/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.11it/s]


train_loss: 0.8454, train_accuracy: 0.7043, val_loss: 0.7933, val_accuracy: 0.7216, lr: 5.9057e-05

Epoch  4/25:


Training: 100%|██████████| 1250/1250 [04:17<00:00,  4.85it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.02it/s]


train_loss: 0.7174, train_accuracy: 0.7488, val_loss: 0.6873, val_accuracy: 0.7622, lr: 5.7893e-05

Epoch  5/25:


Training: 100%|██████████| 1250/1250 [04:17<00:00,  4.85it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.95it/s]


train_loss: 0.6265, train_accuracy: 0.7802, val_loss: 0.6346, val_accuracy: 0.7755, lr: 5.6289e-05

Epoch  6/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.09it/s]


train_loss: 0.5575, train_accuracy: 0.8048, val_loss: 0.5747, val_accuracy: 0.7994, lr: 5.4271e-05

Epoch  7/25:


Training: 100%|██████████| 1250/1250 [04:17<00:00,  4.85it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.07it/s]


train_loss: 0.5010, train_accuracy: 0.8254, val_loss: 0.5544, val_accuracy: 0.8106, lr: 5.1869e-05

Epoch  8/25:


Training: 100%|██████████| 1250/1250 [04:17<00:00,  4.85it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.82it/s]


train_loss: 0.4543, train_accuracy: 0.8415, val_loss: 0.5243, val_accuracy: 0.8196, lr: 4.9123e-05

Epoch  9/25:


Training: 100%|██████████| 1250/1250 [04:17<00:00,  4.85it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.86it/s]


train_loss: 0.4225, train_accuracy: 0.8514, val_loss: 0.4961, val_accuracy: 0.8307, lr: 4.6075e-05

Epoch 10/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:25<00:00, 12.36it/s]


train_loss: 0.3794, train_accuracy: 0.8666, val_loss: 0.4683, val_accuracy: 0.8390, lr: 4.2773e-05

Epoch 11/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.73it/s]


train_loss: 0.3533, train_accuracy: 0.8777, val_loss: 0.4724, val_accuracy: 0.8383, lr: 3.9271e-05

Epoch 12/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.75it/s]


train_loss: 0.3220, train_accuracy: 0.8886, val_loss: 0.4399, val_accuracy: 0.8485, lr: 3.5621e-05

Epoch 13/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.83it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.82it/s]


train_loss: 0.2952, train_accuracy: 0.8983, val_loss: 0.4360, val_accuracy: 0.8555, lr: 3.1884e-05

Epoch 14/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.78it/s]


train_loss: 0.2689, train_accuracy: 0.9067, val_loss: 0.4268, val_accuracy: 0.8555, lr: 2.8116e-05

Epoch 15/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.83it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.77it/s]


train_loss: 0.2494, train_accuracy: 0.9115, val_loss: 0.4342, val_accuracy: 0.8593, lr: 2.4379e-05

Epoch 16/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.82it/s]


train_loss: 0.2282, train_accuracy: 0.9206, val_loss: 0.4103, val_accuracy: 0.8649, lr: 2.0729e-05

Epoch 17/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.83it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.72it/s]


train_loss: 0.2089, train_accuracy: 0.9265, val_loss: 0.4104, val_accuracy: 0.8653, lr: 1.7227e-05

Epoch 18/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.64it/s]


train_loss: 0.1928, train_accuracy: 0.9336, val_loss: 0.4156, val_accuracy: 0.8672, lr: 1.3925e-05

Epoch 19/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.88it/s]


train_loss: 0.1801, train_accuracy: 0.9369, val_loss: 0.3981, val_accuracy: 0.8664, lr: 1.0877e-05

Epoch 20/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.01it/s]


train_loss: 0.1709, train_accuracy: 0.9403, val_loss: 0.4011, val_accuracy: 0.8702, lr: 8.1309e-06

Epoch 21/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.89it/s]


train_loss: 0.1583, train_accuracy: 0.9453, val_loss: 0.4071, val_accuracy: 0.8720, lr: 5.7295e-06

Epoch 22/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.87it/s]


train_loss: 0.1492, train_accuracy: 0.9478, val_loss: 0.3964, val_accuracy: 0.8739, lr: 3.7108e-06

Epoch 23/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:23<00:00, 13.34it/s]


train_loss: 0.1492, train_accuracy: 0.9481, val_loss: 0.3997, val_accuracy: 0.8753, lr: 2.1067e-06

Epoch 24/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:23<00:00, 13.60it/s]


train_loss: 0.1419, train_accuracy: 0.9503, val_loss: 0.3954, val_accuracy: 0.8752, lr: 9.4251e-07

Epoch 25/25:


Training: 100%|██████████| 1250/1250 [04:18<00:00,  4.84it/s]
Validation: 100%|██████████| 313/313 [00:23<00:00, 13.51it/s]



train_loss: 0.1392, train_accuracy: 0.9524, val_loss: 0.3864, val_accuracy: 0.8779, lr: 2.3656e-07


### ResNet-18 ($m_4$)

In [None]:
# Downloading ResNet-18.
m_4 = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", weights = None)

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip


In [None]:
# Printing number of parameters.
print(f"m_4 has {sum(p.numel() for p in m_4.parameters())} parameters.")

m_4 has 11689512 parameters.


In [None]:
# Computing number of FLOPs.
flops = FlopCountAnalysis(m_4, (torch.randn((1, 3, 224, 224)),))

# Printing number of FLOPs per layer.
print(flop_count_str(flops))

ResNet(
  #params: 11.69M, #flops: 1.83G
  (conv1): Conv2d(
    3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    #params: 9.41K, #flops: 0.12G
  )
  (bn1): BatchNorm2d(
    64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    #params: 0.13K, #flops: 4.01M
  )
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    #params: 0.15M, #flops: 0.47G
    (0): BasicBlock(
      #params: 73.98K, #flops: 0.23G
      (conv1): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        #params: 36.86K, #flops: 0.12G
      )
      (bn1): BatchNorm2d(
        64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
        #params: 0.13K, #flops: 1M
      )
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(
        64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        #params: 36.86K, #flops: 0.12G


In [None]:
# Loading model to device.
m_4.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
# Epochs.
EPOCHS = 25

# Initial learning rate.
LR = 6e-5

# Weight decay.
WD = 1e-5

# Loss function.
loss_fn = nn.CrossEntropyLoss()

# Optimizer.
optimizer = torch.optim.AdamW(m_4.parameters(), lr = LR, weight_decay = WD)

# Scheduler.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

In [None]:
# Training.
history = train(m_4, optimizer, scheduler, loss_fn, train_loader, val_loader, EPOCHS, device)

# Creating a checkpoint.
checkpoint = {
  "model": m_4.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict()
}

# Saving the model.
torch.save(checkpoint, "checkpoint4.pt")

# Saving history.
with open("m_4_history.pkl", "wb") as f:
  pickle.dump(history, f)

Epoch  1/25:


Training: 100%|██████████| 1250/1250 [02:12<00:00,  9.45it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.26it/s]


train_loss: 1.7029, train_accuracy: 0.4316, val_loss: 1.2856, val_accuracy: 0.5347, lr: 6.0000e-05

Epoch  2/25:


Training: 100%|██████████| 1250/1250 [02:06<00:00,  9.90it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.19it/s]


train_loss: 1.1560, train_accuracy: 0.5877, val_loss: 1.1000, val_accuracy: 0.6051, lr: 5.9763e-05

Epoch  3/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.98it/s]
Validation: 100%|██████████| 313/313 [00:23<00:00, 13.55it/s]


train_loss: 0.9575, train_accuracy: 0.6609, val_loss: 0.9275, val_accuracy: 0.6746, lr: 5.9057e-05

Epoch  4/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.04it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.47it/s]


train_loss: 0.8337, train_accuracy: 0.7075, val_loss: 0.8515, val_accuracy: 0.7041, lr: 5.7893e-05

Epoch  5/25:


Training: 100%|██████████| 1250/1250 [02:06<00:00,  9.90it/s]
Validation: 100%|██████████| 313/313 [00:21<00:00, 14.88it/s]


train_loss: 0.7347, train_accuracy: 0.7440, val_loss: 0.7346, val_accuracy: 0.7446, lr: 5.6289e-05

Epoch  6/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.95it/s]
Validation: 100%|██████████| 313/313 [00:21<00:00, 14.32it/s]


train_loss: 0.6619, train_accuracy: 0.7701, val_loss: 0.7150, val_accuracy: 0.7569, lr: 5.4271e-05

Epoch  7/25:


Training: 100%|██████████| 1250/1250 [02:06<00:00,  9.85it/s]
Validation: 100%|██████████| 313/313 [00:21<00:00, 14.90it/s]


train_loss: 0.6100, train_accuracy: 0.7879, val_loss: 0.6422, val_accuracy: 0.7813, lr: 5.1869e-05

Epoch  8/25:


Training: 100%|██████████| 1250/1250 [02:06<00:00,  9.87it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 14.92it/s]


train_loss: 0.5594, train_accuracy: 0.8070, val_loss: 0.6199, val_accuracy: 0.7877, lr: 4.9123e-05

Epoch  9/25:


Training: 100%|██████████| 1250/1250 [02:07<00:00,  9.84it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.14it/s]


train_loss: 0.5222, train_accuracy: 0.8189, val_loss: 0.5733, val_accuracy: 0.8062, lr: 4.6075e-05

Epoch 10/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.02it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.36it/s]


train_loss: 0.4903, train_accuracy: 0.8304, val_loss: 0.5667, val_accuracy: 0.8036, lr: 4.2773e-05

Epoch 11/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.94it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.19it/s]


train_loss: 0.4535, train_accuracy: 0.8446, val_loss: 0.5141, val_accuracy: 0.8220, lr: 3.9271e-05

Epoch 12/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.02it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.48it/s]


train_loss: 0.4243, train_accuracy: 0.8532, val_loss: 0.5124, val_accuracy: 0.8276, lr: 3.5621e-05

Epoch 13/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.95it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.44it/s]


train_loss: 0.3945, train_accuracy: 0.8645, val_loss: 0.4849, val_accuracy: 0.8360, lr: 3.1884e-05

Epoch 14/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.08it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 13.89it/s]


train_loss: 0.3730, train_accuracy: 0.8711, val_loss: 0.4713, val_accuracy: 0.8410, lr: 2.8116e-05

Epoch 15/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.05it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.45it/s]


train_loss: 0.3440, train_accuracy: 0.8810, val_loss: 0.4633, val_accuracy: 0.8412, lr: 2.4379e-05

Epoch 16/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.97it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.33it/s]


train_loss: 0.3225, train_accuracy: 0.8901, val_loss: 0.4303, val_accuracy: 0.8560, lr: 2.0729e-05

Epoch 17/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.06it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.37it/s]


train_loss: 0.3008, train_accuracy: 0.8976, val_loss: 0.4270, val_accuracy: 0.8513, lr: 1.7227e-05

Epoch 18/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.96it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.29it/s]


train_loss: 0.2870, train_accuracy: 0.9002, val_loss: 0.4249, val_accuracy: 0.8566, lr: 1.3925e-05

Epoch 19/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.06it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.39it/s]


train_loss: 0.2683, train_accuracy: 0.9090, val_loss: 0.4175, val_accuracy: 0.8552, lr: 1.0877e-05

Epoch 20/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.97it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.57it/s]


train_loss: 0.2569, train_accuracy: 0.9116, val_loss: 0.4041, val_accuracy: 0.8637, lr: 8.1309e-06

Epoch 21/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.07it/s]
Validation: 100%|██████████| 313/313 [00:22<00:00, 14.04it/s]


train_loss: 0.2475, train_accuracy: 0.9168, val_loss: 0.4130, val_accuracy: 0.8608, lr: 5.7295e-06

Epoch 22/25:


Training: 100%|██████████| 1250/1250 [02:03<00:00, 10.09it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.24it/s]


train_loss: 0.2348, train_accuracy: 0.9215, val_loss: 0.4017, val_accuracy: 0.8611, lr: 3.7108e-06

Epoch 23/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00,  9.95it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.52it/s]


train_loss: 0.2289, train_accuracy: 0.9232, val_loss: 0.4034, val_accuracy: 0.8645, lr: 2.1067e-06

Epoch 24/25:


Training: 100%|██████████| 1250/1250 [02:04<00:00, 10.08it/s]
Validation: 100%|██████████| 313/313 [00:20<00:00, 15.57it/s]


train_loss: 0.2227, train_accuracy: 0.9247, val_loss: 0.3964, val_accuracy: 0.8686, lr: 9.4251e-07

Epoch 25/25:


Training: 100%|██████████| 1250/1250 [02:05<00:00, 10.00it/s]
Validation: 100%|██████████| 313/313 [00:19<00:00, 15.70it/s]



train_loss: 0.2221, train_accuracy: 0.9247, val_loss: 0.3975, val_accuracy: 0.8678, lr: 2.3656e-07


## Models Testing

In [22]:
# Tests a model.
def test(model, test_loader):
  
  # Tetsing.
  model.eval()

  # Predictions and ground truth.
  y_pred = []
  y_true = []

  with torch.no_grad():

    # Iterating over mini-batches.
    for batch in tqdm(test_loader, desc = "Testing", position = 0):

      # Moving x and y to GPU.
      x = batch[0].to(device)
      y = batch[1].to(device)

      # Predictions.
      yhat = model(x)

      # Updating variables.
      y_pred.extend(torch.argmax(yhat, 1).tolist())
      y_true.extend(y.tolist())

  # Returning pred and true.
  return y_pred, y_true

### Lightweight MHSA model ($m_1$)

In [23]:
# Testing the model.
y_pred, y_true = test(m_1, test_loader)

Testing: 100%|██████████| 313/313 [00:36<00:00,  8.61it/s]


In [24]:
# Classification report.
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.89      0.91      0.90      1000
           1       0.95      0.95      0.95      1000
           2       0.85      0.85      0.85      1000
           3       0.78      0.77      0.77      1000
           4       0.85      0.89      0.87      1000
           5       0.83      0.80      0.82      1000
           6       0.92      0.92      0.92      1000
           7       0.94      0.92      0.93      1000
           8       0.94      0.94      0.94      1000
           9       0.93      0.94      0.93      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000



### Standard MHSA model ($m_2$)

In [None]:
# Testing the model.
y_pred, y_true = test(m_2, test_loader)

In [None]:
# Classification report.
print(classification_report(y_true, y_pred))

### No attention model ($m_3$) 

In [None]:
# Testing the model.
y_pred, y_true = test(m_3, test_loader)

Testing: 100%|██████████| 313/313 [00:28<00:00, 11.04it/s]


In [None]:
# Classification report.
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.90      0.90      0.90      1000
           1       0.95      0.96      0.95      1000
           2       0.87      0.86      0.87      1000
           3       0.75      0.78      0.77      1000
           4       0.88      0.89      0.88      1000
           5       0.85      0.80      0.82      1000
           6       0.92      0.91      0.92      1000
           7       0.92      0.92      0.92      1000
           8       0.92      0.95      0.94      1000
           9       0.94      0.94      0.94      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000



### ResNet-18 ($m_4$)

In [None]:
# Testing the model.
y_pred, y_true = test(m_4, test_loader)

Testing: 100%|██████████| 313/313 [00:18<00:00, 17.35it/s]


In [None]:
# Classification report.
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.86      0.91      0.89      1000
           1       0.95      0.95      0.95      1000
           2       0.82      0.84      0.83      1000
           3       0.74      0.77      0.76      1000
           4       0.88      0.86      0.87      1000
           5       0.83      0.79      0.81      1000
           6       0.92      0.90      0.91      1000
           7       0.92      0.89      0.91      1000
           8       0.92      0.94      0.93      1000
           9       0.92      0.93      0.93      1000

    accuracy                           0.88     10000
   macro avg       0.88      0.88      0.88     10000
weighted avg       0.88      0.88      0.88     10000



## Results

In [None]:
# Plots losses and accuracies.
def plot_training(histories, m_names):

  # Creating the figure and axes.
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (10, 4))

  # Iterating over models.
  for name in m_names:

    # Computing the x axis array.
    x = np.linspace(1, len(histories[name]["val_loss"]), len(histories[name]["val_loss"]), dtype = int)

    # Plotting.
    ax1.plot(x, histories[name]["val_loss"], label = r"{}".format(name))
    ax2.plot(x, histories[name]["val_accuracy"], label = r"{}".format(name))

  ax1.set_ylabel("Loss")
  ax1.set_xlabel("Epoch")
  ax1.set_xticks(x)
  ax1.legend()

  ax2.set_ylabel("Accuracy")
  ax2.set_xlabel("Epoch")
  ax2.set_xticks(x)
  ax2.legend()

  plt.show()

In [None]:
# Training histories.
histories = {}

# Names.
m_names = ["m_1", "m_2", "m_3", "m_4"]

# Reading histories.
for name in m_names:
  with open(f"{name}_history.pkl", "rb") as f: 
    histories[name] = pickle.load(f)

# Printing training loss and accuracy.
plot_training(histories, m_names)

## References <a name="references"></a>

1. Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing Xu and Yunhe Wang. CMT: Convolutional Neural Networks Meet Vision Transformers, 2021. [https://arxiv.org/abs/2107.06263](https://arxiv.org/abs/2107.06263).