In [1]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json
import math
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple
from mask_transformer import MaskTransformer
from vit import ViT
import utils

  warn(f"Failed to load image Python extension: {e}")


In [2]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated = os.listdir(annotated_dir)
    
# Get path directories for clips and annotations for the TUSimple dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

In [9]:
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (640,640), subset_size = 0.01,val_size= 0.1)

# Create train and validation splits / Always use del dataset to free memory after this
train_set, validation_set = dataset.train_val_split()
del dataset

In [10]:
# Custom training function for the transformer pipeline with schedule and SGD optimizer
def train(model, train_loader, num_epochs=10, lr=0.01, momentum=0.9, weight_decay=1e-4, lr_scheduler=True):
    # Set up loss function and optimizer
    criterion =  nn.BCEWithLogitsLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    # Set up learning rate scheduler
    if lr_scheduler:
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    # Set up device (GPU or CPU)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Train the model
    for epoch in range(num_epochs):
        # Train for one epoch
        model.train()
        train_loss = 0
        train_acc = 0
        train_iou = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            
            inputs, targets = inputs.to(device), targets.to(device)
                   
            optimizer.zero_grad()
            outputs = model(inputs)
            outputs.requires_grad = True
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            
            train_loss += loss.item() * inputs.size(0)
            train_acc += utils.accuracy(pred = outputs.detach(), target = targets)
            train_iou += utils.mean_iou (pred= outputs.detach(), target = targets)
            
        train_loss /= len(train_loader)
        train_acc /= len(train_loader)
        train_iou /= len(train_loader)
        
     # Print progress
        if lr_scheduler:
            print('Epoch: {} - Train Loss: {:.4f} - Train_Acc: {:.3f} - Train_mIoU: {:.3f} - Learning Rate: {:.6f}'.format(epoch+1, train_loss, train_acc, train_iou, scheduler.get_last_lr()[0]))
            scheduler.step()
        else:
            print('Epoch: {} - Train Loss: {:.4f}'.format(epoch+1, train_loss))

In [11]:
# Segmenter pipeline class (ViT + Masks transformer end-to-end)
class Segmenter(nn.Module):
    def __init__(self,encoder, mask_trans, image_size = (640,640)):
        super().__init__()
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = mask_trans
        self.image_size = image_size
        
    # Forward pass of the pipeline
    def forward(self, im):
        H, W = self.image_size
        
        # Pass through the pre-trained vit backbone
        x = self.encoder(im, return_features=True)
        
        # Pass through the masks transformer
        masks = self.decoder(x)

        # Interpolate patch level class annotatations to pixel level and transform to original image size
        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        predicted_masks = torch.argmax(masks, dim=1).float()
        
        # expand the output mask tensor along the channel dimension to match the ground truth tensor (maybe needs removal)
        predicted_masks = predicted_masks.unsqueeze(0).expand(3, -1, -1, -1).transpose(0, 1).squeeze(0)

        return predicted_masks
    
    # Count pipeline trainable parameters
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    


In [12]:
# train_loader = DataLoader(train_set, batch_size=2, shuffle= True) 
# for batch_idx, (inputs, targets) in enumerate(train_loader):
#     print(targets.shape)
#     print(len(targets))
#     print(len(inputs))
#     print(inputs.shape)
#     print(utils.accuracy(targets,targets))
#     print(utils.mean_iou(targets,targets))
#     break

torch.Size([2, 3, 640, 640])
2
2
torch.Size([2, 3, 640, 640])
1.0
1.0


In [13]:
train_loader = DataLoader(train_set, batch_size=2, shuffle= True) 

encoder = ViT(image_size=640, patch_size=16, num_classes=2, dim=768, depth=12, heads=12, 
            mlp_dim=3072, dropout=0.1,load_pre= True, pre_trained_path= '../pre-trained/jx_vit_base_p16_224-80ecf9dd.pth')
encoder.freeze_all_but_some(['pos_embedding','norm.weight','norm.bias'])
decoder = MaskTransformer()
model = Segmenter(encoder, decoder)
print(model.count_parameters())


train(model, train_loader, 1)


Succesfully created ViT with pre-trained weights...!
14767108
Epoch: 1 - Train Loss: 1.7430 - Train_Acc: 0.712 - Train_mIoU: 0.018 - Learning Rate: 0.010000


In [22]:
model.eval()
img_tens, gt = validation_set[0]
test = model(img_tens)

In [23]:
test.shape

torch.Size([3, 640, 640])

In [24]:
predicted_mask = utils.toImagearr(test)
base_img = utils.toImagearr(img_tens)
utils.disp_img(image = base_img, name = 'Original Image')
utils.disp_img(image = predicted_mask, name = 'Predicted Mask')

In [21]:
test.unique()

tensor([0.])