In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings('ignore')


In [2]:
import os
import json

# Data Wrangling
import h5py
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import glob
import flammkuchen as fl
import shutil

import seaborn as sns
import matplotlib as mpl

from tqdm import tqdm

## Custom functions from helper file

In [3]:
from feature_extractuion_helper import smooth_trace, reshape_feature_array, get_eye_max, invert_tail, compute_directionality
from feature_extractuion_helper import compute_vigor, compute_bout_dur, get_vigor_stats, compute_oscillations
from feature_extractuion_helper import time_of_first_peak, compute_leading_fin, compute_corr_lag, calulate_frequency_with_peaks
from feature_extractuion_helper import extract_wavelet_features

## Read in Data

In [8]:
master_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Raw_Data')
fish_paths = list(master_path.glob('*f[0-9]*'))
fish_paths

[WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/105234_f0'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/105646_f0'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/110659_f0'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/111701_f0'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/112901_f1'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/113916_f2'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/114946_f2'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/120328_f3'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/120350_f3'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/123953_f4'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/125639_f5'),
 WindowsPath('//portulab.synology.me/data/Kata/Data/230307_visstim_2D/130839_f5'),
 Win

In [9]:
fish= 0
fish_id =  fish_paths[fish].name
exp_name = Path(fish_paths[fish]).parts[-2]
exp_name = 'testfish'
fish_id, exp_name


('105646_f0', '230307_visstim_2D')

In [10]:
out_path = Path(r'\\portulab.synology.me\data\Kata\testdata\Processed_Data')
save_data_path = out_path

# Pooled loop

In [18]:
for ind, fish_path in enumerate(tqdm(fish_paths[1:])):
    fish_id =  fish_path.name
    print ('Working on fish {}'.format(fish_id))

    # try:
    bout_data = fl.load(out_path/'{}_bout_data.h5'.format(fish_id))
    traces =  fl.load(out_path / '{}_tensor.h5'.format(fish_id))
    tail = traces[:, 0, :]
    l_fin = traces[:, 1, :]
    r_fin = traces[:, 2, :]

    indices = fl.load(out_path / '{}_indices.h5'.format(fish_id))

    t_feature_vector_array = fl.load(save_data_path/ '{}_t_feature_vector_array.h5'.format(fish_id))
    l_feature_vector_array = fl.load(save_data_path/ '{}_l_feature_vector_array.h5'.format(fish_id))
    r_feature_vector_array = fl.load(save_data_path/ '{}_r_feature_vector_array.h5'.format(fish_id))

    t_reshaped_array, t_peaks_a_array, t_peaks_i_array, t_valleys_a_array, t_valleys_i_array = reshape_feature_array(t_feature_vector_array)
    l_reshaped_array, l_peaks_a_array, l_peaks_i_array, l_valleys_a_array, l_valleys_i_array = reshape_feature_array(l_feature_vector_array)
    r_reshaped_array, r_peaks_a_array, r_peaks_i_array, r_valleys_a_array, r_valleys_i_array = reshape_feature_array(r_feature_vector_array)

    ipsi_contra_tensor = fl.load(Path(out_path / '{}_tensor_ipsi_contra.h5'.format(fish_id)))
    ipsi_fin_id= fl.load(Path(out_path/'{}_ipsi_fin_ids.h5'.format(fish_id)))
    leading_fin = fl.load(Path(out_path/ '{}_leading_fin.h5'.format(fish_id)))
    osc = fl.load(out_path/ '{}_oscillations.h5'.format(fish_id)) # [t_osc, l_osc, r_osc])
    laterality = fl.load(out_path/ '{}_bout_laterality.h5'.format(fish_id))
    freqs = fl.load(out_path/ '{}_tbf_output.h5'.format(fish_id))

    [duration_t, start_t, end_t] = fl.load( out_path / '{}_t_durations.h5'.format(fish_id)) 
    [duration_l, start_l, end_l] = fl.load( out_path / '{}_l_durations.h5'.format(fish_id)) 
    [duration_r, start_r, end_r] = fl.load( out_path / '{}_r_durations.h5'.format(fish_id))

    tail_periods_array = fl.load( out_path/ '{}_t_full_beats.h5'.format(fish_id))
    tail_half_beats_array = fl.load( out_path/ '{}_t_half_beats.h5'.format(fish_id))
    l_periods_array = fl.load( out_path/ '{}_l_full_beats.h5'.format(fish_id))
    l_half_beats_array = fl.load( out_path/ '{}_l_half_beats.h5'.format(fish_id))
    r_periods_array = fl.load( out_path/ '{}_r_full_beats.h5'.format(fish_id))
    r_half_beats_array = fl.load( out_path/ '{}_r_half_beats.h5'.format(fish_id))

    
    ### Creating the dataframe
    tails_ = bout_data['tailsums']
    n_bouts = tail.shape[0]
    
    # create dataframe
    df = pd.DataFrame()
    df['exp'] =  [exp_name] * n_bouts
    df['fish_id'] = [fish_id] * n_bouts
    df['bout_angle'] = bout_data['body_angles_delta'][indices]
    df['cluster'] = clusters = bout_data['cluster'][indices]
    df['frame_start'] =  bout_data['bout_times'][:,0][indices]
    df['frame_end'] = bout_data['bout_times'][:,1][indices]
    df['bout_id_orig'] = indices

    # Add all the filters
    df['mb_proba'] = bout_data['mb_proba'][indices]
    df['dlc_tracking_score'] =bout_data['dlc_filter'][indices]
    df['edge_tracking_score'] = bout_data['edge_filter'][indices]

    # Add the eyes
    left_eye_angles = bout_data['eye_angles'][:,0,:]
    right_eye_angles = bout_data['eye_angles'][:,1,:]
    eye_rotation = bout_data['eye_rotation'] 
    eye_vergence = bout_data['eye_vergence']
    max_eye_rot, max_eye_vergence = get_eye_max(eye_rotation, eye_vergence)
    df['max_eye_rot'] =max_eye_rot[indices]
    df['max_eye_vergence'] = max_eye_vergence[indices]

    vigor_stats = np.zeros((tail.shape[0], 2))
    

    for i in range(tail.shape[0]):
        vigor = compute_vigor(tail[i])
        max_vig, med_vig = get_vigor_stats(vigor)
        vigor_stats[i] = max_vig, med_vig
    print (vigor_stats)
        
    df['tail_peak_vigor'] = vigor_stats[:,0]

    l_fin_vigor_stats = np.zeros((l_fin.shape[0], 2))
    r_fin_vigor_stats = np.zeros((l_fin.shape[0], 2))
    
    for i in range(l_fin.shape[0]):
        l_fin_vigor = compute_vigor(l_fin[i]*-1)
        max_vig, med_vig = get_vigor_stats(l_fin_vigor)
        l_fin_vigor_stats[i] = max_vig, med_vig
    
        r_fin_vigor = compute_vigor(r_fin[i])
        max_vig, med_vig = get_vigor_stats(r_fin_vigor)
        r_fin_vigor_stats[i] = max_vig, med_vig

    df['l_fin_peak_vigor'] = l_fin_vigor_stats[:,0]
    df['r_fin_peak_vigor'] = r_fin_vigor_stats[:,0]
    
    corr, lags = compute_corr_lag(l_fin, r_fin)
    df['fin_fin_corr'] = corr
    df['fin_fin_lag'] = lags

    df['ipsi_fin'] = ipsi_fin_id
    df['tail_direction'] = laterality
    df['leading_fin'] = leading_fin

    l_max_amp = np.nanmax(l_peaks_a_array,axis=1)
    r_max_amp = np.nanmax(r_peaks_a_array,axis=1)
    df['l_max_amp'] = l_max_amp
    df['r_max_amp'] = r_max_amp
    
    mean_freqs = np.nanmean(freqs, axis=2)
    df['tail_freq'] = mean_freqs[0,:] 
    df['l_fin_freqs'] = mean_freqs[1,:]
    df['r_fin_freqs'] = mean_freqs[2,:]
    
    df['tail_osc'] = osc[0]
    df['l_osc'] = osc[1]
    df['r_osc'] = osc[2]
    
    df['tail_duration'] = duration_t
    df['t_start'] = start_t
    df['t_end'] = end_t
    
    df['l_fin_duration'] = duration_l
    df['l_start'] = start_l
    df['l_end'] = end_l
    
    df['r_fin_duration'] = duration_r
    df['r_start'] = start_r
    df['r_end'] = end_r
    
    df.to_csv(Path(out_path /'{}_features.csv'.format(fish_id)), index=False)  

# except:
#     print ('no')
#     pass


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]

Working on fish 105646_f0
Reshaped array shape: (47, 4, 9)
Reshaped array shape: (47, 4, 9)
Reshaped array shape: (47, 4, 9)


  0%|                                                                                           | 0/13 [00:00<?, ?it/s]


TypeError: only integer scalar arrays can be converted to a scalar index