In [3]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
from torchvision.models import ViT_B_16_Weights
from torch.utils.data import DataLoader
import matplotlib as plt
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchmetrics import F1Score,JaccardIndex
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from mlp_decoder import DecoderMLP
import segnet_backbone as cnn
import utils

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

In [4]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated_dir_test = os.path.join(root_dir,'datasets/tusimple/test_set/annotations/')
test_clips_dir = os.path.join(root_dir,'datasets/tusimple/test_set/')


annotated = os.listdir(annotated_dir)

test_annotated = os.listdir(annotated_dir_test)

In [5]:
# Get path directories for clips and annotations for the TUSimple training  dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

# Load dataset / Calculate pos weight
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (448,448), subset_size = 0.001,val_size= 0.1)
train_set, validation_set = dataset.train_val_split()
del dataset

# Lane weight
pos_weight = utils.calculate_class_weight(train_set)

In [6]:
model = cnn.SegNet()
print(f'Number of trainable parameters : {model.count_parameters()}')
model.load_weights('../models/best_segnet.pth')

Number of trainable parameters : 29443587
Loaded state dict succesfully!


In [7]:
img_tns,gt = validation_set[0]

print(img_tns.unsqueeze(0).shape)

model.eval()
test_pred, test_features = model.predict(img_tns.unsqueeze(0))


torch.Size([1, 3, 448, 448])


In [8]:
test_features.shape

torch.Size([1, 64, 448, 448])

In [9]:
# Channel Attention module
class TransformerChannelAttention(nn.Module):
    def __init__(self, dim, heads, mlp_dim, depth, in_channels):
        super(TransformerChannelAttention, self).__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, in_channels)
        
        # define the pooling layer to reduce spatial dimensions
        self.pool = nn.MaxPool2d(kernel_size=4, stride=4)
        # define the unpooling layer to recover original spatial dimensions
        self.unpool = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        
    def forward(self, x):
        batch_size, num_channels, height, width = x.shape
        
        # Pass through the pooling layer to reduce spatial dimensions for computational reasons
        x = self.pool(x)
        height = height//4
        width = width//4

        
        # Reformat the inputs to the appropriate format for the transformer
        x = x.reshape(height*width, batch_size, num_channels)
        x = x.permute(1, 0, 2)   # shape: (batch_size, seq_length, num_channels)

        # Perform channel wise attention to the reduced feature maps
        x = self.transformer(x)  # shape: (batch_size, seq_length, num_channels)
        x = x.permute(1, 0, 2)  # swap back seq_length and batch_size dimensions
        x = x.reshape(batch_size, num_channels, height, width)  # shape: (batch_size, num_channels, height, width)
        # Reshape back to original spatial dimensions
        x = self.unpool(x)
        height,width = height * 4, width*4
        
        # Average pooling to get one value per channel and calculate channel attention weights 
        attn = self.avg_pool(x)  # shape: (batch_size, num_channels, 1, 1)
        attn = attn.view(batch_size, num_channels)  # shape: (batch_size, num_channels)
        attn = self.fc(attn)  # shape: (batch_size, num_channels)
        attn = attn.unsqueeze(-1).unsqueeze(-1)  # shape: (batch_size, num_channels, 1, 1)
        
        # Weight the transformed features based on channel attention and keep best features
        trans_features = (x * attn).sum(dim=1, keepdim=True)  # shape: (batch_size, 1, height, width)

        # Pass through sigmoid to get attention-weighted probabilities
        mask_probs = torch.sigmoid(trans_features)
        return trans_features, mask_probs
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [10]:
ca_module = TransformerChannelAttention(dim = 64 ,heads=8,mlp_dim=2048,depth=3,in_channels=64)

In [11]:
logits,probs = ca_module(test_features)

tensor([[[[-0.0210]],

         [[ 0.0364]],

         [[-0.0477]],

         [[-0.0623]],

         [[ 0.0154]],

         [[ 0.1016]],

         [[-0.0029]],

         [[-0.0082]],

         [[ 0.1099]],

         [[ 0.0209]],

         [[ 0.0811]],

         [[ 0.0278]],

         [[-0.0859]],

         [[ 0.1085]],

         [[ 0.0083]],

         [[ 0.0291]],

         [[-0.0712]],

         [[ 0.0882]],

         [[ 0.1091]],

         [[-0.0731]],

         [[ 0.0474]],

         [[ 0.1186]],

         [[-0.0240]],

         [[ 0.0022]],

         [[ 0.0183]],

         [[ 0.0715]],

         [[-0.0595]],

         [[-0.1111]],

         [[-0.1076]],

         [[ 0.0951]],

         [[-0.0507]],

         [[-0.0190]],

         [[ 0.0089]],

         [[ 0.0918]],

         [[ 0.0891]],

         [[ 0.0026]],

         [[ 0.0657]],

         [[-0.0582]],

         [[ 0.0768]],

         [[ 0.1118]],

         [[-0.0830]],

         [[ 0.0653]],

         [[-0.0554]],

         [[

In [22]:
probs

tensor([[[[1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 2.6450e-12,
           8.7872e-13, 8.7872e-13],
          [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 2.6450e-12,
           8.7872e-13, 8.7872e-13],
          [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 1.5414e-10,
           7.2772e-11, 7.2772e-11],
          ...,
          [9.9799e-01, 9.9799e-01, 9.6293e-01,  ..., 9.9999e-01,
           1.0000e+00, 1.0000e+00],
          [6.3819e-01, 6.3819e-01, 1.9970e-01,  ..., 9.9999e-01,
           1.0000e+00, 1.0000e+00],
          [6.3819e-01, 6.3819e-01, 1.9970e-01,  ..., 9.9999e-01,
           1.0000e+00, 1.0000e+00]]]], grad_fn=<SigmoidBackward0>)

In [11]:
class attention_pipe(nn.Module):
    def __init__(self, feat_extractor, ca_module, image_size = (448,448)):
        super().__init__()
        self.cnn = feat_extractor
        self.ca_transformer = ca_module
        self.image_size = image_size
        self.lane_threshold = 0.5
        self.activation = nn.Sigmoid()
        
    # Forward pass of the pipeline
    def forward(self, im):
        H, W = self.image_size
        
        # CNN branch for feature extraction
        x,_ = self.cnn(im)
        
        # Transform standardized feature maps using the ViT
        logits, probs = self.ca_transformer(x)
        
        return logits, probs
        
        
    # Count pipeline trainable parameters
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    # Load trained model
    def load_weights(self,path): 
        self.load_state_dict(torch.load(path,map_location=torch.device('cpu')))

In [12]:
pipe = attention_pipe(model,ca_module)

In [13]:
pipe(img_tns.unsqueeze(0))

(tensor([[[[-1.0384, -1.0384, -0.9430,  ...,  0.0671,  0.1409,  0.1409],
           [-1.0384, -1.0384, -0.9430,  ...,  0.0671,  0.1409,  0.1409],
           [-0.8334, -0.8334, -0.7689,  ...,  0.2270,  0.3021,  0.3021],
           ...,
           [ 0.2861,  0.2861,  0.2942,  ..., -0.6637, -0.7097, -0.7097],
           [ 0.3824,  0.3824,  0.3845,  ..., -0.8456, -0.9243, -0.9243],
           [ 0.3824,  0.3824,  0.3845,  ..., -0.8456, -0.9243, -0.9243]]]],
        grad_fn=<SumBackward1>),
 tensor([[[[0.2615, 0.2615, 0.2803,  ..., 0.5168, 0.5352, 0.5352],
           [0.2615, 0.2615, 0.2803,  ..., 0.5168, 0.5352, 0.5352],
           [0.3029, 0.3029, 0.3167,  ..., 0.5565, 0.5750, 0.5750],
           ...,
           [0.5710, 0.5710, 0.5730,  ..., 0.3399, 0.3297, 0.3297],
           [0.5945, 0.5945, 0.5950,  ..., 0.3004, 0.2841, 0.2841],
           [0.5945, 0.5945, 0.5950,  ..., 0.3004, 0.2841, 0.2841]]]],
        grad_fn=<SigmoidBackward0>))