### Define & Load

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from matplotlib.backends.backend_pdf import PdfPages
import os
import pandas as pd
sys.path.append(r'D:\Neural-Pipeline\source')
from process_openephys.MemSaccPlot import NeuralDataLoader
########### Zarya Trial Settings #############

angles = np.zeros(20)
eccs = np.zeros(20)

eccs[[0, 2, 4, 6]] = 5
eccs[[1, 3, 5, 7, 10, 13, 16, 19]] = 10
eccs[[8, 11, 14, 17]] = 3.5
eccs[[9, 12]] = 6.72
eccs[[15, 18]] = 5.83

angles[[0, 1]] = 0
angles[[2, 3]] = 90
angles[[4, 5]] = 180
angles[[6, 7]] = 270
angles[[8, 9, 10]] = 42
angles[[11, 12, 13]] = 138
angles[[14, 15, 16]] = 211
angles[[17, 18, 19]] = 329   

########### Plot Settings #############

align_info_list = [{ 'event': 'SACC', 'pre': -0.8, 'post': 0}]
align_info = { 'event': 'SACC', 'pre': -0.8, 'post': 0}
electrode_id = None
kilosort = False

###### variable changed by blocks #############


data_struct = "kilosort"  # Options: "continuous", "spike_time", "kilosort"

match data_struct:
    case "continuous":
        base_dir = r"D:\20250724\2025-07-24_13-46-22\Record Node 110\experiment3\recording1"

    case "spike_time":
        electrode_id = 3
        base_dir = r"D:\20250724\2025-07-24_13-46-22\Record Node 111\experiment3\recording1"
    case "kilosort":
        kilosort = True
        datasets = [
            # {
            #     'base_dir': r"D:\20250718",
            #     'kilo_dir': r"D:\20250718\kilosort4",
            #     'sacc_file': 1410
            # },
            # {
            #     'base_dir': r"D:\Neural-Pipeline\data\20250724_6836",
            #     'kilo_dir': r"D:\Neural-Pipeline\data\20250724_6836\kilosort4",  # Adjust path as needed [1346. 1410. 1412. 1546. 1548. 1550. 1626.]
            #     'sacc_file': 1346
            # },
            {
                'base_dir': r"D:\Neural-Pipeline\data\20250724",
                'kilo_dir': r"D:\Neural-Pipeline\data\20250724\kilosort4",  # Adjust path as needed
                'sacc_file': 1702
            }
            # {
            #     'base_dir': r"D:\Neural-Pipeline\data\20250801",
            #     'kilo_dir': r"D:\Neural-Pipeline\data\20250801\kilosort4",  # Adjust path as needed
            #     'sacc_file': 1350
            # }
        ]

    case _:
        raise ValueError(f"Invalid data_struct: {data_struct}. Must be 'continuous', 'spike_time', or 'kilosort'")


### Plot TDI / Tuning curve

In [None]:
def process_kilosort_data(base_dir, align_info, compare_position, pdf, save_to_pdf, kilo_dir, sacc_file):
    data = NeuralDataLoader(base_dir, kilo_dir = kilo_dir, sacc_file = sacc_file)
    data.parse_ttl_events()
    
    cluster_info = pd.read_csv(data.cluster_info_file, sep='\t')
    all_clusters = cluster_info['cluster_id'].values
    cluster_labels = dict(zip(cluster_info['cluster_id'], cluster_info['KSLabel']))
    
    tdi_all = np.full(len(all_clusters), np.nan)
    
    for idx, cluster_id in enumerate(all_clusters):
        cluster_label = cluster_labels[cluster_id]
        print(f"Processing cluster {cluster_id} ({cluster_label}) ({idx+1}/{len(all_clusters)})")
        
        try:
            cluster_spikes = data.spike_times[data.cluster_id == cluster_id]
            data.current_cluster_spikes = cluster_spikes
            
            avg_firing_rates, _, tdi = data.avg_firing_rate(align_info)
            mean_fr = np.nanmean(list(avg_firing_rates.values()))

            print(f"  Cluster {cluster_id}: TDI={tdi:.3f}, Mean FR={mean_fr:.3f}")
            
            if not np.isnan(tdi) and mean_fr > 2 and tdi > 0.45:
                tdi_all[idx] = tdi
                
                depth_str = ""
                if data.cluster_depth is not None and cluster_id in data.cluster_depth:
                    depth = data.cluster_depth[cluster_id]
                    depth_str = f", Depth: {depth:.0f}"
                
                # Add cluster type to title
                cluster_type = "Good Unit" if cluster_label == 'good' else f"MUA ({cluster_label})"
                title = f"Cluster {cluster_id} ({cluster_type}) - TDI: {tdi:.2f}{depth_str}, Firing Rate: {mean_fr:.2f}"
                plot_results(data, align_info, avg_firing_rates, compare_position, title, pdf, save_to_pdf)
            else:
                print(f"  ✗ Cluster {cluster_id} failed criteria (TDI={tdi:.3f}, FR={mean_fr:.3f})")
        
        except Exception as e:
            print(f"Error processing cluster {cluster_id}: {e}")
            
    
    return tdi_all, data

def process_continuous_data(base_dir, align_info, compare_position, pdf, save_to_pdf, electrode_id):
    data = NeuralDataLoader(base_dir)
    data.load_continuous_data(electrode_id)
    data.parse_ttl_events()
    
    channels_to_process = range(380, 384)  # or [channel] if specific
    tdi_all = np.full(len(channels_to_process), np.nan)
    
    for idx, channel in enumerate(channels_to_process):
        print(f"Processing channel {channel} ({idx+1}/{len(channels_to_process)})")
        
        try:
            data.get_spike_times(channel, threshold=-65)
            total_firing = len(data.spike_times) / (data.ttl_timestamps[-1] - data.ttl_timestamps[0])
            
            if total_firing >= 1:
                avg_firing_rates, _, tdi = data.avg_firing_rate(align_info)
                mean_fr = np.nanmean(list(avg_firing_rates.values()))
                
                if not np.isnan(tdi) and mean_fr > 2 and tdi > 0.45:
                    tdi_all[idx] = tdi
                    title = f"Channel {channel} - TDI: {tdi:.2f}, Firing Rate: {mean_fr:.2f}"
                    plot_results(data, align_info, avg_firing_rates, compare_position, title, pdf, save_to_pdf)
                    
        except Exception as e:
            print(f"Error processing channel {channel}: {e}")
    
    return tdi_all

def process_electrode_data(base_dir, align_info, compare_position, pdf, save_to_pdf):
    tdi_all = np.full(192, np.nan)
    
    for i in range(96):
        electrode_id = i + 1
        data = NeuralDataLoader(base_dir, electrode_id)
        data.parse_ttl_events()

        try:
            avg_firing_rates, _, tdi = data.avg_firing_rate(align_info)
            mean_fr = np.nanmean(list(avg_firing_rates.values()))
            
            if not np.isnan(tdi) and mean_fr > 2 and tdi > 0.45:
                tdi_all[i] = tdi
                title = f"TDI: {tdi:.2f}, Depth: {10722 - electrode_id * 40}, Firing Rate: {mean_fr:.2f}"
                plot_results(data, align_info, avg_firing_rates, compare_position, title, pdf, save_to_pdf)

        except Exception as e:
            print(f"Error processing electrode {electrode_id}: {e}")
    
    return tdi_all

def plot_results(data, align_info, avg_firing_rates, compare_position, title, pdf, save_to_pdf):
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    plt.sca(axes[0])
    data.plot_avg_firing_rate_heatmap(align_info, avg_firing_rates, eccs, angles)
    axes[0].set_title("Heatmap")
    
    plt.sca(axes[1])
    data.plot_psth(avg_firing_rates, compare_position)
    
    fig.suptitle(title, fontsize=14)
    
    if save_to_pdf:
        pdf.savefig()  
        plt.close() 
    else:
        plt.show()


In [None]:
for dataset in datasets:
    base_dir = dataset['base_dir']
    kilo_dir = dataset['kilo_dir']
    sacc_file = dataset['sacc_file']
    
    print(f"\n{'='*50}")
    print(f"Processing dataset: {base_dir}")
    print(f"Sacc file: {sacc_file}")
    print(f"{'='*50}")
    
    save_to_pdf = True
    pdf_path = os.path.join(base_dir, f"memory_saccade_heatmaps_sacc{sacc_file}.pdf")
    compare_position = 3

    with PdfPages(pdf_path) as pdf:
        if kilosort:
            tdi_all, data = process_kilosort_data(base_dir, align_info, compare_position, pdf, save_to_pdf, kilo_dir, sacc_file)
        elif electrode_id is None:
            tdi_all = process_continuous_data(base_dir, align_info, compare_position, pdf, save_to_pdf, electrode_id)
        else:
            tdi_all = process_electrode_data(base_dir, align_info, compare_position, pdf, save_to_pdf, electrode_id)
    
    print(f"Completed processing {base_dir}")
    print(f"Results saved to: {pdf_path}")
   

### Plot firing rate by depth

In [None]:
for dataset in datasets:
    base_dir = dataset['base_dir']
    kilo_dir = dataset['kilo_dir']
    sacc_file = dataset['sacc_file']
    
    print(f"\n{'='*50}")
    print(f"Processing dataset: {base_dir}")
    print(f"Sacc file: {sacc_file}")
    print(f"{'='*50}")
    
    # Load data and process everything in one call
    data = NeuralDataLoader(base_dir, kilo_dir=kilo_dir, sacc_file=sacc_file,  tdi_threshold = 0.45)
    
    # Process all units and create comprehensive heatmap
    results = data.process_all_units_comprehensive(align_info)
    tdi_all, fr_all, depth_all, cluster_data_all, top_matrix, bottom_matrix, unit_info, time_bins = results
    
    print(f"Completed processing {base_dir}")

In [None]:
# Quick plot of TDI over depth
plt.figure(figsize=(10, 8))

# Scatter plot
plt.scatter(tdi_all, depth_all, alpha=0.6, s=50, c=tdi_all, cmap='viridis')

# Add threshold line
plt.axvline(x=0.55, color='red', linestyle='--', linewidth=2, alpha=0.8, label='TDI threshold = 0.55')

# Add colorbar
cbar = plt.colorbar(label='TDI Value')

# Labels and formatting
plt.xlabel('TDI (Tuning Depth Index)', fontsize=12)
plt.ylabel('Depth (μm)', fontsize=12)
plt.title(f'TDI vs Depth Distribution\n{os.path.basename(base_dir)} - {len(tdi_all)} units', fontsize=14)
plt.grid(True, alpha=0.3)

# Invert y-axis so depth=0 is at bottom, max depth at top
# plt.gca().invert_yaxis()

plt.tight_layout()

# Save figure
tdi_plot_path = os.path.join(base_dir, f"tdi_vs_depth_sacc{sacc_file}.png")
plt.savefig(tdi_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"TDI vs Depth plot saved to: {tdi_plot_path}")



### Plot TDI by depth

In [2]:
for dataset in datasets:
    base_dir = dataset['base_dir']
    kilo_dir = dataset['kilo_dir']
    sacc_file = dataset['sacc_file']
    
    print(f"\n{'='*50}")
    print(f"Processing dataset: {base_dir}")
    print(f"Sacc file: {sacc_file}")
    print(f"{'='*50}")
    
    # Load data and process everything in one call
    data = NeuralDataLoader(base_dir, kilo_dir=kilo_dir, sacc_file=sacc_file,  tdi_threshold = 0.45)
    
    # Process all units and create comprehensive heatmap
    results = data.create_comprehensive_tdi_heatmap(align_info)
    
    print(f"Completed processing {base_dir}")


Processing dataset: D:\Neural-Pipeline\data\20250724
Sacc file: 1702
Selected base directory: D:\Neural-Pipeline\data\20250724
Loading kilosort data...
Filtering TTL data for block 1702
Parsed 176 trials, 84 good trials.
Processing TDI of cluster: 0
Processing TDI of cluster: 1
Processing TDI of cluster: 2
Processing TDI of cluster: 3
Processing TDI of cluster: 4
Processing TDI of cluster: 5
Processing TDI of cluster: 6
Processing TDI of cluster: 7
Processing TDI of cluster: 8
Processing TDI of cluster: 9
Processing TDI of cluster: 10
Processing TDI of cluster: 11
Processing TDI of cluster: 14
Processing TDI of cluster: 12
Processing TDI of cluster: 13
Processing TDI of cluster: 15
Processing TDI of cluster: 16
Processing TDI of cluster: 17
Processing TDI of cluster: 18
Processing TDI of cluster: 20
Processing TDI of cluster: 19
Processing TDI of cluster: 21
Processing TDI of cluster: 22
Processing TDI of cluster: 23
Processing TDI of cluster: 25
Processing TDI of cluster: 28
Processi

KeyboardInterrupt: 

In [None]:
tdi_baseline_corrected, valid_units, time_bins = results
print(tdi_baseline_corrected[0])