In [1]:
import torch
from flytracker.utils import FourArenasQRCodeMask
from torch.utils.data import DataLoader
from itertools import takewhile
import matplotlib.pyplot as plt
from torchvision.transforms.functional import rgb_to_grayscale, to_tensor
import cv2 as cv
from flytracker.tracking import blob_detector_localization
import numpy as np

from flytracker.tracking import kmeans_torch

%load_ext autoreload
%autoreload 2

In [2]:
torch.cuda.is_available()

True

In [3]:
class VideoDataset(torch.utils.data.IterableDataset):
    def __init__(self, path, mask):
        super().__init__()
        self.capture = cv.VideoCapture(path)
        self.mask = torch.tensor(mask, dtype=torch.bool)
        
    def __iter__(self):
        return self

    def __next__(self) -> torch.Tensor:
        # Loading image
        succes, image = self.capture.read()
        if succes is False:
            raise StopIteration
            
        image = torch.tensor(image)  
        image = torch.movedim(torch.tensor(image), -1, 0) # first axis needs to be channels
        image = rgb_to_grayscale(image).squeeze() 
        image = torch.where(self.mask, image, torch.tensor(255, dtype=torch.uint8))
        return image

In [6]:
mask = FourArenasQRCodeMask().mask
path = "/home/gert-jan/Documents/flyTracker/data/movies/4arenas_QR.h264"

dataset = VideoDataset(path, mask)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)

In [7]:
%%time
for batch_idx, batch in enumerate(loader):
    batch = batch.cuda(non_blocking=True)
    if batch_idx % 100 == 0:
        print(f"Loaded {batch_idx}, {batch.device}")
    if batch_idx == 1000:
        break

  image = torch.movedim(torch.tensor(image), -1, 0) # first axis needs to be channels


Loaded 0, cuda:0
Loaded 100, cuda:0
Loaded 200, cuda:0
Loaded 300, cuda:0
Loaded 400, cuda:0
Loaded 500, cuda:0
Loaded 600, cuda:0
Loaded 700, cuda:0
Loaded 800, cuda:0
Loaded 900, cuda:0
Loaded 1000, cuda:0
CPU times: user 59.9 s, sys: 75.7 ms, total: 1min
Wall time: 3.54 s


In [11]:
def initialize(loader, n_frames=100):
    """Find flies using blob detector and
    calulate number of flies."""
    n_blobs = []
    
    for frame_idx, frame in enumerate(loader):
        locations = blob_detector_localization(frame.numpy().squeeze())
        n_blobs.append(locations.shape[0])
        if len(n_blobs) >= n_frames:
            n_flies = int(np.median(n_blobs))
            if n_blobs[-1] == n_flies:
                break
    # pluse on cause the next one is the first
    initial_frame = frame_idx + 1
    locations = torch.tensor(locations[:, [1, 0]], dtype=torch.float32).to('cuda')
    return n_flies, locations, initial_frame

In [12]:
def localize(loader, init, n_frames=900, threshold=120):
    """Find flies using blob detector and
    calulate number of flies."""
    data = [init]
    for frame_idx, frame in enumerate(loader):
        frame = frame.cuda(non_blocking=True)
        fly_pixels = torch.nonzero(frame.squeeze() < threshold).type(torch.float32)
        data.append(kmeans_torch(fly_pixels, data[-1]))
        
        if frame_idx == n_frames:
            break
    return data

In [14]:
%%time
mask = FourArenasQRCodeMask().mask
path = "/home/gert-jan/Documents/flyTracker/data/movies/4arenas_QR.h264"

dataset = VideoDataset(path, mask)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)
n_flies, initial_locations, initial_frame = initialize(loader, 100)
locations = localize(loader, initial_locations, )

  image = torch.movedim(torch.tensor(image), -1, 0) # first axis needs to be channels


CPU times: user 1min 30s, sys: 114 ms, total: 1min 30s
Wall time: 5.48 s


In [21]:
from flytracker.tracker import run
from itertools import accumulate
from flytracker.tracking import hungarian

In [34]:

len(list(accumulate(locations, func=lambda x, y: hungarian(x.cpu(), y.cpu()))))

902

In [40]:
locs = torch.stack(locations, axis=0).cpu().numpy()

In [37]:
locs.shape

(902, 40, 2)

In [43]:
%%time
np.concatenate(list(accumulate(locs, func=lambda x, y: hungarian(y, x))), axis=0)

CPU times: user 56.9 ms, sys: 12 µs, total: 56.9 ms
Wall time: 56.5 ms


array([[890.16785, 414.62177],
       [872.3016 , 474.09927],
       [862.12463, 458.7108 ],
       ...,
       [205.92154, 565.0982 ],
       [310.24493, 290.38773],
       [214.65854, 546.1219 ]], dtype=float32)

In [None]:
lambda x, y: hungarian(x.cpu(), y.cpu()), 