In [1]:
import os
import json
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import segmentation_models_pytorch as smp
from ultralytics import YOLO
import torch
import torch.nn as nn
import torch.optim as optim
from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
from safetensors.torch import save_file, load_file
from sklearn.preprocessing import LabelEncoder
from decord import VideoReader, cpu
from tqdm import tqdm

In [2]:
# base paths
ROOT_DIR = os.path.join(os.getcwd(), os.pardir)
DATA_DIR = os.path.join(ROOT_DIR, 'data')
TRAIN_DIR = os.path.join(DATA_DIR, 'unet_segmentation_train', 'road_marking')
TRAIN_VIDEOS_DIR = os.path.join(DATA_DIR, 'train_videos')
TRAIN_TARGETS_PATH = os.path.join(DATA_DIR, 'train_targets.csv')
UNET_CHECKPOINT_PATH = os.path.join(ROOT_DIR, 'checkpoints', 'unet_segment_resnet50.safetensors')
YOLO_SIGNS_CHECKPOINT_PATH = os.path.join(ROOT_DIR, 'checkpoints', 'road-signs-yolov8n.pt')
YOLO_TRAFFIC_LIGHTS_CHECKPOINT_PATH = os.path.join(ROOT_DIR, 'checkpoints', 'traffic-lights-yolov8n.pt')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARN_FROM_ZERO = False

In [3]:
class TrainDataset(Dataset):
    def __init__(self, vid_dir, targets, width, height, mask_encoder, 
                 targets_encoder=None, disc_freq=5, device='cpu'):
        self.vid_dir = vid_dir
        self.videos_names = []
        self.videos_len = []
        self.fps = []
        self.videos_samples_len = []
        self.width = width
        self.height = height
        self.device=device
        self.mask_encoder = mask_encoder
        
        # Preload video metadata and prepare decoders
        for filename in os.listdir(vid_dir):
            if not filename.endswith('.txt'):
                filepath = os.path.join(vid_dir, filename)
                self.videos_names.append(filename)

                with open(filepath, 'rb') as f:
                    video_reader = VideoReader(f, ctx=cpu(0))

                self.videos_len.append(len(video_reader))
                self.fps.append(video_reader.get_avg_fps())

                assert disc_freq <= self.fps[-1]
                self.videos_samples_len.append(int(self.videos_len[-1] / self.fps[-1] * disc_freq))
        
        self.disc_freq = disc_freq
        self.targets = targets

        # Encode targets if encoder provided
        if targets_encoder:
            self.targets_encoder = targets_encoder
            self.targets = self.targets_encoder.transform(self.targets)
        else:
            self.targets_encoder = LabelEncoder()
            self.targets['violation'] = self.targets_encoder.fit_transform(self.targets['violation'])
        
        self.video = [None, None]

    def __len__(self):
        return sum(self.videos_samples_len)

    def __getitem__(self, idx):
        # Determine which video this idx falls into
        for vid_idx in range(len(self.videos_names)):
            if idx >= self.videos_samples_len[vid_idx]:
                idx -= self.videos_samples_len[vid_idx]
            else:
                break

        if vid_idx != self.video[0]:
            if self.video[1]:
                self.video[1].seek(0)
            with open(os.path.join(self.vid_dir, self.videos_names[vid_idx]), 'rb') as f:
                self.video = [vid_idx, VideoReader(f, ctx=cpu(0))]
        
        # Calculate the frame index based on disc_freq
        frame_sec = int(idx / self.disc_freq - 1e-8)
        frame_idx = int(idx / self.disc_freq * self.fps[vid_idx])
        
        # Use Decord to fetch the frame efficiently
        video_reader = self.video[1]
        frame = video_reader[frame_idx].asnumpy()
        
        # Preprocess image
        frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_CUBIC)   
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = Image.fromarray(frame)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

        frame = transform(frame)
        frame = frame.to(self.device)

        with torch.no_grad():
            mask_transform = transforms.Resize((164, 164), interpolation=transforms.InterpolationMode.BICUBIC)
            mask = self.mask_encoder(frame.unsqueeze(0))
            mask = mask.squeeze(1)
            mask = mask_transform(mask)
            mask = mask.to(torch.uint8)
            mask = mask.squeeze(0)
            mask = mask.to(torch.float32)

        # Extract the label for the current frame
        label = self.targets['violation'][(self.targets['id'] == self.videos_names[vid_idx].split('.')[0].lower()) 
                                          & (self.targets['time'] == frame_sec + 1)]
        assert len(label) == 1
        label = torch.tensor(label.iloc[0])
        
        return frame, mask, label.to(self.device)

In [4]:
# load unet roadline segmentation model
unet_state = load_file(UNET_CHECKPOINT_PATH)
unet_model = smp.Unet(
    encoder_name='resnet50',
    in_channels=3,
    classes=1
).to(DEVICE)
unet_model.load_state_dict(unet_state)

<All keys matched successfully>

In [5]:
# load yolo8 detection models
yolo_signs = YOLO(YOLO_SIGNS_CHECKPOINT_PATH)
yolo_traffic_lights = YOLO(YOLO_TRAFFIC_LIGHTS_CHECKPOINT_PATH)

In [6]:
# load backbones for frames and mask
frame_backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

In [7]:
class ViolationDetectionModel(nn.Module):
    def __init__(self, frame_backbone):
        super().__init__()
        self.frame_backbone = frame_backbone
        self.seq_frame_in = self._make_sequential(1000, 512)
        self.seq_mask_in = self._make_sequential(164 * 164, 2048)
        self.seq_combine = self._make_sequential(2048 + 512, 1024)
        self.fc1 = self._make_sequential(1024, 512)
        self.fc2 = self._make_sequential(512, 256)
        self.fc3 = self._make_sequential(256, 128)
        self.fc_out = self._make_sequential(128, 6)
        
    def forward(self, frame, mask):
        # flatten mask
        mask = torch.flatten(mask, start_dim=1)

        # pass frame and its mask through backbones
        frame_backbone_out = self.frame_backbone(frame)
        mask_out = self.seq_mask_in(mask)

        # combine outputs
        frame_out = self.seq_frame_in(frame_backbone_out)
        combined = self.seq_combine(torch.hstack([frame_out, mask_out]))

        # get class probs
        out = self.fc1(combined)
        out = self.fc2(out)
        out = self.fc3(out)
        out = self.fc_out(out)
        out = F.softmax(out, dim=-1)

        return out

    def _make_sequential(self, input_dim, output_dim):
        return nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.Dropout(p=0.3),
            nn.ReLU()
        )

In [8]:
targets = pd.read_csv(TRAIN_TARGETS_PATH)
dataset = TrainDataset(
    vid_dir=TRAIN_VIDEOS_DIR,
    targets=targets, 
    disc_freq=5,
    width=736,
    height=416,
    mask_encoder=unet_model,
    device=DEVICE
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=8
)

In [None]:
NUM_EPOCH = 10

model = ViolationDetectionModel(
    frame_backbone=frame_backbone
).to(DEVICE)

criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=2e-4)
for epoch in range(NUM_EPOCH):
    pbar = tqdm(dataloader)
    total_loss = 0
    accuracy = 0
    for i, batch in enumerate(pbar, start=1):
        frame, mask, target = batch
        out = model(frame, mask)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss
        accuracy += (torch.argmax(out, dim=-1) == target).sum()
        pbar.set_postfix({'loss': loss.item(), 'total_loss': total_loss.item() / i, 'acc': (accuracy.item() / 8) / i})

 56%|█████▌    | 1045/1875 [04:50<03:54,  3.55it/s, loss=-0.896, total_loss=-0.679, acc=0.848]