In [18]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
from einops import rearrange


# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple

In [19]:
# Count pipeline trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [20]:
# Masks transformer class
class MaskTransformer(nn.Module):
    def __init__(self, image_size = (640,640) ,n_classes = 1, patch_size = 16, depth = 2 ,heads = 8, dim_enc = 768, dim_dec = 768, mlp_dim = 3072, dropout = 0.1, pre_train = False):
        super(MaskTransformer, self).__init__()
        self.dim = dim_enc
        self.patch_size = patch_size
        self.depth = depth
        self.class_n = n_classes
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.d_model = dim_dec
        self.scale = self.d_model ** -0.5
        self.att_heads = heads
        self.image_size = image_size
        
        # Define the transformer blocks
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(dim_dec, heads, mlp_dim, dropout)
            for _ in range(self.depth)
            ])
        
        # Learnable Class embedding parameter
        self.cls_emb = nn.Parameter(torch.randn(1, n_classes,dim_dec))
        
        # Projection layers for patch embeddings and class embeddings
        self.proj_dec = nn.Linear(dim_enc,dim_dec)
        self.proj_patch = nn.Parameter(self.scale * torch.randn(dim_dec, dim_dec))
        self.proj_classes = nn.Parameter(self.scale * torch.randn(dim_dec, dim_dec))
        
        # Normalization layers
        self.decoder_norm = nn.LayerNorm(dim_dec)
        self.mask_norm = nn.LayerNorm(n_classes)
        
        
        # Initialize weights from a random normal distribution for all layers and the class embedding parameter
        if pre_train:
            self.load_pretrained_weights()
            init.normal_(self.cls_emb, std=0.02)
            init.xavier_uniform_(self.proj_dec.weight)
            init.normal_(self.proj_classes, std=0.02)
            init.normal_(self.proj_patch)
            self.decoder_norm.weight.data.normal_(mean=0.0, std=0.02)
            self.decoder_norm.bias.data.zero_()
            self.mask_norm.weight.data.normal_(mean=0.0, std=0.02)
            self.mask_norm.bias.data.zero_()
        else:
            self.apply(self.init_weights)
        
    
    # Init weights method
    @staticmethod
    def init_weights(module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.kaiming_normal_(module.weight, mode='fan_in')
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.weight, 1)
            nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        H, W = self.image_size
        GS = H // self.patch_size

        # Project embeddings to mask transformer dim size and expand class embedding(by adding the batch dim) to match these 
        x = self.proj_dec(x)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        
        # Add the learnable class embedding to the patch embeddings and pass through the transformer blocks
        x = torch.cat((x, cls_emb), 1)
        for blk in self.transformer_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        # Split output tensor into patch embeddings and the transformer patch level class embeddings
        patches, cls_seg_feat = x[:, : -self.class_n], x[:, -self.class_n :]
        patches = patches @ self.proj_patch
        cls_seg_feat = cls_seg_feat @ self.proj_classes

        # Perform L2 Normalizations over the two tensors
        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        # 1. Calculate patch level class scores(as per dot product) by between the normalized patch tensors and the normalized class embeddings
        # 2. Reshape the output from (batch,number of patches, classes) to (batch size, classes, height, width)
        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks)
        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))

        return masks       
    
     # Load pre-trained weights method
    def load_pretrained_weights(self):
        map_dict = self.generate_mapping_dict()
        pretrained = torch.load('../pre-trained/pretrained_mask_trans_2.pth')
        model_state_dict = self.state_dict()
        
    
        # create new state dict with mapped keys
        new_state_dict = {}
        for key in pretrained['model'].keys():
            if key in map_dict:
                new_state_dict[map_dict[key]] = pretrained['model'][key]
            else:
                if key in model_state_dict:
                    new_state_dict[key] = pretrained['model'][key]
        
        
        # Load the mapped weights into our ViT model
        self.load_state_dict(new_state_dict, strict= False)
        self.freeze_transformer_layers(list(map_dict.values()))
        print('Succesfully created Mask Transformer with pre-trained weights...!')
        print('Froze transformer layers for training..!')
        
    def freeze_transformer_layers(self, transformer_layers_dict):
        for name, param in self.named_parameters():
            if name in transformer_layers_dict:
                param.requires_grad = False
                
    # Generate the mapping dict renaming the pretrained weights layers names to the desired format
    def generate_mapping_dict(self):
        mapping = {}
        for i in range(2):
            prefix = f'decoder.blocks.{i}.'

            mapping[f'{prefix}norm1.bias'] = f'transformer_blocks.{i}.norm1.bias'
            mapping[f'{prefix}norm1.weight'] = f'transformer_blocks.{i}.norm1.weight'
            mapping[f'{prefix}norm2.bias'] = f'transformer_blocks.{i}.norm2.bias'
            mapping[f'{prefix}norm2.weight'] = f'transformer_blocks.{i}.norm2.weight'
            mapping[f'{prefix}mlp.fc1.bias'] = f'transformer_blocks.{i}.linear1.bias'
            mapping[f'{prefix}mlp.fc1.weight'] = f'transformer_blocks.{i}.linear1.weight'
            mapping[f'{prefix}mlp.fc2.bias'] = f'transformer_blocks.{i}.linear2.bias' 
            mapping[f'{prefix}mlp.fc2.weight'] = f'transformer_blocks.{i}.linear2.weight'
            mapping[f'{prefix}attn.proj.bias'] = f'transformer_blocks.{i}.self_attn.out_proj.bias'
            mapping[f'{prefix}attn.proj.weight'] = f'transformer_blocks.{i}.self_attn.out_proj.weight'
            mapping[f'{prefix}attn.qkv.bias'] = f'transformer_blocks.{i}.self_attn.in_proj_bias'
            mapping[f'{prefix}attn.qkv.weight'] = f'transformer_blocks.{i}.self_attn.in_proj_weight'
        return mapping

In [24]:
model = MaskTransformer(image_size=(640,640),n_classes=1, dim_dec= 768, depth= 2, pre_train=True)
count_parameters(model)

Succesfully created Mask Transformer with pre-trained weights...!
Froze transformer layers for training..!


1772546

In [7]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated = os.listdir(annotated_dir)

# Get path directories for clips and annotations for the TUSimple dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (640,640), subset_size = 0.002, val_size= 0.2)

# Create train and validation splits / Always use del dataset to free memory after this
train_set, validation_set = dataset.train_val_split()
del dataset


In [8]:
from vit import ViT
import utils

encoder = ViT(image_size=640, patch_size=16, num_classes=1, dim=768, depth=12, heads=12, 
            mlp_dim=3072, dropout=0.1,load_pre= True, pre_trained_path= '../pre-trained/jx_vit_base_p16_224-80ecf9dd.pth')

Succesfully created ViT with pre-trained weights...!


In [9]:
img,gt = train_set[0]

torch.Size([3, 640, 640])


In [10]:
sample = encoder(img.unsqueeze(0))
sample.shape

torch.Size([1, 1600, 768])

In [11]:
mask = model(sample)
mask.shape

torch.Size([1, 1, 40, 40])