In [None]:
from library import *
from fileio import *

In [None]:
from split import *
from feature import *
from evaluate import *
from train_test import *
from plot import *
from reference import *

In [None]:
all_traces = load_dict_from_pickle('/data1/candy/predict_fwd_rev/outputs/all_traces.pkl')

# hyperparameters on time point selection
# t = 8 is when forward run starts at all_trace_starts
# t = 16 is when forward run ends at all_trace_ends
import traceback

fwd_start     = 8
base_t_start  = 6
base_t_end    = 10
slope_t_start = 10
slope_t_end   = 14

verbose = False
perf_by_rev_masks = {}

for maskRev in ['none', 'shortRev', 'longRev']:
    perf = {}

    for group in ['1h_starved']: #['sparse_food', 'off_food', 'just_fed', 'fasted', '1h_starved', 'all']:
        perf[group] = {}

        for nc in all_traces.keys():
            try:
                arr_vel  = all_traces[nc][group]['beh'][0]
                arr_hc   = all_traces[nc][group]['beh'][1]
                arr_pump = all_traces[nc][group]['beh'][2]
                arr_neu  = all_traces[nc][group]['neu'][0]
                rev_tag  = all_traces[nc][group]['rev']
                target   = all_traces[nc][group]['target']
                animal   = all_traces[nc][group]['animal']

                flags, valid_indices = flag_nan_in_array(arr_neu, t_start=fwd_start-4, t_end=slope_t_end)
                valid_vel  = arr_vel[valid_indices]
                valid_hc   = arr_hc[valid_indices]
                valid_pump = arr_pump[valid_indices]
                valid_neu  = arr_neu[valid_indices]
                labels     = target[valid_indices]
                animal_id  = animal[valid_indices]

                valid_rev = rev_tag[valid_indices]
                shortRev_mask = valid_rev < 1
                longRev_mask  = valid_rev == 2

                if maskRev == 'shortRev':
                    valid_vel  = valid_vel[shortRev_mask]
                    valid_hc   = valid_hc[shortRev_mask]
                    valid_pump = valid_pump[shortRev_mask]
                    valid_neu  = valid_neu[shortRev_mask]
                    labels     = labels[shortRev_mask]
                    animal_id  = animal_id[shortRev_mask]
                elif maskRev == 'longRev':
                    valid_vel  = valid_vel[longRev_mask]
                    valid_hc   = valid_hc[longRev_mask]
                    valid_pump = valid_pump[longRev_mask]
                    valid_neu  = valid_neu[longRev_mask]
                    labels     = labels[longRev_mask]
                    animal_id  = animal_id[longRev_mask]

                if np.unique(animal_id).shape[0] > 6 and valid_neu.shape[0] > 50:
                    baseline_neu = get_baseline(valid_neu, t_start=base_t_start, t_end=base_t_end)
                    slope_neu    = get_slope_TheilSen(valid_neu, t_start=slope_t_start, t_end=slope_t_end)

                    prev_vel  = get_baseline(valid_vel, t_start=fwd_start-4, t_end=fwd_start)
                    curr_vel  = get_baseline(valid_vel, t_start=fwd_start,   t_end=fwd_start+4)

                    prev_hc   = get_slope_TheilSen(valid_hc, t_start=fwd_start-4, t_end=fwd_start)
                    curr_hc   = get_slope_TheilSen(valid_hc, t_start=fwd_start,   t_end=fwd_start+4)

                    prev_pump = get_baseline(valid_pump, t_start=fwd_start-4, t_end=fwd_start)
                    curr_pump = get_baseline(valid_pump, t_start=fwd_start,   t_end=fwd_start+4)

                    feat_neu  = np.hstack((baseline_neu, slope_neu))
                    feat_vel  = np.hstack((prev_vel, curr_vel))
                    feat_beh  = np.hstack((prev_vel, curr_vel, prev_hc, curr_hc, prev_pump, curr_pump))
                    feat_comb = np.hstack((baseline_neu, slope_neu, prev_vel, curr_vel, prev_hc, curr_hc, prev_pump, curr_pump))

                    assert feat_comb.shape[0] == valid_neu.shape[0] == valid_vel.shape[0] == len(labels), \
                        'Matrix dimensions do not match up! Something went wrong with feature extraction.'

                    if verbose:
                        print(f"Neural features shape:     {feat_neu.shape}")
                        print(f"Velocity features shape:   {feat_vel.shape}")
                        print(f"Behavioral features shape: {feat_beh.shape}")
                        print(f"Combined features shape:   {feat_comb.shape}")

                    perf[group][nc] = {}
                    perf[group][nc]['neu']  = train_eval(feat_neu,  labels, animal_id, verbose=verbose)
                    perf[group][nc]['vel']  = train_eval(feat_vel,  labels, animal_id, verbose=verbose)
                    perf[group][nc]['beh']  = train_eval(feat_beh,  labels, animal_id, verbose=verbose)
                    perf[group][nc]['comb'] = train_eval(feat_comb, labels, animal_id, verbose=verbose)

            except Exception as e:
                print(f"⚠️  Skipped: maskRev={maskRev}, group={group}, nc={nc} due to error:")
                print(traceback.format_exc())

    perf_by_rev_masks[maskRev] = perf
        
save_dict_as_pickle(perf_by_rev_masks, '/data1/candy/predict_fwd_rev/outputs/perf_runStart_revMasks.pkl')

In [None]:
# fold_results = perf['sparse_food']['AVB']['comb']

# # f1_scores  = [fold['f1_score'] for fold in fold_results]
# aurocs     = [fold['auroc']    for fold in fold_results]
# auprcs     = [fold['auprc']    for fold in fold_results]

# # print_metric_summary("F1 Score", f1_scores)
# print_metric_summary("AUROC", aurocs)
# print_metric_summary("AUPRC", auprcs)

## Visualize performance across neuron classes

In [None]:
group  = 'sparse_food'
mdl0   = 'beh'
mdl1   = 'comb'
metric = 'auprc'

maskRev = 'none'
perf = perf_by_rev_masks[maskRev]

perf_mdl0, perf_mdl1, perf_gain = calc_perf_gain(perf, mdl0=mdl0, mdl1=mdl1, metric=metric, group=group)

In [None]:
create_strip_plot(perf_mdl0, title= f'{group}: AUPRC - Behavior only', baseline=0.5)

In [None]:
create_strip_plot(perf_mdl1, title= f'{group}: AUPRC - Neuron & Behavior', baseline=0.5)

In [None]:
create_strip_plot(perf_gain, title= f'{group}: Δ AUPRC', baseline=0)