In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchinfo import summary
import timm
from pytorch_models_imp.segformer import PatchOverlapEmbeddings, ReducedSelfAttention, MixFFN, SegformerBlock, mitb0

In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=5)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=5)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
images, labels = next(iter(trainloader))
N, C, H, W = images.shape
H_reducted, W_reducted = H // 4, W // 4

### Overlapped patch embedder

In [4]:
EMBED_SIZE = 32
patch_embedder = PatchOverlapEmbeddings(input_channels=C, image_sizes=(H_reducted, W_reducted), stride=4, patch_size=7, embed_size=32)
number_of_patches = patch_embedder.number_of_patches
patch_height_resolution = patch_embedder.patch_height_resolution
patch_width_resolution = patch_embedder.patch_width_resoultion

In [5]:
patch_embedder_out = patch_embedder(images)
assert patch_embedder_out.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Reduced Self Attention

In [6]:
EMBED_SIZE = 32
NUM_HEADS = 2
REDUCTION = 2
DROPOUT = 0.0
attention = ReducedSelfAttention(NUM_HEADS, EMBED_SIZE, REDUCTION, DROPOUT)

In [7]:
attention_out = attention(patch_embedder_out, H_reducted, W_reducted)
assert attention_out.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### MIX FFN

In [8]:
MLP_EXPANSION = 4
mix_ffn = MixFFN(EMBED_SIZE, MLP_EXPANSION, DROPOUT)

In [9]:
out_mix_ffn = mix_ffn(attention_out, H_reducted, W_reducted)
assert out_mix_ffn.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Segformer Block

In [10]:
segformer_block = SegformerBlock(NUM_HEADS, EMBED_SIZE, REDUCTION, MLP_EXPANSION, DROPOUT)

In [11]:
out_segformer_block = segformer_block(patch_embedder_out, H_reducted, W_reducted)
assert out_segformer_block.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Segformer Encoder

In [64]:
segformer_encoder = mitb0((H, W))
hidden_states = segformer_encoder(images)

In [65]:
for indx, hidden_state in enumerate(hidden_states):
    print(f"Hidden state {indx} shape: {hidden_state.shape}")

Hidden state 0 shape: torch.Size([128, 3136, 32])
Hidden state 1 shape: torch.Size([128, 784, 64])
Hidden state 2 shape: torch.Size([128, 196, 160])
Hidden state 3 shape: torch.Size([128, 49, 256])


In [66]:
segformer_encoder

SegformerEncoder(
  (segformer_layers): ModuleList(
    (0): SegformerLayer(
      (layers): ModuleList(
        (layer_0): SegformerBlock(
          (attention): ReducedSelfAttention(
            (query): Linear(in_features=32, out_features=32, bias=True)
            (key): Linear(in_features=32, out_features=32, bias=True)
            (value): Linear(in_features=32, out_features=32, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (reductor): Conv2d(32, 32, kernel_size=(8, 8), stride=(8, 8))
            (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
            (proj): Linear(in_features=32, out_features=32, bias=True)
          )
          (mix_ffn): MixFFN(
            (fc1): Linear(in_features=32, out_features=256, bias=True)
            (fc2): Linear(in_features=256, out_features=32, bias=True)
            (position_conv): PositionDWConv(
              (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
  