In [None]:
import os
import json
from PIL import Image
import numpy as np
import pandas as pd
import cv2
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.optim as optim
from torch.functional import F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from safetensors.torch import save_file, load_file
from sklearn.preprocessing import LabelEncoder
from decord import VideoReader, cpu
from tqdm import tqdm

In [None]:
class TrainDataset(Dataset):
    def __init__(self, vid_dir, targets, targets_encoder=None, disc_freq=5):
        self.videos_names = []
        self.videos = []
        self.videos_len = []
        self.fps = []
        self.videos_samples_len = []
        
        # Preload video metadata and prepare decoders
        for filename in os.listdir(vid_dir):
            if not filename.endswith('.txt'):
                filepath = os.path.join(vid_dir, filename)
                self.videos_names.append(filename)

                with open(filepath, 'rb') as f:
                    video_reader = VideoReader(f, ctx=cpu(0))

                self.videos.append(video_reader)
                self.videos_len.append(len(video_reader))
                self.fps.append(video_reader.get_avg_fps())

                assert disc_freq <= self.fps[-1]
                self.videos_samples_len.append(int(self.videos_len[-1] / self.fps[-1] * disc_freq))
        
        self.disc_freq = disc_freq
        self.targets = targets

        # Encode targets if encoder provided
        if targets_encoder:
            self.targets_encoder = targets_encoder
            self.targets = self.targets_encoder.transform(self.targets)
        else:
            self.targets_encoder = LabelEncoder()
            self.targets['violation'] = self.targets_encoder.fit_transform(self.targets['violation'])

    def __len__(self):
        return sum(self.videos_samples_len)

    def __getitem__(self, idx):
        # Determine which video this idx falls into
        for vid_idx in range(len(self.videos)):
            if idx >= self.videos_samples_len[vid_idx]:
                idx -= self.videos_samples_len[vid_idx]
            else:
                break
        
        # Calculate the frame index based on disc_freq
        frame_sec = int(idx / self.disc_freq - 1e-8)
        frame_idx = int(idx / self.disc_freq * self.fps[vid_idx])
        
        # Use Decord to fetch the frame efficiently
        video_reader = self.videos[vid_idx]
        frame = video_reader[frame_idx].asnumpy()
        
        # Convert BGR to RGB
        frame = cv2.resize(frame, (736, 416), interpolation=cv2.INTER_CUBIC)   
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = Image.fromarray(frame)

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])

        frame = transform(frame)

        # Extract the label for the current frame
        label = self.targets['violation'][(self.targets['id'] == self.videos_names[vid_idx].split('.')[0].lower()) 
                                          & (self.targets['time'] == frame_sec + 1)]
        assert len(label) == 1
        label = label.iloc[0]
        
        return frame, label

In [None]:
dataset = TrainDataset('../data/train_videos/', pd.read_csv('../data/train_targets.csv'), disc_freq=1/10)

In [None]:
loader = DataLoader(dataset, 16)
for i in loader:
    frame, label = i

In [None]:
dataset.videos = None