In [3]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
from torchvision.models import ViT_B_16_Weights
from torch.utils.data import DataLoader
import matplotlib as plt
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchmetrics import F1Score,JaccardIndex
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from mlp_decoder import DecoderMLP
import segnet_backbone as cnn
import utils

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

In [4]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated_dir_test = os.path.join(root_dir,'datasets/tusimple/test_set/annotations/')
test_clips_dir = os.path.join(root_dir,'datasets/tusimple/test_set/')


annotated = os.listdir(annotated_dir)

test_annotated = os.listdir(annotated_dir_test)

In [5]:
# Get path directories for clips and annotations for the TUSimple training  dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

# Load dataset / Calculate pos weight
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (448,448), subset_size = 0.001,val_size= 0.1)
train_set, validation_set = dataset.train_val_split()
del dataset

# Lane weight
pos_weight = utils.calculate_class_weight(train_set)

In [6]:
model = cnn.SegNet()
print(f'Number of trainable parameters : {model.count_parameters()}')
model.load_weights('../models/best_segnet.pth')

Number of trainable parameters : 29443587
Loaded state dict succesfully!


In [7]:
img_tns,gt = validation_set[0]

model.eval()
test_pred, test_features = model.predict(img_tns.unsqueeze(0))


In [8]:
test_features.shape

torch.Size([1, 64, 448, 448])

In [37]:
# Channel Attention module
class TransformerChannelAttention(nn.Module):
    def __init__(self, dim, heads, mlp_dim, depth, in_channels):
        super(TransformerChannelAttention, self).__init__()
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_channels, in_channels)
        
        # define the pooling layer to reduce spatial dimensions
        self.pool = nn.MaxPool2d(kernel_size=4, stride=4)
        # define the unpooling layer to recover original spatial dimensions
        self.unpool = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        
    def forward(self, x):
        batch_size, num_channels, height, width = x.shape
        
        # Pass through the pooling layer to reduce spatial dimensions for computational reasons
        x = self.pool(x)
        height = height//4
        width = width//4

        
        # Reformat the inputs to the appropriate format for the transformer
        x = x.view(height*width, batch_size, num_channels)
        x = x.permute(1, 0, 2)   # shape: (batch_size, seq_length, num_channels)

        # Perform channel wise attention to the reduced feature maps
        x = self.transformer(x)  # shape: (batch_size, seq_length, num_channels)
        x = x.permute(1, 0, 2)  # swap back seq_length and batch_size dimensions
        x = x.view(batch_size, num_channels, height, width)  # shape: (batch_size, num_channels, height, width)
        
        # Reshape back to original spatial dimensions
        x = self.unpool(x)
        height,width = height * 4, width*4
        
        # Average pooling to get one value per channel and calculate channel attention weights 
        attn = self.avg_pool(x)  # shape: (batch_size, num_channels, 1, 1)
        attn = attn.view(batch_size, num_channels)  # shape: (batch_size, num_channels)
        attn = self.fc(attn)  # shape: (batch_size, num_channels)
        attn = attn.unsqueeze(-1).unsqueeze(-1)  # shape: (batch_size, num_channels, 1, 1)
        
        # Weight the transformed features based on channel attention and keep best features
        trans_features = (x * attn).sum(dim=1, keepdim=True)  # shape: (batch_size, 1, height, width)

        # Pass through sigmoid to get attention-weighted probabilities
        mask_probs = torch.sigmoid(trans_features)
        return trans_features, mask_probs

In [38]:
ca_module = TransformerChannelAttention(dim = 64 ,heads=8,mlp_dim=2048,depth=3,in_channels=64)

In [39]:
test = ca_module(test_features)

In [40]:
test.unique()

tensor([0.1033, 0.1091, 0.1092,  ..., 0.8657, 0.8695, 0.8696],
       grad_fn=<Unique2Backward0>)