In [20]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torchinfo import summary
import timm
from pytorch_models_imp.segformer import PatchOverlapEmbeddings, ReducedSelfAttention, MixFFN, SegformerBlock, mitb0, mitb1, SegformerDecoder, segformer_b2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=5)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=5)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
images, labels = next(iter(trainloader))
N, C, H, W = images.shape
H_reducted, W_reducted = H // 4, W // 4

### Overlapped patch embedder

In [4]:
EMBED_SIZE = 32
patch_embedder = PatchOverlapEmbeddings(input_channels=C, image_sizes=(H_reducted, W_reducted), stride=4, patch_size=7, embed_size=32)
number_of_patches = patch_embedder.number_of_patches
patch_height_resolution = patch_embedder.patch_height_resolution
patch_width_resolution = patch_embedder.patch_width_resoultion

In [5]:
patch_embedder_out, h_new, w_new = patch_embedder(images)
assert patch_embedder_out.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Reduced Self Attention

In [6]:
EMBED_SIZE = 32
NUM_HEADS = 2
REDUCTION = 2
DROPOUT = 0.0
attention = ReducedSelfAttention(NUM_HEADS, EMBED_SIZE, REDUCTION, DROPOUT)

In [7]:
attention_out = attention(patch_embedder_out, H_reducted, W_reducted)
assert attention_out.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### MIX FFN

In [8]:
MLP_EXPANSION = 4
mix_ffn = MixFFN(EMBED_SIZE, MLP_EXPANSION, DROPOUT)

In [9]:
out_mix_ffn = mix_ffn(attention_out, H_reducted, W_reducted)
assert out_mix_ffn.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Segformer Block

In [10]:
segformer_block = SegformerBlock(NUM_HEADS, EMBED_SIZE, REDUCTION, MLP_EXPANSION, DROPOUT)

In [11]:
out_segformer_block = segformer_block(patch_embedder_out, H_reducted, W_reducted)
assert out_segformer_block.shape == (N, H_reducted * W_reducted, EMBED_SIZE)

### Segformer Encoder

In [12]:
segformer_encoder = mitb0((H, W))
hidden_states = segformer_encoder(images)

In [13]:
for indx, hidden_state in enumerate(hidden_states):
    print(f"Hidden state {indx} shape: {hidden_state.shape}")

Hidden state 0 shape: torch.Size([128, 32, 56, 56])
Hidden state 1 shape: torch.Size([128, 64, 28, 28])
Hidden state 2 shape: torch.Size([128, 160, 14, 14])
Hidden state 3 shape: torch.Size([128, 256, 7, 7])


### Segformer Decoder

In [17]:
DECODER_HIDDEN_STATE = 256
DROPOUT = 0.0
NUM_CLASSES = 1000

layer_config = segformer_encoder.layer_configuration
segformer_decoder = SegformerDecoder(layer_config, DECODER_HIDDEN_STATE, NUM_CLASSES, DROPOUT)

In [18]:
decoder_output = segformer_decoder(hidden_states)
assert decoder_output.shape == (N, NUM_CLASSES, H_reducted, W_reducted)

### Segformer

In [26]:
segformer_full = segformer_b2((H, W), NUM_CLASSES)

In [27]:
segmentation_out = segformer_full(images)
assert segmentation_out.shape == (N, NUM_CLASSES, H_reducted, W_reducted)