In [31]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


    Note: run in megabouts_dlc

In [32]:
import os
import json

# Data Wrangling
import h5py
import numpy as np
import pandas as pd
from pathlib import Path
import glob
import tables
import flammkuchen as fl

# Computation
from scipy.interpolate import interp1d

#custom functions
from datetime import datetime
import math
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
from scipy.signal.signaltools import correlate

from tqdm import tqdm

In [33]:
def nanzscore(array, axis=0):
    return (array - np.nanmean(array, axis=axis))/np.nanstd(array, axis=axis)

def reduce_to_pi(ar):
    """Reduce angles to the -pi to pi range"""
    return np.mod(ar + np.pi, np.pi * 2) - np.pi

def compute_tailsum(tail_angle):
    pre_tailsum= np.zeros((tail_angle.shape[0], tail_angle.shape[1]))
    for segment in range(tail_angle.shape[1]):
        pre_tailsum[:,segment]= (tail_angle[:, segment] - tail_angle[:,0])

    tailsum= np.sum(pre_tailsum, axis=1)/pre_tailsum.shape[1]
    tailsum = reduce_to_pi(tailsum)
    return tailsum

def moving_average(x, w):
    return np.hstack(([0,0], np.convolve(x, np.ones(w), 'valid') / w))

In [34]:
from megabouts.utils import (
    bouts_category_name,
    bouts_category_name_short,
    bouts_category_color,
    cmp_bouts,
)

In [35]:
bouts_category_name

['approach_swim',
 'slow1',
 'slow2',
 'short_capture_swim',
 'long_capture_swim',
 'burst_swim',
 'J_turn',
 'high_angle_turn',
 'routine_turn',
 'spot_avoidance_turn',
 'O_bend',
 'long_latency_C_start',
 'short_latency_C_start']

In [46]:
from scipy.signal import savgol_filter
def smooth_trace(trace, wnd=5, poly=2):
    return savgol_filter(trace, wnd, poly)  # window size 5, polynomial order 2

# Set up paths

In [47]:
# master_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Raw_Data')

# fish_paths = list(master_path.glob('*f[0-9]*'))
# fish_paths, len(fish_paths)

In [77]:
## Analysed for paper

# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\230307_visstim_2D") #rectangular arena # start from fish 1
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\22042024_visstim_2D_round")
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\22042024_visstim_2D_2") #rectangular arena
# master_path = Path(r"\\portulab.synology.me\data\Kata\Data\13052024_visstim_2D_round")
master_path = Path(r"\\portulab.synology.me\data\Kata\Data\14052024_visstim_2D_round")

In [78]:
fish_paths = list(master_path.glob('*f[0-9]*'))
fish_paths, len(fish_paths)

([WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f0'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f1'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f2'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f3'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f4'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f5'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f6'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f7'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f8'),
  WindowsPath('//portulab.synology.me/data/Kata/Data/14052024_visstim_2D_round/240514_f9')],
 10)

In [79]:
# out_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Processed_Data')

In [80]:
## Analysed for paper

# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\230307_visstim_2D_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\22042024_visstim_2D_round_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\22042024_visstim_2D_2_")
# out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\13052024_visstim_2D_round_")
out_path = Path(r"\\portulab.synology.me\data\Kata\Processed_Data\14052024_visstim_2D_round_")

In [81]:
fish= 0
fish_id =  fish_paths[fish].name#[:-13]
exp_name = Path(fish_paths[fish]).parts[-2]
# exp_name = 'testfish'
exp_name, fish_id

('14052024_visstim_2D_round', '240514_f0')

In [82]:
print ("{} videos found".format(len(fish_paths)))

10 videos found


In [83]:
mb_thresh =0.6
plot =False
smooth_data = True
n_clust = 13

## Load and mask data

In [84]:
for ind, fish_path in enumerate(tqdm(fish_paths)):
    fish_id =  fish_path.name
    print ('Working on fish {}'.format(fish_id))

    data = fl.load(out_path/'{}_bout_data.h5'.format(fish_id))
    dlc_filter = data['dlc_filter']
    indices = np.arange(0, dlc_filter.shape[0],1)
    print (indices.shape)

    mask =  data['dlc_filter'] ==1
    mask_2 = data['edge_filter'][mask] == True 
    mask_3 = data['mb_proba'][mask][mask_2] >=mb_thresh

    indices = indices[mask][mask_2][mask_3]
    clusters = data['cluster'][:,0][mask][mask_2][mask_3]
    tail_vectors = data['tail_vectors'][mask][mask_2][mask_3]
    tailsums = data['tailsums'] [mask][mask_2][mask_3]
    l_fin = data['fin_angles'][:,0,:][mask][mask_2][mask_3]
    r_fin = data['fin_angles'][:,1,:][mask][mask_2][mask_3]
    tail_vectors = data['tail_vectors'] [mask][mask_2][mask_3]
    print (data['tail_vectors'].shape, data['laterality'].shape)
    laterality = data['laterality'][mask][mask_2][mask_3]
    print (tail_vectors.shape, np.unique(laterality))

    if smooth_data:
        tailsums = np.apply_along_axis(smooth_trace, 1, tailsums)
        l_fin = np.apply_along_axis(smooth_trace, 1, l_fin)
        r_fin = np.apply_along_axis(smooth_trace, 1, r_fin)

    if plot:
        fig, axes = plt.subplots(1,n_clust, figsize=(20, 5), sharex=True, sharey=True)
        axes= axes.ravel()
        for clust in range(n_clust):
            axes[clust].set_title(bouts_category_name[clust])
            axes[clust].plot(tailsums[clusters==clust].T, c=bouts_category_color[clust], alpha=0.3)
        plt.tight_layout()
    
    # Stack them along the new axis (axis=1) to form the shape (trials, 3, timepoints)
    data_combined = np.stack((tailsums, l_fin, r_fin), axis=1)
    print(data_combined.shape)
    
    ## save tensor 
    fl.save(out_path / '{}_indices.h5'.format(fish_id), indices)
    fl.save(out_path / '{}_tensor.h5'.format(fish_id), data_combined)
    fl.save(out_path / '{}_tail_tensor.h5'.format(fish_id), tail_vectors)
    fl.save(out_path/ '{}_bout_laterality.h5'.format(fish_id), laterality)
    print (indices.shape)


  0%|                                                                                           | 0/10 [00:00<?, ?it/s]

Working on fish 240514_f0
(507,)
(507, 50, 10) (507,)
(192, 50, 10) [-1.  1.]
(192, 3, 50)


 10%|████████▎                                                                          | 1/10 [00:00<00:04,  1.95it/s]

(192,)
Working on fish 240514_f1
(426,)
(426, 50, 10) (426,)
(41, 50, 10) [-1.  1.]
(41, 3, 50)


 20%|████████████████▌                                                                  | 2/10 [00:00<00:03,  2.40it/s]

(41,)
Working on fish 240514_f2
(122,)
(122, 50, 10) (122,)
(40, 50, 10) [-1.  1.]
(40, 3, 50)


 30%|████████████████████████▉                                                          | 3/10 [00:01<00:02,  2.75it/s]

(40,)
Working on fish 240514_f3
(595,)
(595, 50, 10) (595,)
(147, 50, 10) [-1.  1.]
(147, 3, 50)


 40%|█████████████████████████████████▏                                                 | 4/10 [00:01<00:02,  2.47it/s]

(147,)
Working on fish 240514_f4
(299,)
(299, 50, 10) (299,)
(272, 50, 10) [-1.  1.]
(272, 3, 50)


 50%|█████████████████████████████████████████▌                                         | 5/10 [00:02<00:02,  2.06it/s]

(272,)
Working on fish 240514_f5
(323,)
(323, 50, 10) (323,)
(175, 50, 10) [-1.  1.]
(175, 3, 50)


 60%|█████████████████████████████████████████████████▊                                 | 6/10 [00:02<00:01,  2.06it/s]

(175,)
Working on fish 240514_f6
(314,)
(314, 50, 10) (314,)
(277, 50, 10) [-1.  1.]
(277, 3, 50)


 70%|██████████████████████████████████████████████████████████                         | 7/10 [00:03<00:01,  1.89it/s]

(277,)
Working on fish 240514_f7
(330,)
(330, 50, 10) (330,)
(279, 50, 10) [-1.  1.]
(279, 3, 50)


 80%|██████████████████████████████████████████████████████████████████▍                | 8/10 [00:03<00:01,  1.80it/s]

(279,)
Working on fish 240514_f8
(618,)
(618, 50, 10) (618,)
(577, 50, 10) [-1.  1.]
(577, 3, 50)


 90%|██████████████████████████████████████████████████████████████████████████▋        | 9/10 [00:05<00:00,  1.42it/s]

(577,)
Working on fish 240514_f9
(632,)
(632, 50, 10) (632,)
(509, 50, 10) [-1.  1.]
(509, 3, 50)


100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.67it/s]

(509,)





In [85]:
tail_vectors[0].shape

(50, 10)