In [2]:
# Imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import cv2
import os
import json

# Set seed for randomize functions (Ez reproduction of results)
random.seed(100)

# Import TuSimple loader
import sys
sys.path.insert(0,'../resources/')
from tusimple import TuSimple

  warn(f"Failed to load image Python extension: {e}")


In [3]:
# ROOT DIRECTORIES
root_dir = os.path.dirname(os.getcwd())
annotated_dir = os.path.join(root_dir,'datasets/tusimple/train_set/annotations')
clips_dir = os.path.join(root_dir,'datasets/tusimple/train_set/')
annotated = os.listdir(annotated_dir)
    
# Get path directories for clips and annotations for the TUSimple dataset + ground truth dictionary
annotations = list()
for gt_file in annotated:
    path = os.path.join(annotated_dir,gt_file)
    json_gt = [json.loads(line) for line in open(path)]
    annotations.append(json_gt)
    
annotations = [a for f in annotations for a in f]

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# Patch embedding class
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, embed_dim, channels):
        super().__init__()

        self.image_size = image_size
        if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0:
            raise ValueError("image dimensions must be divisible by the patch size")
        self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            channels, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, im):
        try:
            B, C, H, W = im.shape
        except:
            _, H, W = im.shape
        x = self.proj(im).flatten(2).transpose(1, 2)
        return x
    
    
    
# B-16 ViT Class
class ViT(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=1000, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_classes = num_classes
        self.dim = dim
        self.depth = depth
        self.heads = heads
        self.mlp_dim = mlp_dim
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(dim)

        # Calculate the number of patches
        self.num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2

        # Define the patch embedding layer
        # self.patch_embedding = nn.Conv2d(in_channels=3, out_channels=dim, kernel_size=patch_size, stride=patch_size)
        self.patch_embedding = PatchEmbedding((self.image_size,self.image_size),self.patch_size,self.dim, 3)
        
        
        # Define the positional embedding layer
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches, dim))
        self.pos_embedding = nn.init.trunc_normal_(self.pos_embedding,std= 0.02)

        # Define the transformer layers
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim, dropout=dropout),
            num_layers=depth
        )

        # Define the MLP head for classification
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x, return_features = True):
        # Apply the patch embedding layer
        
        x = self.patch_embedding(x)
        
        # Reshape the patches and get dimensions
        x = x.flatten(2).transpose(1, 2)
        
        # Resize positional embeddings
        if self.image_size != 224:
            resized_size = self.image_size
            self.resize_pos_embeds(resized_size)
            
        
        # Add the positional embeddings and use dropout
        x = (x.reshape(1, -1, 768) + self.pos_embedding)
        x = self.dropout(x)
        
        # Apply the transformer layers
        x = self.transformer(x)
        
        # Apply layer normalization before returning the transformed features 
        x = self.norm(x)
        
        if return_features:
            return x

        else: 
            # Apply the MLP head
        
            # Average over the sequence dimension (average over all transformed patch sequences) 
            # NOTE: Probably should be removed on segmentation tasks
            x = x.mean(dim=1)  
        
            # Classification head (Optional with added class embedding for classification tasks)
            x = self.mlp_head(x)

            return x

    # Resize pos embeddings functionality for tuning the ViT to accept resized images
    def resize_pos_embeds(self, new_image_size):
        # Get the original size of the positional embeddings
        orig_pos_embeds = self.pos_embedding

        # Calculate the number of patches for the new image size
        new_num_patches = (new_image_size // self.patch_size) ** 2

        # Define the new size of the positional embeddings based on the new number of patches
        new_embed_size = (new_num_patches, self.dim)  # Keep the same number of tokens
        new_pos_embeds = F.interpolate(orig_pos_embeds.unsqueeze(0), size=new_embed_size).squeeze(0)

        # Replace the original positional embeddings with the new ones
        self.pos_embedding = nn.Parameter(new_pos_embeds)
    

In [5]:
dataset = TuSimple(train_annotations = annotations, train_img_dir = clips_dir, resize_to = (640,640), subset_size = 0.05)
    
img_tns, gt = dataset[0]

In [6]:
# Instantiate a ViT model
model = ViT(image_size=640, patch_size=16, num_classes=10, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1)

In [7]:
test =model(img_tns)
test.shape

torch.Size([1, 1600, 768])