In [None]:
from library import *
from bootstrap import *
from eta import *
from fileio import *

In [None]:
def plot_eta_run(
    trace_data, group_name, neuron_class,
    t_before = 8, 
    t_after  = 16,
    thresh_short_FWD = 17,  # <= this threshold, fwd run is short
    thresh_long_FWD  = 24,  # >= this threshold, fwd run is long
    thresh_short_REV = 11,  # <= this threshold, reversal is short
    thresh_long_REV  = 13,  # >= this threshold, reversal is long
    n_bootstrap = 30
):
    # import numpy as np
    # import matplotlib.pyplot as plt
    # from scipy import stats
    
    # First Row: velocity, head curvature, pumping
    target_data = trace_data[group_name]['target']
    eta_data    = trace_data[group_name]['beh']

    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    ax1, ax2, ax3 = axes[0]  # First row
    ax4, ax5, ax6 = axes[1]  # Second row
    plt.suptitle(f'{group_name}: {neuron_class} found in {len(np.unique(trace_start[group_name]['animal']))} animals', fontweight='bold', fontsize=16)

    # First row first subplot: Velocity
    # Separate traces by condition
    short_traces = []
    medium_traces = []
    long_traces = []

    for i in range(len(target_data)):
        if target_data[i] == 0:     # short runs
            short_traces.append(eta_data[0][i][:t_before+t_after])
        elif target_data[i] == 1:   # medium runs
            medium_traces.append(eta_data[0][i][:t_before+t_after])
        else:                       # long runs (target_data[i] == 2)
            long_traces.append(eta_data[0][i][:t_before+t_after])

    if n_bootstrap == 0:
        # Use original traces without bootstrap
        short_mean_beh, short_std_beh   = original_mean_std(short_traces)
        # medium_mean_beh, medium_std_beh = original_mean_std(medium_traces)
        long_mean_beh, long_std_beh     = original_mean_std(long_traces)
    else:
        # Calculate bootstrapped statistics for behavioral data
        short_mean_beh, short_std_beh   = bootstrap_mean_std(short_traces, n_bootstrap)
        # medium_mean_beh, medium_std_beh = bootstrap_mean_std(medium_traces, n_bootstrap)
        long_mean_beh, long_std_beh     = bootstrap_mean_std(long_traces, n_bootstrap)
        
    ax1.axvline([t_before], color='black', linestyle='dotted')
    ax1.axhline([0], color='black', linestyle='dotted')

    if short_mean_beh is not None:
        ax1.plot(short_mean_beh, color='green', linewidth=2, label=f'Short runs < {thresh_short_FWD/1.67:.0f}s (n={len(short_traces)})')
        ax1.fill_between(range(len(short_mean_beh)), 
                         short_mean_beh - 2*short_std_beh, 
                         short_mean_beh + 2*short_std_beh, 
                         color='green', alpha=0.2)

    # if medium_mean_beh is not None:
    #     ax1.plot(medium_mean_beh, color='blue', linewidth=2, label=f'Medium runs {thresh_short/1.67:.0f}-{thresh_long/1.67:.0f}s (n={len(medium_traces)})')
    #     ax1.fill_between(range(len(medium_mean_beh)), 
    #                      medium_mean_beh - 2*medium_std_beh, 
    #                      medium_mean_beh + 2*medium_std_beh, 
    #                      color='blue', alpha=0.2)

    if long_mean_beh is not None:
        ax1.plot(long_mean_beh, color='blueviolet', linewidth=2, label=f'Long runs > {thresh_long_FWD/1.67:.0f}s (n={len(long_traces)})')
        ax1.fill_between(range(len(long_mean_beh)), 
                         long_mean_beh - 2*long_std_beh, 
                         long_mean_beh + 2*long_std_beh, 
                         color='blueviolet', alpha=0.2)

    ax1.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    # ax1.set_xlabel('Time (seconds)', fontsize=12)
    ax1.set_ylabel('mm/s', fontsize=12)
    ax1.set_title(f'Velocity', fontsize=14)
    
    
    # First row second subplot: Head Curvature
    # Separate traces by condition
    short_traces = []
    medium_traces = []
    long_traces = []

    for i in range(len(target_data)):
        if target_data[i] == 0:     # short runs
            short_traces.append(eta_data[1][i][:t_before+t_after])
        elif target_data[i] == 1:   # medium runs
            medium_traces.append(eta_data[1][i][:t_before+t_after])    # ignored for now
        else:                       # long runs (target_data[i] == 2)
            long_traces.append(eta_data[1][i][:t_before+t_after])

    if n_bootstrap == 0:
        # Use original traces without bootstrap
        short_mean_beh, short_std_beh   = original_mean_std(short_traces)
        long_mean_beh, long_std_beh     = original_mean_std(long_traces)
    else:
        # Calculate bootstrapped statistics for behavioral data
        short_mean_beh, short_std_beh   = bootstrap_mean_std(short_traces, n_bootstrap)
        long_mean_beh, long_std_beh     = bootstrap_mean_std(long_traces, n_bootstrap)
        
    ax2.axvline([t_before], color='black', linestyle='dotted')
    ax2.axhline([0], color='black', linestyle='dotted')

    if short_mean_beh is not None:
        ax2.plot(short_mean_beh, color='green', linewidth=2, label=f'Short runs < {thresh_short_FWD/1.67:.0f}s (n={len(short_traces)})')
        ax2.fill_between(range(len(short_mean_beh)), 
                         short_mean_beh - 2*short_std_beh, 
                         short_mean_beh + 2*short_std_beh, 
                         color='green', alpha=0.2)

    if long_mean_beh is not None:
        ax2.plot(long_mean_beh, color='blueviolet', linewidth=2, label=f'Long runs > {thresh_long_FWD/1.67:.0f}s (n={len(long_traces)})')
        ax2.fill_between(range(len(long_mean_beh)), 
                         long_mean_beh - 2*long_std_beh, 
                         long_mean_beh + 2*long_std_beh, 
                         color='blueviolet', alpha=0.2)

    ax2.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    # ax2.set_xlabel('Time (seconds)', fontsize=12)
    ax2.set_ylabel('rad', fontsize=12)
    ax2.set_title(f'Head Curvature', fontsize=14)

    # First row third subplot: Pumping
    # Separate traces by condition
    short_traces = []
    medium_traces = []
    long_traces = []

    for i in range(len(target_data)):
        if target_data[i] == 0:     # short runs
            short_traces.append(eta_data[2][i][:t_before+t_after])
        elif target_data[i] == 1:   # medium runs
            medium_traces.append(eta_data[2][i][:t_before+t_after])
        else:                       # long runs (target_data[i] == 2)
            long_traces.append(eta_data[2][i][:t_before+t_after])

    if n_bootstrap == 0:
        # Use original traces without bootstrap
        short_mean_beh, short_std_beh   = original_mean_std(short_traces)
        long_mean_beh, long_std_beh     = original_mean_std(long_traces)
    else:
        # Calculate bootstrapped statistics for behavioral data
        short_mean_beh, short_std_beh   = bootstrap_mean_std(short_traces, n_bootstrap)
        long_mean_beh, long_std_beh     = bootstrap_mean_std(long_traces, n_bootstrap)
        
    ax3.axvline([t_before], color='black', linestyle='dotted')
    ax3.axhline([0], color='black', linestyle='dotted')

    if short_mean_beh is not None:
        ax3.plot(short_mean_beh, color='green', linewidth=2, label=f'Short runs < {thresh_short_FWD/1.67:.0f}s (n={len(short_traces)})')
        ax3.fill_between(range(len(short_mean_beh)), 
                         short_mean_beh - 2*short_std_beh, 
                         short_mean_beh + 2*short_std_beh, 
                         color='green', alpha=0.2)

    if long_mean_beh is not None:
        ax3.plot(long_mean_beh, color='blueviolet', linewidth=2, label=f'Long runs > {thresh_long_FWD/1.67:.0f}s (n={len(long_traces)})')
        ax3.fill_between(range(len(long_mean_beh)), 
                         long_mean_beh - 2*long_std_beh, 
                         long_mean_beh + 2*long_std_beh, 
                         color='blueviolet', alpha=0.2)

    ax3.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    # ax3.set_xlabel('Time (seconds)', fontsize=12)
    ax3.set_ylabel('Hz', fontsize=12)
    ax3.set_title(f'Pumping', fontsize=14)    
    
    # ================================================================================ #
    # Second row
    # Process neural data
    eta_data = trace_data[group_name]['neu']
    rev_tag  = trace_data[group_name]['rev']
    assert len(rev_tag) == len(target_data), "!!!There are different numbers of FWD labels and REV labels!!!" 

    # Separate traces by condition - Overall
    short_traces      = []
    short_traces_tag  = []
    medium_traces     = []
    medium_traces_tag = []
    long_traces       = []
    long_traces_tag   = []

    for i in range(len(target_data)):
        if target_data[i] == 0:     # short runs
            short_traces.append(eta_data[0][i][:t_before+t_after])
            short_traces_tag.append(rev_tag[i])
        elif target_data[i] == 1:   # medium runs
            medium_traces.append(eta_data[0][i][:t_before+t_after])
            medium_traces_tag.append(rev_tag[i])
        else:                       # long runs (target_data[i] == 2)
            long_traces.append(eta_data[0][i][:t_before+t_after])
            long_traces_tag.append(rev_tag[i])
    
    # Second row first subplot: Overall neural data
    ax4.axvline([t_before], color='black', linestyle='dotted')
    
    if n_bootstrap == 0:
        # Use original traces without bootstrap
        short_mean_neu, short_std_neu = original_mean_std(short_traces)
        # medium_mean_neu, medium_std_neu = original_mean_std(medium_traces)
        long_mean_neu, long_std_neu = original_mean_std(long_traces)
    else:
        # Calculate bootstrapped statistics for neural data
        short_mean_neu, short_std_neu = bootstrap_mean_std(short_traces, n_bootstrap)
        # medium_mean_neu, medium_std_neu = bootstrap_mean_std(medium_traces, n_bootstrap)
        long_mean_neu, long_std_neu = bootstrap_mean_std(long_traces, n_bootstrap)
    
    if short_mean_neu is not None:
        ax4.plot(short_mean_neu, color='green', linewidth=3, label=f'Short runs < {thresh_short_FWD/1.67:.0f}s (n={len(short_traces)})')
        ax4.fill_between(range(len(short_mean_neu)), 
                         short_mean_neu - 2*short_std_neu, 
                         short_mean_neu + 2*short_std_neu, 
                         color='green', alpha=0.3)

    if long_mean_neu is not None:
        ax4.plot(long_mean_neu, color='blueviolet', linewidth=3, label=f'Long runs > {thresh_long_FWD/1.67:.0f}s (n={len(long_traces)})')
        ax4.fill_between(range(len(long_mean_neu)), 
                         long_mean_neu - 2*long_std_neu, 
                         long_mean_neu + 2*long_std_neu, 
                         color='blueviolet', alpha=0.3)

    ax4.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    ax4.set_xlabel('Time (seconds)', fontsize=12)
    ax4.set_ylabel('Normalized Neural Activity', fontsize=12)
    ax4.set_title(f'{neuron_class}: all FWD runs', fontsize=14)    
    ax4.legend(fontsize=12)
    ax4.sharey(ax6)
    
    # Second row second subplot: Short FWD run neural data grouped by preceding reversal duration    
    short_rev_mask_1 = [tag == 0 for tag in short_traces_tag]
    long_rev_mask_1  = [tag == 2 for tag in short_traces_tag]
    short_traces     = np.array(short_traces)
    
    if n_bootstrap == 0:
        # Use original traces without bootstrap
        short_mean_neu_shortRev, short_std_neu_shortRev = original_mean_std(short_traces[short_rev_mask_1])
        short_mean_neu_longRev, short_std_neu_longRev   = original_mean_std(short_traces[long_rev_mask_1])
    else:
        # Calculate bootstrapped statistics for neural data
        short_mean_neu_shortRev, short_std_neu_shortRev = bootstrap_mean_std(short_traces[short_rev_mask_1], n_bootstrap)
        short_mean_neu_longRev, short_std_neu_longRev   = bootstrap_mean_std(short_traces[long_rev_mask_1], n_bootstrap)

    short_rev_mask_2 = [tag == 0 for tag in long_traces_tag]
    long_rev_mask_2  = [tag == 2 for tag in long_traces_tag]
    long_traces      = np.array(long_traces)
    
    if n_bootstrap == 0:
        # Use original traces without bootstrap
        long_mean_neu_shortRev, long_std_neu_shortRev = original_mean_std(long_traces[short_rev_mask_2])
        long_mean_neu_longRev,  long_std_neu_longRev  = original_mean_std(long_traces[long_rev_mask_2])
    else:
        # Calculate bootstrapped statistics for neural data
        long_mean_neu_shortRev, long_std_neu_shortRev = bootstrap_mean_std(long_traces[short_rev_mask_2], n_bootstrap)
        long_mean_neu_longRev,  long_std_neu_longRev  = bootstrap_mean_std(long_traces[long_rev_mask_2], n_bootstrap)
        
    ax5.axvline([t_before], color='black', linestyle='dotted')
    if short_mean_neu_shortRev is not None:
        ax5.plot(short_mean_neu_shortRev, color='olive', linewidth=3, label=f'Short runs (n={len(short_traces[short_rev_mask_1])})')
        ax5.fill_between(range(len(short_mean_neu_shortRev)), 
                         short_mean_neu_shortRev - 2*short_std_neu_shortRev, 
                         short_mean_neu_shortRev + 2*short_std_neu_shortRev, 
                         color='olive', alpha=0.3)

    if long_mean_neu_shortRev is not None:
        ax5.plot(long_mean_neu_shortRev, color='magenta', linewidth=3, label=f'Long runs (n={len(long_traces[short_rev_mask_2])})')
        ax5.fill_between(range(len(long_mean_neu_shortRev)), 
                         long_mean_neu_shortRev - 2*long_std_neu_shortRev, 
                         long_mean_neu_shortRev + 2*long_std_neu_shortRev, 
                         color='magenta', alpha=0.3)

    ax5.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    ax5.set_xlabel('Time (seconds)', fontsize=12)
    # ax5.set_ylabel('Normalized Neural Activity', fontsize=12)
    ax5.set_title(f'{neuron_class}: FWD Runs following Reversals < {thresh_short_REV/1.67:.0f}s', fontsize=14)
    ax5.legend(fontsize=12)
    ax5.sharey(ax6)    
    
    # Second row third subplot: Long FWD run neural data grouped by preceding reversal duration
    ax6.axvline([t_before], color='black', linestyle='dotted')
        
    if short_mean_neu_longRev is not None:
        ax6.plot(short_mean_neu_longRev, color='olive', linewidth=3, label=f'Short runs (n={len(short_traces[long_rev_mask_1])})')
        ax6.fill_between(range(len(short_mean_neu_longRev)), 
                         short_mean_neu_longRev - 2*short_std_neu_longRev, 
                         short_mean_neu_longRev + 2*short_std_neu_longRev, 
                         color='olive', alpha=0.3)

    if long_mean_neu_longRev is not None:
        ax6.plot(long_mean_neu_longRev, color='magenta', linewidth=3, label=f'Long runs (n={len(long_traces[long_rev_mask_2])})')
        ax6.fill_between(range(len(long_mean_neu_longRev)), 
                         long_mean_neu_longRev - 2*long_std_neu_longRev, 
                         long_mean_neu_longRev + 2*long_std_neu_longRev, 
                         color='magenta', alpha=0.3)
        
    ax6.set_xticks([0, t_before, t_before+t_after], [f'{-t_before/1.67:.0f}', '0', f'{t_after/1.67:.0f}'], fontsize=12)
    ax6.set_xlabel('Time (seconds)', fontsize=12)
    # ax6.set_ylabel('Normalized Neural Activity', fontsize=12)
    ax6.set_title(f'{neuron_class}: FWD Runs following Reversals > {thresh_long_REV/1.67:.0f}s', fontsize=14)  
    ax6.legend(fontsize=12)   
    
    plt.tight_layout()
    # plt.show()
    if t_before < t_after:
        plt.savefig(f'/data1/candy/predict_fwd_rev/bootstrap_eta_runs/{neuron_class}_{group_name}_start.png')
    else:
        plt.savefig(f'/data1/candy/predict_fwd_rev/bootstrap_eta_runs/{neuron_class}_{group_name}_end.png')

In [None]:
confidence_threshold = 3

all_trace_starts = {}
all_trace_ends   = {} 

for nc in flv.fig4_neuron_classes:   
    try:
        neuron_classes = [nc]
        
        # by default this function loads all data where "M3" is labelled with a confidence of >=1
        data = import_from_flv_utils(
            neuron_classes = neuron_classes,
            confidence_threshold = 3
        )   
        
        # Make ETA
        trace_start = {}
        # trace_end   = {}

        for group_name, list_outputs in data.items():
            trace_start[group_name] = {}
            # trace_end  [group_name] = {}

            # ETA by run start
            beh, neu, rev, target, animal_uid = get_run_start_eta(list_outputs, neuron_classes)
            trace_start[group_name]['beh'] = beh
            trace_start[group_name]['neu'] = neu
            trace_start[group_name]['rev'] = rev
            trace_start[group_name]['target'] = target
            trace_start[group_name]['animal'] = animal_uid
            plot_eta_run(trace_start, group_name, nc)

            # # ETA by run end
            # beh, neu, target, animal_uid = get_run_end_eta(list_outputs, neuron_classes)
            # trace_end[group_name]['beh'] = beh
            # trace_end[group_name]['neu'] = neu
            # trace_end[group_name]['target'] = target
            # trace_end[group_name]['animal'] = animal_uid
            # plot_eta_run(trace_end, group_name, nc, t_before = 16, t_after = 8)

        print(f'{nc} finished successfully')
        all_trace_starts[nc] = trace_start
        # all_trace_ends  [nc] = trace_end
        
    except Exception as e:
        print(f'Error processing {nc}: {str(e)}')
        continue