# Advanced Topics in Embodied Learning and Vision: Video Learning Demo
##### 2025-02-06, Chris Hoang

### Setup and imports

1. Create a new conda environment (remember to modify your Jupyter setup to activate this conda environment)
```
conda create -n "video" python=3.10.0
```
2. Install torch, torchvision, decord, ipykernel
```
pip install torch==2.2.0 torchvision==0.17 --index-url https://download.pytorch.org/whl/cu118
pip install decord
pip install ipykernel
```
3. Copy BDD100K example videos

```
scp /scratch/ch3451/evl/video-learning-demo/videos <YOUR_DIR>
```

In [3]:
import gc
import os
import time

from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as tvF
from decord import VideoReader, cpu

In [7]:
class BDD100KDataset(Dataset):
    def __init__(self,
                 root_dir,
                 delta_t,
                 repeat_sample=None,
                 crop_scale=None,
                 decode_resolution=None,
                 ):
        self.root_dir = root_dir
        self.repeat_sample = repeat_sample or 1
        self.delta_t = delta_t
        self.decode_resolution = decode_resolution

        self.video_paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir)])
        self._dataset_len = len(self.video_paths)

        # Define data augmentations
        def transform(x, y):
            if crop_scale is not None:
                random_crop = transforms.Compose([
                    transforms.RandomResizedCrop((224, 224), scale=(0.2, 0.4), interpolation=Image.BILINEAR),
                    transforms.ToTensor()
                ])
                x1 = random_crop(x)
                x2 = random_crop(y)
                y1 = random_crop(y)
                y2 = random_crop(y)
                return [x1, x2, y1, y2]
            else:
                return [tvF.to_tensor(x), tvF.to_tensor(y)]
        self.transform = transform
       
    def __len__(self):
        return self._dataset_len
    
    def __getitem__(self, idx):
        # Retrieve CPU worker to use for Decord VideoReader
        worker_info = torch.utils.data.get_worker_info()
        cpuid = 0 if worker_info == None else int(worker_info.id)

        # Potentially decode into lower resolution
        if self.decode_resolution is not None:
            h, w = self.decode_resolution
            vr = VideoReader(self.video_paths[idx], num_threads=0, ctx=cpu(cpuid), width=w, height=h)
        else:
            vr = VideoReader(self.video_paths[idx], num_threads=0, ctx=cpu(cpuid))
        vr_len = len(vr)

        # Get random frame indices as well as future frame indices to decode
        i_s = np.random.randint(0, vr_len - self.delta_t[1], size=self.repeat_sample)
        delta_ts = np.random.randint(self.delta_t[0], self.delta_t[1]+1, size=self.repeat_sample)
        i_s = np.array([index for i, delta_t in zip(i_s, delta_ts) for index in [i, i+delta_t]])

        # Sort frame indices to decode frame in-order, which is faster
        sort_indexes = np.argsort(i_s).astype(np.int32)
        unsort_indexes = np.argsort(sort_indexes).astype(np.int32)

        try:
            imgs = vr.get_batch(list(i_s[sort_indexes])).asnumpy()[unsort_indexes]
            del vr; gc.collect()

            ls = []
            # Augment pairs of frames at a time
            for i in range(self.repeat_sample):
                i1, i2 = i*2, i*2+1
                img1 = tvF.to_pil_image(imgs[i1])
                img2 = tvF.to_pil_image(imgs[i2])
                aug_img = self.transform(img1, img2)
                if len(ls) == 0:
                    ls = [[] for _ in range(len(aug_img))]
                for j in range(len(aug_img)):
                    ls[j].append(aug_img[j])

            return [torch.stack(l, dim=0) for l in ls]
        except Exception as e:
            # If failure, decode from a different randomly sampled video
            print(f"Error reading video {self.video_paths[idx]}: {e}")
            return self.__getitem__(np.random.randint(0, self._dataset_len))

#### Sample data from BDD videos, high resolution

In [8]:
dataset = BDD100KDataset(
    '/scratch/ch3451/evl/video-learning-demo/videos',
    delta_t=[5,5],
    repeat_sample=4,
    crop_scale=None
)

In [9]:
example = dataset[-1]
example = torch.cat(example, dim=0)
grid = torchvision.utils.make_grid((example * 255.).type(torch.uint8), nrow=2, padding=2, normalize=False, pad_value=0)
torchvision.io.write_png(grid, 'high-resolution.png')

#### Using no CPU workers for dataloading

In [10]:
sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
    dataset,
    sampler=sampler,
    batch_size=1,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

In [11]:
start_time = time.time()
for batch in data_loader:
    continue
end_time = time.time()
print(f'Total time: {end_time - start_time}')

Total time: 12.720113277435303


#### Using 12 CPU workers to speed up dataloading

In [12]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    sampler=sampler,
    batch_size=1,
    num_workers=12,
    pin_memory=True,
    drop_last=True,
)

In [13]:
start_time = time.time()
for batch in data_loader:
    continue
end_time = time.time()
print(f'Total time: {end_time - start_time}')

Total time: 2.587550401687622


#### Decoding to a lower resolution to speed up dataloading

In [14]:
dataset = BDD100KDataset(
    '/scratch/ch3451/evl/video-learning-demo/videos',
    delta_t=[5,5],
    repeat_sample=4,
    decode_resolution=(180,320),
)

In [15]:
sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
    dataset,
    sampler=sampler,
    batch_size=1,
    num_workers=12,
    pin_memory=True,
    drop_last=True,
)

In [16]:
start_time = time.time()
for batch in data_loader:
    continue
end_time = time.time()
print(f'Total time: {end_time - start_time}')

Total time: 1.5015265941619873


In [10]:
example = dataset[-1]
example = torch.cat(example, dim=0)
grid = torchvision.utils.make_grid((example * 255.).type(torch.uint8), nrow=4, padding=2, normalize=False, pad_value=0)
torchvision.io.write_png(grid, 'low-resolution.png')

#### Using RandomResizedCrop data augmentations

In [11]:
dataset = BDD100KDataset(
    '/scratch/ch3451/evl/video-learning-demo/videos',
    delta_t=[5,5],
    repeat_sample=4,
    crop_scale=(0.2, 0.4)
)

In [12]:
example = dataset[-1]
example = torch.cat(example, dim=0)
grid = torchvision.utils.make_grid((example * 255.).type(torch.uint8), nrow=4, padding=2, normalize=False, pad_value=0)
torchvision.io.write_png(grid, 'random-crop.png')

### Additional resources

decord: https://github.com/dmlc/decord/blob/master/examples/video_reader.ipynb