In [1]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from mask_transformer import MaskTransformer
from vit import ViT

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated = os.listdir(annotated_dir)
    
# Get path directories for clips and annotations for the TUSimple dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

In [3]:
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (640,640), subset_size = 0.02,val_size= 0.1)

# Create train and validation splits / Always use del dataset to free memory after this
train_set, validation_set = dataset.train_val_split()
del dataset

In [6]:
# Segmenter pipeline class (ViT + Masks transformer end-to-end)
class Segmenter(nn.Module):
    def __init__(self,encoder, mask_trans, image_size = (640,640)):
        super().__init__()
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = mask_trans
        self.image_size = image_size
        
    # Forward pass of the pipeline
    def forward(self, im):
        H, W = self.image_size
        
        # Pass through the pre-trained vit backbone
        x = self.encoder(im, return_features=True)
        
        # Pass through the masks transformer
        masks = self.decoder(x)

        # Interpolate patch level class annotatations to pixel level and transform to original image size
        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        predicted_masks = torch.argmax(masks, dim=1).float()
        
        # expand the output mask tensor along the channel dimension to match the ground truth tensor (maybe needs removal)
        predicted_masks = predicted_masks.unsqueeze(0).expand(3, -1, -1, -1).transpose(0, 1).squeeze(0)

        return predicted_masks

In [7]:
encoder = ViT(image_size=640, patch_size=16, num_classes=10, dim=768, depth=12, heads=12, 
            mlp_dim=3072, dropout=0.1, load_pre= True, pre_trained_path= '../pre-trained/jx_vit_base_p16_224-80ecf9dd.pth')
decoder = MaskTransformer()
segm = Segmenter(encoder, decoder)
test = segm(train_set[0][0])

Succesfully created ViT with pre-trained weights...!


In [8]:
test.shape

torch.Size([3, 640, 640])