In [44]:
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
import numpy as np
import pandas as pd
import h5py
from pathlib import Path

import matplotlib.pyplot as plt
import flammkuchen as fl

import random
from tqdm import tqdm

In [46]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [47]:
from megabouts_helper import labels_cat, color

In [48]:
import matplotlib.gridspec as gridspec
from cycler import cycler

from megabouts.tracking_data import TrackingConfig, FullTrackingData
from megabouts.pipeline import FullTrackingPipeline
from megabouts.utils import (
    bouts_category_name,
    bouts_category_name_short,
    bouts_category_color,
    cmp_bouts,
)

In [49]:
def compute_angle_between_vect_tail(v1, v2):
    dot = np.einsum('ijk,ijk->ij',[v1,v1,v2],[v2,v1,v2])
    cos_= dot[0,:]
    sin_= np.cross(v1,v2)
    angle_= np.arctan2(sin_,cos_)
    return angle_


def compute_angle_between_vect(u,v):
    u = u/np.linalg.norm(u)
    v = v/np.linalg.norm(v)
    return np.arctan2(u[0]*v[1]-u[1]*v[0],u[0]*v[0]+u[1]*v[1])

In [50]:
def compute_body_angle(head_x, head_y, body_x, body_y):
    """
    Computes the angle between two points in 2D space.
    
    Parameters:
    head_x, head_y: Coordinates of the first point (head).
    body_x, body_y: Coordinates of the second point (body).
    
    Returns:
    angles_radians: The angle in radians.
    angles_degrees: The angle in degrees.
    """
    # Calculate the differences in the x and y coordinates
    delta_x = body_x - head_x
    delta_y = body_y - head_y

    # Calculate the angle using numpy's arctan2
    angles_radians = np.arctan2(delta_y, delta_x)

    # Convert the angle from radians to degrees
    angles_degrees = np.degrees(angles_radians)

    return angles_radians, angles_degrees


In [51]:
### compute fin and body angles
def fin_preprocess(df, mid_headx, mid_heady, body_x, body_y):
    ##Fin angle computatright
    #Fin angle computatright
    right_fin_tip_x =  df['right_fin_tip'].values[:, 0].astype('float')
    right_fin_tip_y =  df['right_fin_tip'].values[:, 1].astype('float')
    right_fin_base_x =  df['right_fin_base'].values[:, 0].astype('float')
    right_fin_base_y =  df['right_fin_base'].values[:, 1].astype('float')

    left_fin_tip_x =  df['left_fin_tip'].values[:, 0].astype('float')
    left_fin_tip_y =  df['left_fin_tip'].values[:, 1].astype('float')
    left_fin_base_x =   df['left_fin_base'].values[:, 0].astype('float')
    left_fin_base_y =   df['left_fin_base'].values[:, 1].astype('float')

    # lets make all the vectors
    a = left_fin_base_x-left_fin_tip_x
    b = left_fin_base_y-left_fin_tip_y
    left_fin_vect = np.array([b,-a])

    a = right_fin_base_x-right_fin_tip_x 
    b = right_fin_base_y-right_fin_tip_y
    right_fin_vect = np.array([-b,a])

    body_vect = np.vstack((mid_headx -body_x , mid_heady - body_y)) 

    ## Compute angles between vectors
    left_fin_angle =  compute_angle_between_vect(left_fin_vect, body_vect)
    right_fin_angle =  compute_angle_between_vect(right_fin_vect, body_vect)

    #nan movement artifacts
    left_fin_angle = left_fin_angle - left_fin_angle[0]
    right_fin_angle = right_fin_angle - right_fin_angle[0]
    left_fin_angle[abs(np.diff(left_fin_angle, prepend=[0])) >= 2] = 0 #np.nan #np.pi
    right_fin_angle[abs(np.diff(right_fin_angle, prepend=[0])) >= 2] = 0 #np.nan #np.pi
    
    return left_fin_vect, right_fin_vect, left_fin_angle, right_fin_angle


# Load Bouts

In [52]:
# master_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Raw_Data')

# fish_paths = list(master_path.glob('*f[0-9]*'))
# fish_paths, len(fish_paths)


In [53]:
## Analysed for paper

# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\230307_visstim_2D") #rectangular arena # start from fish 1
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\22042024_visstim_2D_round")
master_path = Path(r"\\portulab.synology.me\data\Kata\Data\22042024_visstim_2D_2") #rectangular arena
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\13052024_visstim_2D_round")
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\14052024_visstim_2D_round")

In [54]:
fish_paths = list(master_path.glob('*f[0-9]*'))
fish_paths, len(fish_paths)

([WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f0'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f0_1'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f0_2'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f1_1'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f2'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f3'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f4'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f5'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f6'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f7'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/22042024_visstim_2D_2/240422_f8'),
  WindowsPath('//portulab.

In [55]:
fish= 3
fish_path = fish_paths[fish]
fish_id =  fish_paths[fish].name#[:-13]
exp_name = Path(fish_paths[fish]).parts[-2]
# exp_name = 'testfish'
exp_name, fish_id

('22042024_visstim_2D_2', '240422_f1_1')

In [56]:
# out_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Processed_Data')

In [57]:
## Analysed for paper

# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\230307_visstim_2D_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\22042024_visstim_2D_round_")
out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\22042024_visstim_2D_2_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\13052024_visstim_2D_round_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\14052024_visstim_2D_round_")


### Load DLC

In [58]:
fps=200
mm_per_unit = 1/70
N_seg = 10

In [59]:

# for ind, fish_path in enumerate(tqdm(fish_paths[1:])):
for ind, fish_path in enumerate(tqdm(fish_paths)):
    fish_id =  fish_path.name
    print ('Working on fish {}'.format(fish_id))
    df = pd.read_csv(out_path/ '{}_DLC_mod.csv'.format(fish_id), header=[0, 1])
    print(f'{df.shape[0]/(fps*60)} minutes at {fps} fps')
    print('working on {} frames'.format(df.shape[0]))
    
    #Extract angles
    body_x = df.body.values[:, 0].astype('float')
    body_y = df.body.values[:, 1].astype('float')
    
    # Compute head and tail coordinates and convert to mm
    tail_x_col = [f'tail_{i}' for i in range(N_seg)]
    tail_y_col = [f'tail_{i}' for i in range(N_seg)]
    tail_x = np.array([df[x].iloc[:, 0].values.astype('float') for x in tail_x_col]) * mm_per_unit
    tail_y = np.array([df[x].iloc[:, 1].values.astype('float') for x in tail_y_col]) * mm_per_unit
    head_x = df.mid_head.values[:, 0].astype('float') * mm_per_unit
    head_y = df.mid_head.values[:, 1].astype('float') * mm_per_unit

    # compute body angle
    body_x_ = np.asarray(df.body.values[:, 0].astype('float'))
    body_y_ = np.asarray(df.body.values[:, 1].astype('float'))
    head_x_ = np.asarray(df.mid_head.values[:, 0].astype('float'))
    head_y_ = np.asarray(df.mid_head.values[:, 1].astype('float'))
    body_angle, angles_degrees = compute_body_angle(head_x_, head_y_, body_x_, body_y_)

    # compute fin angles
    left_fin_vect, right_fin_vect, left_fin_angle, right_fin_angle = fin_preprocess(df, head_x_, head_y_, body_x_, body_y_)

    # Load eye stuff
    eye_angles = fl.load(fish_path/'eye_angles.h5')['eye_angles'] #for the hdf5 way of saving dict needs ['eye_angles']
    vergence = fl.load(fish_path/'eye_rot.h5')['eye_rot']
    rotation_eye = fl.load(fish_path/'eye_verg.h5')['eye_verg']
    eye_coords = fl.load(fish_path/'eye_coords.h5')['eye_coords']
    
    left_eye_angle = np.rad2deg(eye_angles[:,0])
    right_eye_angle = np.rad2deg(eye_angles[:,1])
    rotation_eye = np.rad2deg(rotation_eye)


    ### MB pipeline
    # Load data and set tracking configuration
    tracking_cfg = TrackingConfig(fps=fps, tracking="full_tracking")

    # Create FullTrackingData object
    tracking_data = FullTrackingData.from_keypoints(
        head_x=head_x, head_y=head_y, tail_x=tail_x.T, tail_y=tail_y.T)

    print (head_x.shape, head_y.shape, tail_x.shape, tail_y.shape)

    pipeline = FullTrackingPipeline(tracking_cfg, exclude_CS=True)
    pipeline.segmentation_cfg.threshold = 20
    pipeline.tail_preprocessing_cfg.tail_speed_filter_ms = 50

    ethogram, bouts, segments, tail, traj = pipeline.run(tracking_data)

    ### Save object
    megabouts_res = dict({
    'segments_on': np.asarray(segments.onset),
    'segments_off': np.asarray(segments.offset),
    'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
    'clusters':  np.asarray(bouts.df.label.category),
    'laterality' :np.asarray(bouts.df.label.sign),
    'proba' :np.asarray(bouts.df.label.proba),
    
    'clean_data_tail':np.asarray(ethogram.df["tail_angle"].values),
    'body_angle' :np.asarray(body_angle),
    'body_angle_rad' :np.asarray(angles_degrees),
    'head_angle_mb' :np.asarray(traj.yaw_smooth),
    'duration' : np.asarray(bouts.df.location.offset - bouts.df.location.onset),
    'bouts_df': bouts.df, 
    'ethogram_df': ethogram.df,

    'fin_angles': np.asarray([left_fin_angle, right_fin_angle]),
    'eye_angles': np.asarray([left_eye_angle, right_eye_angle]), 
    'vergence': np.asarray(vergence),
    'rotation': np.asarray(rotation_eye), 
    
    })
    fl.save(out_path/'{}_megabouts_res.h5'.format(fish_id), megabouts_res)

    


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

Working on fish 240422_f0
9.354083333333334 minutes at 200 fps
working on 112249 frames
(112249,) (112249,) (10, 112249) (10, 112249)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
  8%|██████▉                                                                            | 1/12 [00:36<06:42, 36.56s/it]

Working on fish 240422_f0_1
9.003 minutes at 200 fps
working on 108036 frames
(108036,) (108036,) (10, 108036) (10, 108036)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 17%|█████████████▊                                                                     | 2/12 [01:09<05:46, 34.68s/it]

Working on fish 240422_f0_2
8.945666666666666 minutes at 200 fps
working on 107348 frames
(107348,) (107348,) (10, 107348) (10, 107348)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 25%|████████████████████▊                                                              | 3/12 [01:43<05:06, 34.08s/it]

Working on fish 240422_f1_1
8.896666666666667 minutes at 200 fps
working on 106760 frames
(106760,) (106760,) (10, 106760) (10, 106760)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 33%|███████████████████████████▋                                                       | 4/12 [02:16<04:30, 33.83s/it]

Working on fish 240422_f2
8.878916666666667 minutes at 200 fps
working on 106547 frames
(106547,) (106547,) (10, 106547) (10, 106547)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 42%|██████████████████████████████████▌                                                | 5/12 [02:50<03:55, 33.66s/it]

Working on fish 240422_f3
7.849333333333333 minutes at 200 fps
working on 94192 frames
(94192,) (94192,) (10, 94192) (10, 94192)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 50%|█████████████████████████████████████████▌                                         | 6/12 [03:18<03:11, 31.91s/it]

Working on fish 240422_f4
9.144166666666667 minutes at 200 fps
working on 109730 frames
(109730,) (109730,) (10, 109730) (10, 109730)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 58%|████████████████████████████████████████████████▍                                  | 7/12 [03:52<02:42, 32.45s/it]

Working on fish 240422_f5
6.742416666666666 minutes at 200 fps
working on 80909 frames
(80909,) (80909,) (10, 80909) (10, 80909)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 67%|███████████████████████████████████████████████████████▎                           | 8/12 [04:17<02:00, 30.07s/it]

Working on fish 240422_f6
7.639666666666667 minutes at 200 fps
working on 91676 frames
(91676,) (91676,) (10, 91676) (10, 91676)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [04:45<01:28, 29.46s/it]

Working on fish 240422_f7
8.353333333333333 minutes at 200 fps
working on 100240 frames
(100240,) (100240,) (10, 100240) (10, 100240)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
  'cluster_n_vector': np.asarray(ethogram.df[("bout", "cat")].values),
 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [05:16<01:00, 30.15s/it]

Working on fish 240422_f8
8.89475 minutes at 200 fps
working on 106737 frames
(106737,) (106737,) (10, 106737) (10, 106737)


  net.load_state_dict(torch.load(transformer_weights_path,map_location=torch.device(self.device)))
 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [05:48<01:09, 34.81s/it]


ValueError: could not broadcast input array from shape (10,0) into shape (10,40)

In [None]:
print ('done')

In [None]:
megabouts_res.keys()

In [None]:
np.unique(megabouts_res['clusters'])