In [2]:
import numpy as np
import cv2
from flytracker.annotating import (
    parse_data,
    setup_loader,
    setup_writer,
    add_frame_info,
    write_ID,
)

from itertools import count

In [3]:
# %% Settings
movie_loc = "../data/testing_data/bruno/seq_1.mp4"
output_loc = "annotated_video.mp4"
df_loc = "../tests/bruno/df_new.hdf"
mapping_folder = "../data/distortion_maps/"
touching_distance = 12

In [4]:
data = parse_data(df_loc)
initial_frame = data[0, 0, 0]
# plus 1 for intiial frame since we plot (n-1, n)
loader, image_size = setup_loader(
    movie_loc, mapping_folder, initial_frame=(initial_frame + 1)
)
writer = setup_writer(output_loc, image_size, fps=30)
mask = np.zeros((*image_size[::-1], 3), dtype=np.uint8)  # TODO: Check different shapes


max_frames = 1000
length = 1

In [11]:
[data[data[:, :, 4] == arena] for arena in np.unique(data[:, :, 4])]

[array([[    99,      0,    869,    456,      0],
        [    99,      1,    974,    342,      0],
        [    99,      2,    958,    336,      0],
        ...,
        [108053,      7,    968,    369,      0],
        [108053,      8,    895,    306,      0],
        [108053,      9,    785,    436,      0]]),
 array([[    99,     10,    465,    819,      1],
        [    99,     11,    540,    633,      1],
        [    99,     12,    552,    627,      1],
        ...,
        [108053,     17,    536,    824,      1],
        [108053,     18,    591,    632,      1],
        [108053,     19,    522,    811,      1]]),
 array([[    99,     20,    501,    214,      2],
        [    99,     21,    374,    322,      2],
        [    99,     22,    451,    232,      2],
        ...,
        [108053,     27,    378,    241,      2],
        [108053,     28,    441,    475,      2],
        [108053,     29,    370,    325,      2]]),
 array([[    99,     30,    772,    614,      3],
     

In [8]:
local_data = data[50:10]
print(local_data.shape)

(50, 40, 5)


In [53]:
%%timeit
fly_locs = [list(zip(local_data[:, fly_idx, 2], local_data[:, fly_idx, 3])) for fly_idx in np.arange(40)]

379 µs ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [56]:
local_data.shape

(50, 40, 5)

In [58]:
from seaborn import color_palette

In [72]:
n_arenas 
palette = [tuple(np.array(palette[0]) * 255) for color in color_palette("Paired")]
palette = [palette for _ in np.arange(n_arenas)]

In [73]:
palette

[(166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0),
 (166.0, 206.0, 227.0)]

In [None]:
color_fn = lambda fly_idx, arena: tuple(
    color * 255 for color in palette[idx % len(palette)]
)

In [76]:
data.shape

(107955, 40, 5)

In [117]:
original_mapping = data[0, :, [1, 4]].T
ordering = np.argsort(original_mapping[:, 1])[:, None]
new_ID = np.array([np.argmax(ordering == old_ID) for old_ID in np.arange(40)])[:, None]
full_mapping = np.concatenate([original_mapping, ordering, new_ID], axis=1)

print(full_mapping[np.argsort(original_mapping[:, 1]), :])

[[ 0  0  0  0]
 [25  0 21  1]
 [23  0 28  2]
 [22  0 29  3]
 [ 5  0  6  4]
 [ 6  0 20  5]
 [20  0 13  6]
 [ 8  0 11  7]
 [11  0 18  8]
 [12  0 17  9]
 [27  1  9 10]
 [18  1  4 11]
 [17  1  2 12]
 [16  1 14 13]
 [15  1 19 14]
 [19  1  3 15]
 [14  1 15 16]
 [ 2  1 23 17]
 [ 4  1  5 18]
 [ 3  1 22 19]
 [13  2 16 20]
 [30  2 26 21]
 [29  2 24 22]
 [28  2 10 23]
 [ 1  2 25 24]
 [21  2 30 25]
 [ 7  2  8 26]
 [ 9  2 12 27]
 [10  2 27 28]
 [24  2  1 29]
 [26  3  7 30]
 [38  3 37 31]
 [31  3 38 32]
 [32  3 31 33]
 [33  3 32 34]
 [34  3 33 35]
 [35  3 34 36]
 [36  3 35 37]
 [37  3 36 38]
 [39  3 39 39]]


In [4]:
import pandas as pd

In [4]:
df = pd.read_hdf(df_loc, key="df")
df = df.sort_values(by=["frame", "ID"])

#original_mapping = df.query(f"frame == {df.frame.min()}")[["ID, arena"]]


#data[0, :, [1, 4]].T
#ordering = np.argsort(original_mapping[:, 1])[:, None]
#new_ID = np.array([np.argmax(ordering == old_ID) for old_ID in np.arange(40)])[:, None]
#full_mapping = np.concatenate([original_mapping, ordering, new_ID], axis=1)



In [132]:
df['ID'] = np.concatenate([new_ID.squeeze() for _ in np.arange(df.frame.unique().size)], axis=0)

In [134]:
df = df.sort_values(by=["frame", "ID"])

In [138]:
df.query('arena == 0').ID.unique()

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [139]:
df.query('arena == 1').ID.unique()

array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

In [99]:
df = pd.read_hdf(df_loc, key="df")
df = df.sort_values(by = ["frame", "arena"])

n_flies = df.ID.unique().size
n_frames = df.frame.unique().size

df["ID"] = np.tile(np.arange(n_flies), n_frames)
df = df.sort_values(by=["frame", "ID"])

In [122]:
df = pd.read_hdf(df_loc, key="df")
df = df.sort_values(by=["arena"])

In [123]:
df

Unnamed: 0,frame,ID,x,y,arena
0,99.0,0.0,869.101074,455.739838,0.0
2790208,69854.0,8.0,741.777778,369.422222,0.0
2790211,69854.0,11.0,754.416667,399.250000,0.0
2790212,69854.0,12.0,964.739130,269.739130,0.0
1018652,25565.0,12.0,980.951220,331.243902,0.0
...,...,...,...,...,...
914438,22959.0,38.0,817.526316,839.368421,3.0
914439,22959.0,39.0,837.095238,578.380952,3.0
2946519,73761.0,39.0,769.256410,732.205128,3.0
914431,22959.0,31.0,810.320000,592.480000,3.0


In [124]:
arena_means = np.stack([np.mean(arena_df[["x", "y"]].to_numpy(), axis=0) for _, arena_df in df.query(f"frame == {df.frame.min()}").groupby("arena")], axis=0)
arena_means = np.around(arena_means, decimals=-2)

In [125]:
new_arenas = np.array([np.argmax(np.lexsort((arena_means[:, 0], arena_means[:, 1])) == idx) for idx in np.arange(arena_means.shape[0])])

In [126]:
#df = df.drop(labels="arena", axis=1)
#df = df.drop(labels="arena", axis=1)
df["arena"] = np.repeat(new_arenas, int(df.shape[0] / arena_means.shape[0]))


In [127]:
df = df.sort_values(by=["frame", "arena"])


In [128]:
df[:40]

Unnamed: 0,frame,ID,x,y,arena
29,99.0,29.0,368.687866,359.141541,0
28,99.0,28.0,412.586395,397.629883,0
24,99.0,24.0,516.853821,237.85675,0
21,99.0,21.0,376.755707,279.859833,0
30,99.0,30.0,536.85907,232.687378,0
1,99.0,1.0,500.701599,213.673691,0
13,99.0,13.0,521.86261,218.176682,0
10,99.0,10.0,509.940399,226.678665,0
9,99.0,9.0,450.666779,231.777039,0
7,99.0,7.0,373.817535,322.457062,0


In [19]:
df.arena.unique()

array([1, 2, 0, 3])