# CvT Implementation

In [1]:
import torch
from torch import nn, einsum
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from einops.layers.torch import Rearrange
from einops import rearrange
from einops import repeat
from torchvision.transforms import v2

# Convolutional Transformer Modules

## Helper Modules

In [2]:
device = 
(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device

In [3]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.nn = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        ).to(device)
    def forward(self, x):
        return self.nn(x)

In [4]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        # allows to pass variable # of params to fn
        x = fn(x, **kwargs) + x
        return x

In [5]:
class Norm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim).to(device)
        self.fn = fn
    def forward(self, x, **kwargs):
        # apply layer norm prior to fn
        x = self.norm(x)
        return self.fn(x, **kwargs)

## CvT Specific Helper Modules

In [6]:
class ConvEmbedding(nn.Module):
    def __init__(self, image_size, in_channels, kernel_size, stride, size, padding, dim):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels, dim, kernel_size, stride, padding),
            Rearrange('b c h w -> b (h w) c', h = image_size//size, w = image_size//size),
            nn.LayerNorm(dim)).to(device)
    def forward(self, x):
        x = self.conv(x)
        return x

In [7]:
class ConvProj(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding = 0):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, padding=padding).to(device)
        self.bn = nn.BatchNorm2d(in_channels).to(device)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1).to(device)
    def forward(self, input):
        x = self.depthwise(input)
        x = self.bn(x)
        x = self.pointwise(x)
        return x


In [8]:
class ConvAttention(nn.Module):
     def __init__(self, dim, img_size, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
                 last_stage=False):

        super().__init__()
        self.last_stage = last_stage
        self.img_size = img_size
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        # map to q,k,v
        self.heads = heads
        self.scale = dim_head ** -0.5
        pad = (kernel_size - q_stride)//2
        self.to_q = ConvProj(dim, inner_dim, kernel_size, pad).to(device)
        self.to_k = ConvProj(dim, inner_dim, kernel_size, pad).to(device)
        self.to_v = ConvProj(dim, inner_dim, kernel_size, pad).to(device)

        # use if outputting, else use identity
        self.out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)).to(device) if project_out else nn.Identity().to(device)

     def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        if self.last_stage:
            cls_token = x[:, 0]
            x = x[:, 1:]
            cls_token = rearrange(cls_token.unsqueeze(1), 'b n (h d) -> b h n d', h = h)
        x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size)
        q = self.to_q(x)
        q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)

        v = self.to_v(x)
        v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)

        k = self.to_k(x)
        k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)

        if self.last_stage:
            q = torch.cat((cls_token, q), dim=2)
            v = torch.cat((cls_token, v), dim=2)
            k = torch.cat((cls_token, k), dim=2)


        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.out(out)
        return out


In [9]:
class Transformer(nn.Module):
    def __init__(self, dim, img_size, depth, heads, dim_heads, mlp_dim, dropout=0., last_stage=False):
        super().__init__()
        self.modules = []
        for _ in range(depth):
            self.modules.append(nn.ModuleList([
                Norm(dim, ConvAttention(dim, img_size, heads=heads, dim_head=dim_heads, dropout=dropout, last_stage=last_stage)).to(device),
                Norm(dim, FeedForward(dim, mlp_dim, dropout=dropout)).to(device)
            ]))
    def forward(self, x):
        for attn, ff in self.modules:
            x = attn(x) + x
            x = ff(x) + x
        return x

## CvT Class

In [10]:
class CvT(nn.Module):
    def __init__(self, image_size, in_channels, num_classes, dim=64):
        super().__init__()

        self.dim = dim
        # Stage 1: Conv, Transformer
        trans = Transformer(dim, image_size//4, 1, 1, dim_heads=self.dim, mlp_dim=dim*4).to(device)
        self.stage_1_embed = ConvEmbedding(image_size, in_channels, 7, 4, 4, 2, self.dim).to(device)
        self.stage_1_transformer = nn.Sequential(trans,
                                                  Rearrange('b (h w) c -> b c h w', h = image_size//4, w = image_size//4)).to(device)
        # Stage 2: Conv, Transformer
        in_channels = dim
        scale = 3//1
        dim = dim * scale
        self.stage_2_embed = ConvEmbedding(image_size, in_channels, 3, 2, 8, 1, dim).to(device)
        trans = Transformer(dim, image_size//8, 2, 3, dim_heads=self.dim, mlp_dim=dim*4).to(device)
        self.stage_2_transformer = nn.Sequential(trans,
                                                  Rearrange('b (h w) c -> b c h w', h = image_size//8, w = image_size//8)).to(device)
        # Stage 3: Conv, Transformer, FFN
        in_channels = dim
        scale = 6//3
        dim = scale * dim
        self.stage_3_embed = ConvEmbedding(image_size, in_channels, 3, 2, 16, 1, dim).to(device)
        self.stage_3_transformer = nn.Sequential(Transformer(dim, image_size//16, 10, 6, dim_heads=self.dim, mlp_dim=dim*4, last_stage=True)).to(device)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim)).to(device)

        self.mlp = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        ).to(device)


    def forward(self, img):

        xs = self.stage_1_embed(img)
        xs = self.stage_1_transformer(xs)

        xs = self.stage_2_embed(xs)
        xs = self.stage_2_transformer(xs)

        xs = self.stage_3_embed(xs)
        b, n, _ = xs.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        xs = torch.cat((cls_tokens, xs), dim=1)
        xs = self.stage_3_transformer(xs)
        xs = xs[:, 0]

        xs = self.mlp(xs)
        return xs

# Model Training


## Load and Transform Data

In [11]:
transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize((224, 224))])

In [12]:
food101_train_data = datasets.Food101('/', split="train", download=True, transform=transform)
food101_test_data = datasets.Food101('/', split="test", transform=transform)

train_loader = torch.utils.data.DataLoader(food101_train_data,
                                          batch_size=32,
                                          shuffle=True)

food101_train_data.__getitem__(2)[0].shape

torch.Size([3, 224, 224])

In [13]:
cvt = CvT(224, 3, 101)
cvt.eval()

CvT(
  (stage_1_embed): ConvEmbedding(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
      (1): Rearrange('b c h w -> b (h w) c', h=56, w=56)
      (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (stage_1_transformer): Sequential(
    (0): Transformer()
    (1): Rearrange('b (h w) c -> b c h w', h=56, w=56)
  )
  (stage_2_embed): ConvEmbedding(
    (conv): Sequential(
      (0): Conv2d(64, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): Rearrange('b c h w -> b (h w) c', h=28, w=28)
      (2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
    )
  )
  (stage_2_transformer): Sequential(
    (0): Transformer()
    (1): Rearrange('b (h w) c -> b c h w', h=28, w=28)
  )
  (stage_3_embed): ConvEmbedding(
    (conv): Sequential(
      (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): Rearrange('b c h w -> b (h w) c', h=14, w=14)
      (2): LayerNorm((384,), e

In [18]:
def train_loop(loss_fn, optim, trainloader, net):
    running_loss = 0
    for i, data in enumerate(trainloader, 0):
        
        inputs, labels = data
        inputs = inputs
        labels = labels
        # zero parameter gradients
        optim.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optim.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print(f'loss: {running_loss / 100:.3f}')
            running_loss = 0.0
        print(loss)

def train(epochs, trainloader, net, loss_fn=nn.CrossEntropyLoss(), optimizer=optim.Adam(cvt.parameters(),lr=0.001)):
    net.train()
    for i in range(epochs):
        train_loop(loss_fn, optimizer, trainloader, net)
    print("Done")


In [19]:
optimizer = torch.optim.Adam(cvt.parameters(), lr=1e-3)
train(1, train_loader, cvt, optimizer=optimizer)

# Save model
PATH = "cvt_impl.pt"
torch.save(cvt.state_dict(), PATH)

tensor(4.5432, grad_fn=<NllLossBackward0>)
tensor(4.6857, grad_fn=<NllLossBackward0>)
tensor(4.6629, grad_fn=<NllLossBackward0>)
tensor(4.6905, grad_fn=<NllLossBackward0>)
tensor(4.8349, grad_fn=<NllLossBackward0>)
tensor(4.5938, grad_fn=<NllLossBackward0>)
tensor(4.5737, grad_fn=<NllLossBackward0>)
tensor(4.8321, grad_fn=<NllLossBackward0>)
tensor(4.7636, grad_fn=<NllLossBackward0>)
tensor(5.1685, grad_fn=<NllLossBackward0>)
tensor(4.8618, grad_fn=<NllLossBackward0>)
tensor(4.7726, grad_fn=<NllLossBackward0>)
tensor(4.5561, grad_fn=<NllLossBackward0>)
tensor(4.6430, grad_fn=<NllLossBackward0>)
tensor(4.6156, grad_fn=<NllLossBackward0>)
tensor(4.6963, grad_fn=<NllLossBackward0>)
tensor(4.8112, grad_fn=<NllLossBackward0>)
tensor(4.6130, grad_fn=<NllLossBackward0>)
tensor(4.6943, grad_fn=<NllLossBackward0>)
tensor(4.5599, grad_fn=<NllLossBackward0>)
tensor(4.5730, grad_fn=<NllLossBackward0>)
tensor(4.8356, grad_fn=<NllLossBackward0>)
tensor(4.5910, grad_fn=<NllLossBackward0>)
tensor(4.59

In [24]:
def test(model, loss_fn, test_loader):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    for i, data in enumerate(test_loader, 1):
        inputs, labels = data
        inputs = inputs
        labels = labels
        pred = model(inputs)
        loss = loss_fn(pred, labels)

        running_loss += loss.item()
        _, predicted = pred.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    avg_loss = running_loss / len(test_loader)
    accuracy = 100. * correct / total

    print('Test Loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        avg_loss, correct, total, accuracy))


In [None]:
test_loader = torch.utils.data.DataLoader(food101_test_data,
                                          batch_size=64,
                                          shuffle=True)

test(cvt, nn.CrossEntropyLoss(), test_loader)