In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.patches import Patch
import seaborn as sns
from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache
import psytrack as psy

sns.set_style('ticks')

In [7]:
drive_dir = "/Volumes/Brain2024"
cache = VisualBehaviorNeuropixelsProjectCache.from_local_cache(cache_dir=drive_dir, use_static_cache=True)
behavior_sessions = cache.get_behavior_session_table()
subset_sessions = behavior_sessions[
    (behavior_sessions.session_type.str.contains('EPHYS')) & 
    (behavior_sessions.genotype == 'wt/wt')
]

In [8]:
sessions = cache.get_ecephys_session_table()
ephys_subset = sessions[['genotype', 'behavior_session_id']].loc[sessions['genotype'] == 'wt/wt']

included = ephys_subset.behavior_session_id.to_list()

In [9]:
subset_sessions_ephys = subset_sessions[subset_sessions.index.isin(included)]


In [10]:
len(subset_sessions)

54

In [None]:
for bsid in subset_sessions.index:

    print(f'\n Processing {bsid}')
    
    dir = f'local_modelfit/ephys/{bsid}'
    os.makedirs(dir, exist_ok=True)

    plot_dir = f'plots/modelfit/ephys/{bsid}'
    os.makedirs(plot_dir, exist_ok=True)
    behavior_session = cache.get_behavior_session(bsid)
    licks = behavior_session.licks
    licks_annot = get_metrics.annotate_licks(licks, behavior_session)
    stimulus_presentations = behavior_session.stimulus_presentations
    stimulus_presentations = stimulus_presentations[stimulus_presentations['active']]
    stim_table = get_metrics.annotate_bouts(licks_annot, stimulus_presentations)
    format_options = {
                    'preprocess': 2, 
                    'timing_params': [-5,4],
                    'num_cv_folds': 10}

    psydata = format_Xy.build_regressor(stim_table, format_options, behavior_session)
    
    df = psydata['df']
    wMode = np.load(f'local_modelfit/ephys/{bsid}/wMode.npy')
    w_ratio = wMode[3, :] - wMode[4, :]
    df['diff'] = w_ratio 
    df_change = df[df.change][['diff','licked']]
    df_change.to_csv(f'{dir}/weight_diff_changed_frames.csv')
    
    # Fig 1
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.histplot(data=df, x='diff', ax=ax)
    ax.set_xlabel('Weight Difference', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    plt.tight_layout()
    plt.show()
    fig.savefig(f'{plot_dir}/diff_dist.pdf')
    
    # Fig 2
    fig, ax = plt.subplots(figsize=(20, 6))  # Adjust figure size as needed
    sns.heatmap(wMode, cmap='viridis', cbar=True, ax=ax)
    ax.set_yticks(np.arange(len(strategy_list)) + 0.5)
    strategy_list_sorted = sorted(strategy_list)
    ax.set_yticklabels(strategy_list_sorted, rotation=0, va='center')

    total_ticks = wMode.shape[1]
    tick_locations = np.arange(0, total_ticks, 500)
    ax.set_xticks(tick_locations)
    ax.set_xticklabels(tick_locations)

    ax.set_xlabel('Time (ticks)', fontsize=12)
    ax.set_ylabel('Strategy', fontsize=12)
    ax.set_title('Weight Matrix', fontsize=14)

    cbar = ax.collections[0].colorbar
    cbar.set_label('Weight', rotation=270, labelpad=20)

    plt.tight_layout()
    plt.show()
    fig.savefig(f'{plot_dir}/weight_heatmap.pdf')


    if bsid in included:
        df_change.to_csv(f'{plot_dir}/weight_diff_changed_frames.csv')


    
