In [7]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader
from torchmetrics import F1Score,JaccardIndex
from torch_poly_lr_decay import PolynomialLRDecay
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from vit import ViT
from mlp_decoder import DecoderMLP
from segnet_backbone import SegNet

In [2]:
# End-to-End pipeline (CNN + ViT + MLP)
class Pipeline(nn.Module):
    def __init__(self, feat_extractor, vit_variant, mlp_head, image_size = (448,448)):
        super().__init__()
        self.cnn = feat_extractor
        self.transformer = vit_variant
        self.mlp = mlp_head
        self.image_size = image_size
        self.lane_threshold = 0.5
        self.activation = nn.Sigmoid()
        
    # Forward pass of the pipeline
    def forward(self, im):
        H, W = self.image_size
        
        # CNN branch for feature extraction
        x,_ = self.cnn(im)
        
        # Standardize featmaps to prevent exploding gradients and help the ViT perform better
        x = self.standarize_layer(x)
        
        # Transform standardized feature maps using the ViT
        x = self.transformer(x)
        
        # Perform patch level classification (0 for background/ 1 for lane)
        x = self.mlp(x)

        # Interpolate patch level class annotatations to pixel level and transform to original image size
        logits = F.interpolate(x, size=(H, W), mode="bilinear")
        
        return logits
        
        
    # Standardize feature maps from SegNet layer
    def standarize_layer(self,featmaps):
        # Compute mean and standard deviation of each channel
        mean = torch.mean(featmaps, dim=[0, 2, 3], keepdim=True)
        std = torch.std(featmaps, dim=[0, 2, 3], keepdim=True)

        # Normalize each channel to have zero mean and unit variance
        normal_featmaps = (featmaps - mean) / std
        return normal_featmaps
        
    # Count pipeline trainable parameters
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    # Load trained model
    def load_weights(self,path): 
        self.load_state_dict(torch.load(path,map_location=torch.device('cpu')))

In [5]:
test_img = torch.rand(1,3,448,448)
test_img.shape

torch.Size([1, 3, 448, 448])

In [8]:
# Initialize SegNet
cnn = SegNet()
print(f'Number of trainable parameters for SegNet : {cnn.count_parameters()}')

# Initialize ViT Tiny
vit_tiny = ViT(image_size=448, patch_size=16, num_classes=1, dim=192, depth=6, heads=3, 
                      mlp_dim=768, dropout=0.1,load_pre= False)
print(f'Number of trainable parameters for ViT : {vit_tiny.count_parameters()}')

# Initialize MLP
patch_classifier = DecoderMLP(n_classes = 1, d_encoder = 192, image_size=(448,448))
print(f'Number of trainable parameters for MLP : {patch_classifier.count_parameters()}')

Number of trainable parameters for SegNet : 29443587
Number of trainable parameters for ViT : 2869440
Number of trainable parameters for MLP : 233537


In [9]:
model = Pipeline(cnn, vit_tiny, patch_classifier, image_size= (448,448))
print(f'Number of trainable parameters for Pipeline : {model.count_parameters()}')

Number of trainable parameters for Pipeline : 32546564


In [10]:
test = model(test_img)
test.shape



torch.Size([1, 1, 448, 448])