In [10]:
import os
import sys

from typing import List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

import nvidia.dali as dali
from nvidia.dali.plugin.pytorch import DALIGenericIterator

In [2]:
BASE_DIR = '/home/dmitry/projects/dfdc'
SRC_DIR = os.path.join(BASE_DIR, 'src')
DATA_DIR = os.path.join(BASE_DIR, 'data/dfdc-videos')

sys.path.insert(0, SRC_DIR)

In [3]:
from dataset.utils import read_labels

In [4]:
class VideoPipe(dali.pipeline.Pipeline):
    def __init__(self, filenames: List[str], seq_len=30, stride=10, 
                 batch_size=1, num_threads=1, device_id=0):
        super(VideoPipe, self).__init__(
            batch_size, num_threads, device_id, seed=3)
        self.input = dali.ops.VideoReader(
            device='gpu', filenames=filenames, 
            sequence_length=seq_len,
            shard_id=0, num_shards=1)

    def define_graph(self):
        output = self.input(name='reader')
        return output
    
    
def get_file_list(df: pd.DataFrame, start: int, end: int, 
                  base_dir:str=DATA_DIR) -> List[str]:
    path_fn = lambda row: os.path.join(base_dir, row.dir, row.name)
    return df.iloc[start:end].apply(path_fn, axis=1).values.tolist()


def build_data_iter(files: List[str]):
    pipe = VideoPipe(files)
    pipe.build()
    return DALIGenericIterator([pipe], ['images'], len(files))

In [6]:
df = read_labels(DATA_DIR)

In [7]:
files = get_file_list(df, 0, 100)
dali_iter = build_data_iter(files)

In [8]:
for i, data in enumerate(dali_iter):
    for d in data:
        image = d['images']
    break

In [9]:
image.shape, image.device

(torch.Size([1, 30, 1080, 1920, 3]), device(type='cuda', index=0))

In [22]:
def prepare_imgs(sample):
    n, h, w, c = sample.shape
    
    imgs = sample.float()
    imgs -= torch.tensor([104, 117, 123], device=imgs.device)
    imgs = imgs.permute(0, 3, 1, 2)

    scale = torch.tensor([w, h, w, h])
    return imgs, scale

In [25]:
prepare_imgs(image[0])[0].shape

torch.Size([30, 3, 1080, 1920])

In [21]:
torch.rand(1,2,3,4).transpose(3,1).shape

torch.Size([1, 4, 3, 2])

In [None]:
torch.