In [1]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
from einops import rearrange


# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple

  warn(f"Failed to load image Python extension: {e}")


In [4]:
test = torch.randn((1,576,768))

In [11]:
# Masks transformer class
class MaskTransformer(nn.Module):
    def __init__(self, image_size = (640,640) ,n_classes = 2, patch_size = 16, depth = 6 ,heads = 8, dim_enc = 768, dim_dec = 512, mlp_dim = 1024, dropout = 0.1):
        super(MaskTransformer, self).__init__()
        self.dim = dim_enc
        self.patch_size = patch_size
        self.depth = depth
        self.class_n = n_classes
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.d_model = dim_dec
        self.scale = self.d_model ** -0.5
        self.att_heads = heads
        self.image_size = image_size
        
        # Define the transformer blocks
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(dim_dec, heads, mlp_dim, dropout)
            for _ in range(self.depth)
            ])
        
        # Learnable Class embedding parameter
        self.cls_emb = nn.Parameter(torch.randn(1, n_classes,dim_dec))
        
        # Projection layers for patch embeddings and class embeddings
        self.proj_dec = nn.Linear(dim_enc,dim_dec)
        self.proj_patch = nn.Parameter(self.scale * torch.randn(dim_dec, dim_dec))
        self.proj_classes = nn.Parameter(self.scale * torch.randn(dim_dec, dim_dec))
        
        # Normalization layers
        self.decoder_norm = nn.LayerNorm(dim_dec)
        self.mask_norm = nn.LayerNorm(n_classes)
        
        
        # Initialize weights from a random normal distribution for all layers and the class embedding parameter
        self.apply(self.init_weights)
        init.normal_(self.cls_emb, std=0.02)
    
    # Init weights method
    @staticmethod
    def init_weights(module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode='fan_in')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        H, W = self.image_size
        GS = H // self.patch_size

        # Project embeddings to mask transformer dim size and expand class embedding(by adding the batch dim) to match these 
        x = self.proj_dec(x)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        
        # Add the learnable class embedding to the patch embeddings and pass through the transformer blocks
        x = torch.cat((x, cls_emb), 1)
        for blk in self.transformer_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # Split output tensor into patch embeddings and the transformer patch level class embeddings
        patches, cls_seg_feat = x[:, : -self.class_n], x[:, -self.class_n :]
        patches = patches @ self.proj_patch
        cls_seg_feat = cls_seg_feat @ self.proj_classes

        # Perform L2 Normalizations over the two tensors
        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        # 1. Calculate patch level class scores(as per dot product) by between the normalized patch tensors and the normalized class embeddings
        # 2. Reshape the output from (batch,number of patches, classes) to (batch size, classes, height, width)
        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks)
        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))

        return masks       


In [12]:
model = MaskTransformer(image_size=(384,384),n_classes=1, dim_dec= 256)

In [13]:
model(test).shape

torch.Size([1, 1, 24, 24])

In [14]:
model(test)

tensor([[[[-5.6671e-08, -7.3756e-07,  2.7428e-08, -4.7659e-08, -3.9736e-08,
            2.5518e-08, -2.2287e-08, -3.3461e-07, -3.7178e-07, -9.1061e-07,
            3.0523e-09, -3.9563e-08, -1.7503e-08, -1.0877e-07, -1.0289e-07,
            9.1736e-08, -9.2361e-08, -2.2415e-07,  4.1944e-07, -4.2552e-08,
           -4.0916e-07,  1.9419e-07, -8.4540e-07,  4.8652e-08],
          [-1.9286e-07,  4.4176e-07,  9.4099e-07, -1.7902e-08,  8.0209e-08,
            1.0372e-07, -1.5885e-06,  6.3817e-08, -1.8562e-07,  5.9067e-08,
            3.2308e-07, -3.3988e-07,  8.6113e-08,  1.0192e-06, -1.3037e-07,
            9.3491e-08, -1.1120e-07,  8.7559e-08, -7.0695e-09,  7.3977e-08,
           -2.3963e-07, -2.5006e-07, -5.3607e-08,  8.0832e-07],
          [-8.4898e-08,  1.0473e-07, -1.6218e-08, -1.6943e-07,  8.9449e-10,
            8.4814e-07,  1.3022e-08,  4.6543e-07, -2.4182e-07,  1.9075e-07,
            5.7920e-08,  7.2318e-07, -3.3806e-07,  1.6051e-06, -6.4804e-07,
           -2.8490e-07,  5.3466e-08,