In [6]:
import torch
from flytracker.utils import FourArenasQRCodeMask
from torch.utils.data import DataLoader
from itertools import takewhile
import matplotlib.pyplot as plt
from torchvision.transforms.functional import rgb_to_grayscale, to_tensor
import cv2 as cv
from flytracker.tracking import blob_detector_localization
import numpy as np

from flytracker.tracking import kmeans_torch
from torchvision.io import VideoReader
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
torch.cuda.is_available()

True

In [11]:
class VideoDataset(torch.utils.data.IterableDataset):
    def __init__(self, path, mask):
        super().__init__()
        self.capture = cv.VideoCapture(path)
        self.mask = torch.tensor(mask, dtype=torch.bool)
        
    def __iter__(self):
        return self

    def __next__(self) -> torch.Tensor:
        # Loading image
        succes, image = self.capture.read()
        if succes is False:
            raise StopIteration
            
        image = torch.tensor(image)  
        image = torch.movedim(torch.tensor(image), -1, 0) # first axis needs to be channels
        image = rgb_to_grayscale(image).squeeze() 
        image = torch.where(self.mask, image, torch.tensor(255, dtype=torch.uint8))
        return image

In [12]:
mask = FourArenasQRCodeMask().mask
path = "/home/gert-jan/Documents/flyTracker/data/movies/4arenas_QR.h264"

dataset = VideoDataset(path, mask)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)

In [13]:
%%time
for batch_idx, batch in enumerate(loader):
    batch = batch.cuda(non_blocking=True)
    if batch_idx % 100 == 0:
        print(f"Loaded {batch_idx}, {batch.device}")
    if batch_idx == 1000:
        break

  image = torch.movedim(torch.tensor(image), -1, 0) # first axis needs to be channels


Loaded 0, cuda:0
Loaded 100, cuda:0
Loaded 200, cuda:0
Loaded 300, cuda:0
Loaded 400, cuda:0
Loaded 500, cuda:0
Loaded 600, cuda:0
Loaded 700, cuda:0
Loaded 800, cuda:0
Loaded 900, cuda:0
Loaded 1000, cuda:0
CPU times: user 48 s, sys: 67.7 ms, total: 48 s
Wall time: 3.16 s


In [77]:
loader.dataset.mask

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [33]:
batch.shape

torch.Size([1, 1080, 1280])

In [63]:
reader = VideoReader(path)

In [64]:
reader.get_metadata()

{'video': {'duration': [-7686143364045.646], 'fps': [25.0]},
 'audio': {'duration': [], 'framerate': []}}

In [65]:
reader.next()

AttributeError: 'VideoReader' object has no attribute 'next'

In [8]:
for idx, frame in enumerate(reader):
    print(frame)
    if idx ==1:
        break

{'data': tensor([[[255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         ...,
         [251, 251, 251,  ..., 239, 239, 239],
         [246, 246, 246,  ..., 228, 228, 228],
         [255, 255, 255,  ..., 253, 253, 253]],

        [[255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         ...,
         [255, 255, 255,  ..., 239, 239, 239],
         [248, 248, 248,  ..., 228, 228, 228],
         [255, 255, 255,  ..., 253, 253, 253]],

        [[255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         [255, 255, 255,  ..., 255, 255, 255],
         ...,
         [250, 250, 250,  ..., 236, 236, 236],
         [245, 245, 245,  ..., 225, 225, 225],
         [254, 254, 254,  ..., 250, 250, 250]]], dtype=torch.uint8), 'pts': 0.0}
{'data': tensor([[[255, 255, 255,  ..., 255, 255, 255],
         

In [15]:
frame['data'].shape

torch.Size([3, 1080, 1280])

In [13]:
frame['pts']

0.04

In [18]:
next(reader)

{'data': tensor([[[255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          ...,
          [249, 249, 249,  ..., 238, 238, 238],
          [243, 243, 243,  ..., 227, 227, 227],
          [253, 253, 253,  ..., 252, 252, 252]],
 
         [[255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          ...,
          [254, 254, 254,  ..., 238, 238, 238],
          [245, 245, 245,  ..., 227, 227, 227],
          [255, 255, 255,  ..., 252, 252, 252]],
 
         [[255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          [255, 255, 255,  ..., 255, 255, 255],
          ...,
          [248, 248, 248,  ..., 235, 235, 235],
          [242, 242, 242,  ..., 224, 224, 224],
          [252, 252, 252,  ..., 249, 249, 249]]], dtype=torch.uint8),
 'pts': 0.16}