In [1]:
import torch
import torchvision.transforms.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import tqdm
from decord import VideoReader, cpu

In [2]:
class TrainDataset(Dataset):
    def __init__(self, vid_dir, targets, targets_encoder=None, transform=None, disc_freq=10):
        self.videos_names = []
        self.videos = []
        self.videos_len = []
        self.fps = []
        self.videos_samples_len = []
        
        # Preload video metadata and prepare decoders
        for filename in os.listdir(vid_dir):
            if not filename.endswith('.txt'):
                filepath = os.path.join(vid_dir, filename)
                self.videos_names.append(filename)
                video_reader = VideoReader(filepath, ctx=cpu(0))  # Decord VideoReader for fast access
                self.videos.append(video_reader)
                self.videos_len.append(len(video_reader))
                self.fps.append(video_reader.get_avg_fps())
                assert disc_freq <= self.fps[-1]
                self.videos_samples_len.append(int(self.videos_len[-1] / self.fps[-1] * disc_freq))
        
        self.disc_freq = disc_freq
        self.targets = targets
        self.transform = transform

        # Encode targets if encoder provided
        if targets_encoder:
            self.targets_encoder = targets_encoder
            self.targets = self.targets_encoder.transform(self.targets)
        else:
            self.targets_encoder = LabelEncoder()
            self.targets['violation'] = self.targets_encoder.fit_transform(self.targets['violation'])

    def __len__(self):
        return sum(self.videos_samples_len)

    def __getitem__(self, idx):
        # Determine which video this idx falls into
        for vid_idx in range(len(self.videos)):
            if idx >= self.videos_samples_len[vid_idx]:
                idx -= self.videos_samples_len[vid_idx]
            else:
                break
        
        # Calculate the frame index based on disc_freq
        frame_sec = int(idx / self.disc_freq - 1e-8)
        frame_idx = int(idx / self.disc_freq * self.fps[vid_idx])
        
        # Use Decord to fetch the frame efficiently
        video_reader = self.videos[vid_idx]
        frame = video_reader[frame_idx].asnumpy()
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        frame = torch.tensor(frame)
        
        # Apply transformations if specified
        if self.transform:
            frame = self.transform(frame)
        
        # Extract the label for the current frame
        label = self.targets['violation'][(self.targets['id'] == self.videos_names[vid_idx].split('.')[0].lower()) & (self.targets['time'] == frame_sec + 1)]
        assert len(label) == 1
        label = label.iloc[0]

        return frame, label

In [3]:
class CustomTransform:
    def __init__(self, size=(224, 224), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
        self.size = size
        self.mean = torch.tensor(mean).view(3, 1, 1)
        self.std = torch.tensor(std).view(3, 1, 1)

    def __call__(self, tensor):
        # Изменение порядка осей с (H, W, C) на (C, H, W)
        tensor = tensor.permute(2, 0, 1)
        
        # Resize tensor
        tensor = torch.nn.functional.interpolate(tensor.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False).squeeze(0)

        # Normalize tensor
        tensor = (tensor - self.mean) / self.std

        return tensor

In [4]:
train_transform = CustomTransform(size=(224, 224))
dataset = TrainDataset('../data/train_videos/', pd.read_csv('../data/train_targets.csv'), disc_freq=5, transform=train_transform)

In [5]:
print(len(dataset))
loader = DataLoader(dataset, 16)
for i in tqdm.tqdm(loader):
    pass

15000


100%|██████████| 938/938 [03:10<00:00,  4.92it/s]


In [6]:
class TestDataset(Dataset):
    def __init__(self, vid_dir, transform=None, disc_freq=10):
        self.videos_names = []
        self.videos = []
        self.videos_len = []
        self.fps = []
        self.videos_samples_len = []
        
        # Preload video metadata and prepare decoders
        for filename in os.listdir(vid_dir):
            if not filename.endswith('.txt'):
                filepath = os.path.join(vid_dir, filename)
                self.videos_names.append(filename)
                video_reader = VideoReader(filepath, ctx=cpu(0))  # Decord VideoReader for fast access
                self.videos.append(video_reader)
                self.videos_len.append(len(video_reader))
                self.fps.append(video_reader.get_avg_fps())
                assert disc_freq <= self.fps[-1]
                self.videos_samples_len.append(int(self.videos_len[-1] / self.fps[-1] * disc_freq))
        
        self.disc_freq = disc_freq
        self.transform = transform

    def __len__(self):
        return sum(self.videos_samples_len)

    def __getitem__(self, idx):
        # Determine which video this idx falls into
        for vid_idx in range(len(self.videos)):
            if idx >= self.videos_samples_len[vid_idx]:
                idx -= self.videos_samples_len[vid_idx]
            else:
                break
        
        # Calculate the frame index based on disc_freq
        frame_sec = int(idx / self.disc_freq - 1e-8)
        frame_idx = int(idx / self.disc_freq * self.fps[vid_idx])
        
        # Use Decord to fetch the frame efficiently
        video_reader = self.videos[vid_idx]
        frame = video_reader[frame_idx].asnumpy()
        
        # Convert BGR to RGB
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        frame = torch.tensor(frame)
        
        # Apply transformations if specified
        if self.transform:
            frame = self.transform(frame)

        return frame, (self.videos_names[vid_idx], frame_sec)

In [7]:
test_transform = CustomTransform(size=(224, 224))
dataset = TestDataset('../data/val_videos/', disc_freq=5, transform=test_transform)

In [8]:
print(len(dataset))
loader = DataLoader(dataset, 16)
for i in tqdm.tqdm(loader):
    pass

6000


100%|██████████| 375/375 [01:15<00:00,  4.97it/s]
