In [20]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from mask_transformer import MaskTransformer
from vit import ViT
import utils

In [2]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated = os.listdir(annotated_dir)
    
# Get path directories for clips and annotations for the TUSimple dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

In [3]:
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (640,640), subset_size = 0.01,val_size= 0.1)

# Create train and validation splits / Always use del dataset to free memory after this
train_set, validation_set = dataset.train_val_split()
del dataset

In [4]:
# Segmenter pipeline class (ViT + Masks transformer end-to-end)
class Segmenter(nn.Module):
    def __init__(self,encoder, mask_trans, image_size = (640,640)):
        super().__init__()
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = mask_trans
        self.image_size = image_size
        
    # Forward pass of the pipeline
    def forward(self, im):
        H, W = self.image_size
        
        # Pass through the pre-trained vit backbone
        x = self.encoder(im, return_features=True)
        
        # Pass through the masks transformer
        masks = self.decoder(x)

        # Interpolate patch level class annotatations to pixel level and transform to original image size
        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        predicted_masks = torch.argmax(masks, dim=1).float()
        
        # expand the output mask tensor along the channel dimension to match the ground truth tensor (maybe needs removal)
        predicted_masks = predicted_masks.unsqueeze(0).expand(3, -1, -1, -1).transpose(0, 1).squeeze(0)

        return predicted_masks

In [5]:
train_loader = DataLoader(train_set, batch_size=2, shuffle= True)

encoder = ViT(image_size=640, patch_size=16, num_classes=2, dim=768, depth=12, heads=12, 
            mlp_dim=3072, dropout=0.1,load_pre= True, pre_trained_path= '../pre-trained/jx_vit_base_p16_224-80ecf9dd.pth')
decoder = MaskTransformer()


Succesfully created ViT with pre-trained weights...!


In [6]:
model = Segmenter(encoder,decoder)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the Segmenter Transformer model and move it to the device
model = Segmenter(encoder, decoder).to(device)

# Define the loss function and the optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Loop through the training dataset and perform the training
for epoch in range(1):
    running_loss = 0.0
    for i, (inputs, targets) in enumerate(train_loader):
        # Move the input images and the target masks to the same device as the model
        inputs, targets = inputs.to(device), targets["gt_tensor"].to(device)
        
        # Zero the gradients of the optimizer
        optimizer.zero_grad()
        
        # Forward pass the input images through the model to get the predictions
        outputs = model(inputs)
        outputs.requires_grad = True
        
        # Compute the loss between the predictions and the target masks
        loss = criterion(outputs, targets)

        # Backward propagate the loss through the model to compute the gradients
        loss.backward()

        # Update the model parameters using the optimizer
        optimizer.step()
        
        # Update the running loss
        running_loss += loss.item() * inputs.size(0)
    
    #Print the average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{10} - Loss: {epoch_loss:.4f}")

Epoch 1/10 - Loss: 1.9964


In [8]:
model.eval()
test = model(train_set[0][0])

In [10]:
test.shape

torch.Size([3, 640, 640])

In [None]:
predicted_mask = utils.toImagearr(test)
base_img = utils.toImagearr(train_set[0][0])
utils.disp_img(image = base_img, name = 'Original Image')
utils.disp_img(image = predicted_mask, name = 'Predicted Mask')