<center>
    <h1> Machine Learning for Computer Vision</h1>
    <h2> Project Work </h2>
</center>

<br>

**Student**: Matteo Donati <br>
**Student ID**: 0001032227 <br>
**E-mail**: <a href="mailto:matteo.donati10@studio.unibo.it">matteo.donati10@studio.unibo.it</a>

<br>

This is an unofficial implementation of the paper by Guo et al., (2021)<sup>[[1]](#references)</sup>.

## Libraries and Utilities

In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms as T
import multiprocessing as mp
from matplotlib import pyplot as plt
from torchsummary import summary
from tqdm import tqdm
import time

In [2]:
def set_reproducibility(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

set_reproducibility(0)

## Data Loading and Preprocessing

In [3]:
# Retrieves the CIFAR10 dataset.
def get_cifar10(train_transform = None, test_transform = None):

  # Downloading the training set.
  train_dataset =  datasets.CIFAR10(
    root = "CIFAR10",
    train = True,
    transform = train_transform,
    download = True
  )

  # Downloading the test set.
  test_dataset = datasets.CIFAR10(
    root = "CIFAR10",
    train = False,
    transform = test_transform,
    download = True
  )

  # Returning the two sets.
  return train_dataset, test_dataset

# Fuction that creates train, validation and test data loaders.
def create_data_loaders(train_transform, 
                        test_transform, 
                        img_size = 224, 
                        split = (0.8, 0.2), 
                        batch_size = 32, 
                        num_workers = 1):

  # Retrieving CIFAR10.
  train_dataset, test_dataset = get_cifar10(train_transform, test_transform)

  # Splitting train_dataset to create a validation set.
  train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, 
                                                             (int(len(train_dataset) * split[0]), 
                                                              int(len(train_dataset) * split[1])))

  # Train data loader.
  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = batch_size,
    shuffle = True,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = True,
    sampler = None
  )

  # Validation data loader.
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size = batch_size,
    shuffle = False,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False,
    sampler = None
  )

  # Test data loader.
  test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = batch_size,
    shuffle = False,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False,
    sampler = None
  )

  # Returning the three data loaders.
  return train_loader, val_loader, test_loader

In [4]:
# Training transformations.
train_transform = T.Compose([
    T.RandomCrop(32, padding = 4),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(mean = [0.491, 0.482, 0.447], std = [0.247, 0.243, 0.262])
])

# Test transformations.
test_transform = T.Compose([
    T.Resize(224),
    T.ToTensor(),
    T.Normalize(mean = [0.491, 0.482, 0.447], std = [0.247, 0.243, 0.262])
])

# Creating data loaders.
train_loader, val_loader, test_loader = create_data_loaders(
    train_transform,
    test_transform,
    img_size = 224,
    batch_size = 32,
    num_workers = mp.cpu_count()
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to CIFAR10/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting CIFAR10/cifar-10-python.tar.gz to CIFAR10
Files already downloaded and verified


## Model Definition

In [5]:
# Stem module of the CMT architecture.
class Stem(nn.Module):
    
  # Constructor. Requires number of input and output channels.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Conv 3 x 3, stride 2.
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 2, padding = 1)
    self.bn1 = nn.BatchNorm2d(out_channels)

    # Conv 3 x 3. 
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
    self.bn2 = nn.BatchNorm2d(out_channels)

    # Conv 3 x 3.
    self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1)
    self.bn3 = nn.BatchNorm2d(out_channels)

    # GELU activation.
    self.gelu = nn.GELU()

  # Forward pass.
  def forward(self, x):
    x = self.conv1(x)
    x = self.gelu(x)
    x = self.bn1(x)
    x = self.conv2(x)
    x = self.gelu(x)
    x = self.bn2(x)
    x = self.conv3(x)
    x = self.gelu(x)
    y = self.bn3(x)
    return y

In [6]:
# Patch embedding module of the CMT architecture.
class PatchEmbedding(nn.Module):

  # Constructor. Requires number of input and output channels.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Conv 2 x 2, stride 2.
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 2, stride = 2, padding = 0)

  # Forward pass.
  def forward(self, x):
    x = self.conv(x)
    _, c, h, w = x.size()
    y = torch.nn.functional.layer_norm(x, (c, h, w))
    return y

In [7]:
# Local Perception Unit: LPU(X) = DWConv(X) + X.
class LPU(nn.Module):

  # Constructor.
  def __init__(self, in_channels, out_channels):

    super().__init__()

    # Depthwise convolution.
    self.dwconv = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, groups = in_channels)

  # Forward pass.
  def forward(self, x):
    y = self.dwconv(x) + x
    return y

# Lightweight Multi-Head-Self-Attention.
class LMHSA(nn.Module):
    
  # Constructor.
  def __init__(self, input_size, channels, d_k, d_v, stride, heads, dropout):

    super().__init__()

    # Depthwise convolutions.
    self.dwconv_k = nn.Conv2d(channels, channels, kernel_size = stride, stride = stride, groups = channels)
    self.dwconv_v = nn.Conv2d(channels, channels, kernel_size = stride, stride = stride, groups = channels)

    # Projection matrices.
    self.fc_q = nn.Linear(channels, heads * d_k)
    self.fc_k = nn.Linear(channels, heads * d_k)
    self.fc_v = nn.Linear(channels, heads * d_v)
    self.fc_o = nn.Linear(heads * d_k, channels)

    self.channels = channels
    self.d_k = d_k
    self.d_v = d_v
    self.stride = stride
    self.heads = heads
    self.dropout = dropout

    # Relative position bias to each self-attention module.
    self.B = nn.Parameter(torch.randn(1, self.heads, input_size ** 2, (input_size // stride) ** 2), requires_grad = True)
    
  # Forward pass.
  def forward(self, x):

    # Extracting shape from input signal.
    b, c, h, w = x.shape

    # Reshaping and permuting x. Final shape is (b, h * w, c).
    x_reshape = x.view(b, c, h * w).permute(0, 2, 1)

    # Layer norm.
    x_reshape = torch.nn.functional.layer_norm(x_reshape, (b, h * w, c))

    # Getting queries by applying the fc_q linear projection.
    q = self.fc_q(x_reshape)

    # Reshaping and permuting the queries. Final shape is (b, heads, n = h * w, d_k). 
    q = q.view(b, h * w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Applying depthwise conv to x.
    k = self.dwconv_k(x)

    # Extracting shape of keys.
    k_b, k_c, k_h, k_w = k.shape

    # Reshaping and permuting keys. Final shape is (k_b, k_h * k_w, k_c).
    k = k.view(k_b, k_c, k_h * k_w).permute(0, 2, 1).contiguous()

    # Projecting through fc_k.
    k = self.fc_k(k)

    # Reshaping and permuting the keys. Final shape is (k_b, heads, k_h * k_w, d_k).
    k = k.view(k_b, k_h * k_w, self.heads, self.d_k).permute(0, 2, 1, 3).contiguous()

    # Applying depthwise conv to x.
    v = self.dwconv_v(x)

    # Extracting shape of values.
    v_b, v_c, v_h, v_w = v.shape

    # Reshaping and permuting values. Final shape is (v_b, v_h * v_w, v_c).
    v = v.view(v_b, v_c, v_h * v_w).permute(0, 2, 1).contiguous()

    # Projecting through fc_v.
    v = self.fc_v(v)

    # Reshaping and permuting the keys. Final shape is (v_b, heads, v_h * v_w, d_v).
    v = v.view(v_b, v_h * v_w, self.heads, self.d_v).permute(0, 2, 1, 3).contiguous()

    # Computing softmax((Q K') / sqrt(d_k) + B). Final shape is (b, heads, n = h * w, k_h * k_w).
    attention = torch.einsum("... i d, ... j d -> ... i j", q, k) * (self.d_k ** -0.5)

    attention = attention + self.B
    attention = torch.softmax(attention, dim = -1)

    # Applying attention scores to values by taking the dot product.
    tmp = torch.matmul(attention, v).permute(0, 2, 1, 3)

    # Permuting the result. Final shape is (b, n = h * w, heads, d_v).
    tmp = tmp.contiguous().view(b, h * w, self.heads * self.d_v)

    # Projecting using fc_o and reshaping. Final shape is (b, c, h, w).
    tmp = self.fc_o(tmp).view(b, self.channels, h, w)

    # Returning tmp + x (skip connection).
    return tmp + x

# Inverted Residual Feed-forward Network: IRFFN(X) = Conv(F(Conv(X))), F(X) = DWConv(X) + X.
class IRFFN(nn.Module):

  # Constructor.
  def __init__(self, in_channels, R):

    super().__init__()

    # Number of channels after expansion.
    out_channels = int(in_channels * R)

    # Conv 1 x 1.
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1)
    self.bn1 = nn.BatchNorm2d(out_channels)

    # Depthwise convolution.
    self.dwconv = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, groups = out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)

    # Conv 1 x 1.
    self.conv2 = nn.Conv2d(out_channels, in_channels, kernel_size = 1)
    self.bn3 = nn.BatchNorm2d(in_channels)

    # GELU activation.
    self.gelu = nn.GELU()

  # Forward pass.
  def forward(self, x):
    tmp = self.conv1(x)
    tmp = self.bn1(tmp)
    tmp = self.gelu(tmp)
    tmp = self.dwconv(tmp)
    tmp = self.bn2(tmp)
    tmp = self.gelu(tmp)
    tmp = self.conv2(tmp)
    tmp = self.bn3(tmp)
    y = x + tmp
    return y

In [8]:
# CMT block of the CMT architecture.
class CMTBlock(nn.Module):

  # Constructor. By default it loads the CMT-Ti configuration.
  def __init__(self, img_size, k, d_k, d_v, num_heads, R = 3.6, in_channels = 46):

    super().__init__()

    # Local Perception Unit.
    self.lpu = LPU(in_channels, in_channels)

    # Lightweight MHSA.
    self.lmhsa = LMHSA(img_size, in_channels, d_k, d_v, k, num_heads, 0.0)

    # Inverted Residual FFN.
    self.irffn = IRFFN(in_channels, R)

  # Forward pass.
  def forward(self, x):
    x = self.lpu(x)
    x = self.lmhsa(x)
    y = self.irffn(x)
    return y

In [9]:
# CMT architecture.
class CMT(nn.Module):

  # Constructor. By default it loads the CMT-Ti configuration.
  def __init__(self,
    in_channels = 3,
    stem_channels = 16,
    cmt_channels = [46, 92, 184, 368],
    patch_channels = [46, 92, 184, 368],
    block_layers = [2, 2, 10, 2],
    R = 3.6,
    img_size = 224,
    num_class = 10
  ):

    super(CMT, self).__init__()

    # Stem layer
    self.stem = Stem(in_channels, stem_channels)

    # Patch Aggregation Layer
    self.pe1 = PatchEmbedding(stem_channels, patch_channels[0])
    self.pe2 = PatchEmbedding(patch_channels[0], patch_channels[1])
    self.pe3 = PatchEmbedding(patch_channels[1], patch_channels[2])
    self.pe4 = PatchEmbedding(patch_channels[2], patch_channels[3])

    # Stage 1 CMT blocks.
    stage1 = [CMTBlock(img_size = img_size // 4, 
                       k = 8, 
                       d_k = cmt_channels[0], 
                       d_v = cmt_channels[0], 
                       num_heads = 1, 
                       R = R, 
                       in_channels = patch_channels[0]) for _ in range(block_layers[0])]
    self.stage1 = nn.Sequential(*stage1)

    # Stage 2 CMT blocks.
    stage2 = [CMTBlock(img_size = img_size // 8,
                       k = 4,
                       d_k = cmt_channels[1] // 2,
                       d_v = cmt_channels[1] // 2,
                       num_heads = 2,
                       R = R,
                       in_channels = patch_channels[1]) for _ in range(block_layers[1])]
    self.stage2 = nn.Sequential(*stage2)

    # Stage 3 CMT blocks.
    stage3 = [CMTBlock(img_size = img_size // 16,
                       k = 2,
                       d_k = cmt_channels[2] // 4,
                       d_v = cmt_channels[2] // 4,
                       num_heads = 4,
                       R = R,
                       in_channels = patch_channels[2]) for _ in range(block_layers[2])]
    self.stage3 = nn.Sequential(*stage3)

    # Stage 4 CMT blocks.
    stage4 = [CMTBlock(img_size = img_size // 32,
                       k = 1,
                       d_k = cmt_channels[3] // 8,
                       d_v = cmt_channels[3] // 8,
                       num_heads = 8,
                       R = R,
                       in_channels = patch_channels[3]) for _ in range(block_layers[3])]
    self.stage4 = nn.Sequential(*stage4)

    # Global average pooling.
    self.avg_pool = nn.AdaptiveAvgPool2d(1)

    # Projection layer.
    self.projection = nn.Sequential(
      # nn.Linear(cmt_channels[3], 1280),
      nn.Conv2d(cmt_channels[3], 1280, kernel_size = 1),
      nn.ReLU(inplace = True),
    )

    # Classifier.
    self.classifier = nn.Linear(1280, num_class)

  # Forward pass.
  def forward(self, x):
    x = self.stem(x)
    x = self.pe1(x)
    x = self.stage1(x)
    x = self.pe2(x)
    x = self.stage2(x)
    x = self.pe3(x)
    x = self.stage3(x)
    x = self.pe4(x)
    x = self.stage4(x)
    x = self.avg_pool(x)
    x = self.projection(x)
    x = torch.flatten(x, 1)
    y = self.classifier(x)
    return y

In [10]:
# Defining the model.
model = CMT(
  in_channels = 3,
  stem_channels = 16,
  cmt_channels = [46, 92, 184, 368],
  patch_channels = [46, 92, 184, 368],
  block_layers = [2, 2, 10, 2],
  R = 3.6,
  img_size = 224,
  num_class = 10
)

In [11]:
# Choosing between GPU and CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Loading model to device.
model.to(device)

CMT(
  (stem): Stem(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (gelu): GELU(approximate='none')
  )
  (pe1): PatchEmbedding(
    (conv): Conv2d(16, 46, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe2): PatchEmbedding(
    (conv): Conv2d(46, 92, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe3): PatchEmbedding(
    (conv): Conv2d(92, 184, kernel_size=(2, 2), stride=(2, 2))
  )
  (pe4): PatchEmbedding(
    (conv): Conv2d(184, 368, kernel_size=(2, 2), stride=(2, 2))
  )
  (stage1): Sequential(
    (0): CMTBlock(
      (lpu): LPU(
 

## Model Training

In [20]:
# Epochs.
EPOCHS = 4

# Initial learning rate.
LR = 6e-5

# Weight decay.
WD = 1e-5

# Loss function.
loss_fn = nn.CrossEntropyLoss()

# Optimizer.
optimizer = torch.optim.AdamW(model.parameters(), lr = LR, weight_decay = WD)

# Learning rate scheduler.
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, EPOCHS)

In [21]:
# Trains the model.
def train(model, optimizer, loss_fn, scheduler, train_loader, val_loader, epochs, device = "cuda"):

  # Metrics.
  history = {"train_loss": [], "train_accuracy": [], "val_loss": [], "val_accuracy": []}

  # Iterating over epochs.
  for epoch in range(epochs):

    print("Epoch %2d/%2d:" % (epoch + 1, epochs))

    # Training.
    model.train()

    # Metrics
    train_loss = 0.0
    num_train_correct = 0
    num_train_examples = 0

    # Iterating over mini-batches.
    for batch in tqdm(train_loader, position = 0):

      # Gradient reset.
      optimizer.zero_grad()

      # Moving x and y to GPU.
      x = batch[0].to(device)
      y = batch[1].to(device)

      # Predictions.
      yhat = model(x)

      # Loss.
      loss = loss_fn(yhat, y)

      # Computing the gradient.
      loss.backward()

      # Updating the parameters.
      optimizer.step()

      # Updating the metrics.
      train_loss += loss.item() * x.shape[0]
      num_train_correct += (torch.argmax(yhat, 1) == y).sum().item()
      num_train_examples += x.shape[0]

    # Computing the epoch's metrics.
    train_accuracy = num_train_correct / num_train_examples
    train_loss = train_loss / len(train_loader.dataset)

    # Updating the learning rate.
    scheduler.step()

    # Validation.
    model.eval()

    # Metrics.
    val_loss = 0.0
    num_val_correct = 0
    num_val_examples = 0

    with torch.no_grad():

      # Iterating over mini-batches.
      for batch in tqdm(val_loader, position = 0):

        # Moving x and y to GPU.
        x = batch[0].to(device)
        y = batch[1].to(device)

        # Predictions.
        yhat = model(x)

        # Loss.
        loss = loss_fn(yhat, y)

        # Updating the metrics.
        val_loss += loss.item() * x.shape[0]
        num_val_correct += (torch.argmax(yhat, 1) == y).sum().item()
        num_val_examples += y.shape[0]

      # Computing the epoch's metrics.
      val_accuracy = num_val_correct / num_val_examples
      val_loss = val_loss / len(val_loader.dataset)

    print("{}train_loss: {:.4f}, train_accuracy: {:.4f}, val_loss: {:.4f}, val_accuracy: {:.4f}{}".format("\n" if epoch == epochs - 1 else "",
                                                                                                          train_loss, 
                                                                                                          train_accuracy, 
                                                                                                          val_loss, 
                                                                                                          val_accuracy, 
                                                                                                          "\n" if epoch != epochs - 1 else ""))

    # Appending the epoch's metrics to history lists.
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_accuracy"].append(train_accuracy)
    history["val_accuracy"].append(val_accuracy)

  # Returning history.
  return history

In [22]:
history = train(model, optimizer, loss_fn, scheduler, train_loader, val_loader, EPOCHS)

Epoch  1/ 4:


100%|██████████| 1250/1250 [08:38<00:00,  2.41it/s]
100%|██████████| 313/313 [00:37<00:00,  8.39it/s]


loss: 1.4331, accuracy: 0.4726, val_loss: 1.1372, val_accuracy: 0.5929
Epoch  2/ 4:


100%|██████████| 1250/1250 [08:29<00:00,  2.45it/s]
100%|██████████| 313/313 [00:36<00:00,  8.68it/s]


loss: 1.0362, accuracy: 0.6312, val_loss: 0.9606, val_accuracy: 0.6584
Epoch  3/ 4:


100%|██████████| 1250/1250 [08:28<00:00,  2.46it/s]
100%|██████████| 313/313 [00:35<00:00,  8.71it/s]


loss: 0.8442, accuracy: 0.7007, val_loss: 0.8086, val_accuracy: 0.7166
Epoch  4/ 4:


100%|██████████| 1250/1250 [08:36<00:00,  2.42it/s]
100%|██████████| 313/313 [00:36<00:00,  8.68it/s]

loss: 0.7568, accuracy: 0.7320, val_loss: 0.7568, val_accuracy: 0.7373





## References <a name="references"></a>

1. Jianyuan Guo, Kai Han, Han Wu, Chang Xu, Yehui Tang, Chunjing Xu and Yunhe Wang. CMT: Convolutional Neural Networks Meet Vision Transformers, 2021. [https://arxiv.org/abs/2107.06263](https://arxiv.org/abs/2107.06263).