In [None]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict
from scipy.stats import ttest_ind
from scipy.stats import f as fdistribution
from Helpers.finalproject import *

In [None]:
with open("binwvsrin_oldNaNs.pkl", "rb") as f:
    binwvsrin_oldNaNs = pickle.load(f)
with open("binwvsrin_newNaNs.pkl", "rb") as f:
    binwvsrin_newNaNs = pickle.load(f)

binwvsrinNaNs = {}
binwvsrinNaNs = binwvsrin_oldNaNs | binwvsrin_newNaNs

In [None]:
onewave = list(binwvsrin_newNaNs.keys())[0]
binwvsrinNaNs[onewave]

In [None]:
fu = list(binwvsrin_newNaNs.items())[0]

In [None]:
len(list(binwvsrin_newNaNs.keys()))

In [None]:
one_wave = list(binwvsrinNaNs.keys())[0]
testdf = binwvsrinNaNs[one_wave]
phase = np.arange(0.01, 1.00001, 0.02)
norm_currents = testdf['Normalized Current'].values
peak_idx = np.argmax(norm_currents)
peak_phase = phase[peak_idx]
peak_amplitude = norm_currents[peak_idx]  
slope, intercept = np.polyfit(phase, norm_currents, 1)  

In [None]:
features_df = extract_excitatory_features(binwvsrinNaNs)
summary_across_celltype = aggregate_features(features_df)

In [None]:
feature_cols = ["Peak Gain","Peak Phase","Peak Amplitude"]

In [None]:
features_df

In [None]:
feature_cols = ["Peak Gain","Peak Phase","Peak Amplitude"]
hotttest_results_between = compute_hotellings_tests_grouped(
    features_df, 
    feature_cols,
    compare_col='Rin group', 
    freq_col='Freq Bin'
)
hotttest_results_between
# hotttest_results_between.to_csv("hott_Brin",index=False)

In [None]:
hotttest_results = compute_hotellings_tests(
    features_df, 
    feature_cols,
    group_col='Rin group', 
    compare_col='Cell Type', 
    freq_col='Freq Bin'
)
# hotttest_results.to_csv("hott_NrinBcell",index=False)

In [None]:
results_df=hotttest_results_between
stat_col = "T2_stat"
freq_col="Freq Bin"
figsize=(10, 5)
bar_width=0.2

freq_bins = sorted(results_df[freq_col].dropna().unique(), key=str)

subset = results_df
    
# Count how many bars per bin to offset properly
counts_per_bin = {fb: len(subset[subset[freq_col] == fb]) for fb in freq_bins}
# print(counts_per_bin)
max_count = max(counts_per_bin.values())
# print(max_count)
plt.figure(figsize=figsize)
x_ticks = []
x_labels = []

idx = 0
for i, fb in enumerate(freq_bins):
    bin_subset = subset[subset[freq_col] == fb]
    # print(bin_subset)
    stats = bin_subset[stat_col].values
    # print(stats)
    x = np.arange(len(stats)) * bar_width + i * (max_count + 1) * bar_width
    # print(x)
    plt.bar(x, stats, width=bar_width, color='skyblue', edgecolor='black')

    # Save tick positions and labels
    center_x = x.mean() if len(x) > 0 else i
    x_ticks.append(center_x)
    x_labels.append(fb)

    idx += 1
plt.ylim([0, 2100])

plt.xticks(ticks=x_ticks, labels=x_labels, rotation=45)
plt.xlabel(freq_col)
plt.ylabel(stat_col)
plt.title("Between Rin groups")
plt.tight_layout()
plt.show()
    

In [None]:
group_col='Rin group'

results_df = hotttest_results
freq_col='Freq Bin'
stat_col='T2_stat'
title_prefix="Stat Distribution"
figsize=(10, 5)
bar_width=0.2
title_prefix="Stat Distribution"

unique_groups = sorted(results_df[group_col].dropna().unique())
freq_bins = sorted(results_df[freq_col].dropna().unique(), key=str)

for group in unique_groups:
    subset = results_df[results_df[group_col] == group]
    
    # Count how many bars per bin to offset properly
    counts_per_bin = {fb: len(subset[subset[freq_col] == fb]) for fb in freq_bins}
    # print(counts_per_bin)
    max_count = max(counts_per_bin.values())
    # print(max_count)
    plt.figure(figsize=figsize)
    x_ticks = []
    x_labels = []

    idx = 0
    for i, fb in enumerate(freq_bins):
        bin_subset = subset[subset[freq_col] == fb]
        # print(bin_subset)
        stats = bin_subset[stat_col].values
        # print(stats)
        x = np.arange(len(stats)) * bar_width + i * (max_count + 1) * bar_width
        # print(x)
        plt.bar(x, stats, width=bar_width, color='skyblue', edgecolor='black')

        # Save tick positions and labels
        center_x = x.mean() if len(x) > 0 else i
        x_ticks.append(center_x)
        x_labels.append(fb)

        idx += 1
    plt.ylim([0, 2100])

    plt.xticks(ticks=x_ticks, labels=x_labels, rotation=45)
    plt.xlabel(freq_col)
    plt.ylabel(stat_col)
    plt.title(f"{group}")
    plt.tight_layout()
    plt.show()

In [None]:
primary = features_df[features_df["Rin group"]=="primary"]
Hsecondary = features_df[features_df["Rin group"]=="High Rin Secondary"]
Lsecondary = features_df[features_df["Rin group"]=="Low Rin Secondary "]

In [None]:
def bonferroni_correction(df, alpha=0.05):
    """
    Apply Bonferroni correction to p-values in the DataFrame.
    
    Parameters:
        df: pd.DataFrame
            Must contain a 'p_value' column.
        alpha: float
            Desired significance level.
    
    Returns:
        df with new columns 'p_value_corrected' and 'significant' (bool)
    """
    n_tests = len(df)
    corrected_alpha = alpha / n_tests
    df = df.copy()
    df['p_value_corrected'] = df['p_value'] * n_tests
    df['p_value_corrected'] = df['p_value_corrected'].clip(upper=1.0)  # max 1
    df['significant'] = df['p_value_corrected'] < alpha
    return df, corrected_alpha

In [None]:
def ttester(feature):
    ttest_results_celltype1 = compute_ttests(
        primary, 
        feature, 
        group_col='Freq Bin', 
        compare_col='Cell Type', 
        freq_col='Freq Bin'
    )
    ttest_results_celltype2 = compute_ttests(
        Lsecondary, 
        feature, 
        group_col='Freq Bin', 
        compare_col='Cell Type', 
        freq_col='Freq Bin'
    )
    ttest_results_celltype3 = compute_ttests(
        Hsecondary, 
        feature, 
        group_col='Freq Bin', 
        compare_col='Cell Type', 
        freq_col='Freq Bin'
    )
    ttest_results= pd.concat([ttest_results_celltype1,ttest_results_celltype2,ttest_results_celltype3],ignore_index=False)

    corrected_df, corrected_alph = bonferroni_correction(ttest_results)
    return corrected_df, corrected_alph
# corrected_df.to_csv("phasetst")
corrected_df,corrected_alph = ttester("Peak Gain")

In [None]:
def num_significant(corrected_df):
    print("primary slow:",sum(corrected_df.iloc[0:5,6])/6)
    print("primary fast:",sum(corrected_df.iloc[5:12,6])/6)
    print("low rin secondary slow:",sum(corrected_df.iloc[12:15,6])/3)
    print("low rin secondary fast:",sum(corrected_df.iloc[15:18,6])/3)
    print("high rin slow:",sum(corrected_df.iloc[18:19,6]))
    print("high rin fast:",sum(corrected_df.iloc[19:,6]))

num_significant(corrected_df)

# Plotting waveforms

In [None]:
from collections import Counter

key_counts_all = Counter()

for freq, signal_type, cell_type, cell, rin in binwvsrinNaNs.keys():
    freq_bin = get_freq_bin(freq)
    cell_type_gen = group_iSMN(cell_type)
    signal_type_gen = group_EI(signal_type)
    ringroup = grouprins(cell)
    grouped_key = (freq_bin, signal_type_gen, cell_type_gen, ringroup, rin)
    key_counts_all[grouped_key] += 1
# Initialize sets for each key component
freq_bins = set()
signal_types = set()
cell_types = set()
fast_slow_types = set()

# Loop through all keys in key_counts_all
for freq_bin, signal_type, cell_type, fast_slow,rin in key_counts_all.keys():
    if freq_bin==None:
        continue
    freq_bins.add(freq_bin)
    signal_types.add(signal_type)
    cell_types.add(cell_type)
    fast_slow_types.add(fast_slow)

In [None]:
def average_waveforms_for_key(counter_key,binned):
    '''
    Function that takes a counter key like ('15–25', 'Cell-attached', 'MiP', 'primary')
    Returns a DataFrame with mean and SEM of normalized current for that key
    '''
    import numpy as np
    phase = np.arange(0.01, 1.00001, 0.02)

    # Separate key code
    freq_bin, signal_type, cell_type, cell,rin = counter_key

    dfs = []  # To store matching waveform dataframes
    # loop through all binned waveforms
    for key in binned.keys():
        freq, st, ct,cl,rn = key
        # Check dataframe shape before anything else

        # use get_freq_bin(freq) to turn float into bin, use group_iSMN() to generalize all iSMN
        if get_freq_bin(freq) == freq_bin and group_iSMN(ct) == group_iSMN(cell_type) and group_EI(st) == group_EI(signal_type) and cell==grouprins(cl):
            df = binned[key].copy()
            df['Phase'] = phase
            dfs.append(df)

    # if no waveforms for key
    if not dfs:
        print(f"No matching waveforms found for {counter_key}")
        return None
        

    combined = pd.concat(dfs)

    # Compute mean and SEM per phase bin
    grouped = combined.groupby('Phase', observed=False)['Normalized Current']
    averaged = grouped.mean().reset_index(name='Normalized Current')
    sem = grouped.sem().reset_index(name='SEM')
    # Ensure SEM is float type
    sem['SEM'] = sem['SEM'].astype(float)
    # Merge mean and SEM into a single DataFrame
    result = pd.merge(averaged, sem, on='Phase')
    return result


In [None]:
signal_types=["Excitatory"]


# Assign unique colors to each cell type
cmap = cm.get_cmap('tab20', len(cell_types))
cell_type_colors = {ctype: cmap(i) for i, ctype in enumerate(cell_types)}

# Create 3x4 subplots: rows = swimming speeds, columns = frequency bins
fig, axes = plt.subplots(len(fast_slow_types), len(freq_bins), figsize=(10, 12), sharex=True, sharey=True)

for j, freq_bin in enumerate(freq_bins):
    for i, speed in enumerate(fast_slow_types):

        ax = axes[i, j]

        for signal_type in signal_types:
            for cell_type in cell_types:
                counter_key = (freq_bin, signal_type, cell_type, speed,rin)
                avg_waveform_df = average_waveforms_for_key(counter_key,binwvsrin_newNaNs)
                if avg_waveform_df is None or avg_waveform_df.empty:
                    continue

                label = f"{cell_type} ({signal_type})"
                ax.plot(avg_waveform_df['Phase'], avg_waveform_df['Normalized Current'],
                        label=label, color=cell_type_colors[cell_type])
                ax.fill_between(avg_waveform_df['Phase'],
                                avg_waveform_df['Normalized Current'] - avg_waveform_df['SEM'],
                                avg_waveform_df['Normalized Current'] + avg_waveform_df['SEM'],
                                color=cell_type_colors[cell_type], alpha=0.3)

        # Add subplot title and labels
        if i == 0:
            ax.set_title(f"{freq_bin} Hz")
        if j == 0:
            ax.set_ylabel(f"{speed}\nNormalized Current")
        if i == len(fast_slow_types) - 1:
            ax.set_xlabel("Phase")

# Finalize layout and add legend
handles, labels = [], []
for ctype in cell_types:
    handles.append(plt.Line2D([0], [0], color=cell_type_colors[ctype]))
    labels.append(ctype)
fig.legend(handles, labels, loc='upper right', fontsize='small')

fig.suptitle("Averaged Waveforms by Swimming Speed, Frequency Bin, and Cell Type", fontsize=16)
plt.tight_layout(rect=[0, 0, 0.95, 0.95])
plt.show()


In [None]:
# Make sure these are defined
signal_type = "Excitatory"
line_styles = {'fast': 'solid', 'slow': 'dashed', 'intermediate': 'dotted'}

# Assign colors per cell type
cmap = cm.get_cmap('tab20', len(cell_types))
cell_type_colors = {ctype: cmap(i) for i, ctype in enumerate(cell_types)}

# Plot: one subplot per frequency bin
fig, axes = plt.subplots(1, len(freq_bins), figsize=(6 * len(freq_bins), 5), sharey=True)

if len(freq_bins) == 1:
    axes = [axes]  # make iterable if only one freq_bin

for idx, freq_bin in enumerate(freq_bins):
    ax = axes[idx]

    for cell_type in cell_types:
        for speed in fast_slow_types:
            counter_key = (freq_bin, signal_type, cell_type, speed, rin)
            avg_waveform_df = average_waveforms_for_key(counter_key, binwvsrin_newNaNs)
            if avg_waveform_df is None or avg_waveform_df.empty:
                continue

            label = f"{cell_type} ({speed})"
            ax.plot(avg_waveform_df['Phase'], avg_waveform_df['Normalized Current'],
                    label=label,
                    color=cell_type_colors[cell_type],
                    linestyle=line_styles.get(speed, 'solid'),
                    linewidth=1.5)

            ax.fill_between(avg_waveform_df['Phase'],
                            avg_waveform_df['Normalized Current'] - avg_waveform_df['SEM'],
                            avg_waveform_df['Normalized Current'] + avg_waveform_df['SEM'],
                            color=cell_type_colors[cell_type], alpha=0.2)

    ax.set_title(f"{freq_bin} Hz Excitatory")
    ax.set_xlabel("Phase")
    if idx == 0:
        ax.set_ylabel("Normalized Current")
    ax.legend(fontsize='x-small', loc='best')

fig.suptitle("Excitatory Waveforms by Swim Frequency Bin", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
