In [8]:
import numpy as np
import pandas as pd
import ast
import re

In [1]:
df = pd.read_csv("P1/P1_S6/softpriors_counts_and_times_per_channel_filtered.csv")

In [2]:
def parse_spike_times(x):
    if pd.isna(x):
        return []               # missing value
    elif isinstance(x, str):
        return ast.literal_eval(x)  # stringified list
    elif isinstance(x, (int, float)):
        return [x]              # single spike
    elif isinstance(x, list):
        return x                # already a list
    else:
        raise ValueError(f"Unknown type: {type(x)}")

# Apply to all time columns
for col in ["p_spike_1.0_times", "p_poly_1.0_times", "p_sw_1.0_times"]:
    df[col] = df[col].apply(parse_spike_times)

In [3]:
def merge_spikes(row):
    total_count = row["p_spike_1.0_count"] + row["p_poly_1.0_count"] + row["p_sw_1.0_count"]

    all_spike_times = row["p_spike_1.0_times"] + row["p_poly_1.0_times"] + row["p_sw_1.0_times"]
    all_spike_types = (
        ["p_spike"] * len(row["p_spike_1.0_times"]) +
        ["p_poly"] * len(row["p_poly_1.0_times"]) +
        ["p_sw"] * len(row["p_sw_1.0_times"])
    )

   # Sort spikes *within* this region by time
    if len(all_spike_times) > 0:
        sorted_pairs = sorted(zip(all_spike_times, all_spike_types), key=lambda x: x[0])
        all_spike_times, all_spike_types = zip(*sorted_pairs)
    else:
        all_spike_times, all_spike_types = [], []

    return pd.Series({
        "total_spike_count": total_count,
        "all_spike_times": list(all_spike_times),
        "all_spike_types": list(all_spike_types)
    })

In [6]:
merged_spikes = df.apply(merge_spikes, axis=1)
df["total_spike_count"] = merged_spikes["total_spike_count"]
df["all_spike_times"] = merged_spikes["all_spike_times"]
df["all_spike_types"] = merged_spikes["all_spike_types"]
merged_spikes.to_csv("P1/P1_S6/merged_spikes.csv")

In [11]:
spikes_df = pd.read_csv('P1/P1_S6/merged_spikes.csv')          # your spike data
coords_df = pd.read_csv('P1/P1_S6/mni_coordinates.csv')     # channel -> MNI mapping

# If MNI coordinates are strings in coordinates.csv, convert them to lists
coords_df['MNI'] = coords_df['MNI'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)

# Create a dictionary for fast lookup: channel name -> MNI
mni_dict = dict(zip(coords_df['Channel'], coords_df['MNI']))

# Function to get MNI coordinates for each electrode in a channel pair
def get_mni_coords(channel_pair):
    ch1, ch2 = channel_pair.split('-')  # split "LA1-LA2" into ["LA1","LA2"]
    mni1 = mni_dict.get(ch1, None)      # lookup MNI coordinates
    mni2 = mni_dict.get(ch2, None)
    return pd.Series([mni1, mni2])

# Apply function to create two new columns next to 'Channels'
spikes_df[['MNI Channel 1', 'MNI Channel 2']] = spikes_df['Channels'].apply(get_mni_coords)

# Optional: reorder columns so the MNI columns are right after 'Channels'
cols = spikes_df.columns.tolist()
channels_index = cols.index('Channels')
new_order = cols[:channels_index+1] + ['MNI Channel 1', 'MNI Channel 2'] + cols[channels_index+1:-2]
spikes_df = spikes_df[new_order]

# Save the augmented CSV
spikes_df.to_csv('spikes_with_mni.csv', index=False)

   Unnamed: 0   Channels               MNI Channel 1  \
0           0    LA1-LA2  [-18.631, -3.602, -24.235]   
1           1  LA10-LA11  [-53.653, -3.209, -26.017]   
2           2  LA11-LA12   [-57.567, -3.09, -26.314]   
3           3    LA2-LA3  [-22.428, -3.563, -24.289]   
4           4    LA3-LA4   [-26.25, -3.548, -24.342]   

                MNI Channel 2  total_spike_count  \
0  [-22.428, -3.563, -24.289]                 10   
1   [-57.567, -3.09, -26.314]                 14   
2   [-61.42, -2.926, -26.466]                 17   
3   [-26.25, -3.548, -24.342]                 17   
4  [-30.096, -3.545, -24.449]                 13   

                                     all_spike_times  \
0  [445.78, 1992.76, 2007.49, 2011.7, 2077.905, 3...   
1  [444.84, 1996.3, 3513.75, 3613.995, 3625.58, 3...   
2  [3496.66, 3513.755, 3614.005, 4304.55, 4322.59...   
3  [1091.765, 2009.38, 2501.5, 2943.875, 3252.86,...   
4  [2538.88, 3615.435, 3631.22, 4304.255, 4969.26...   

             

In [13]:
mni_diff_df = pd.read_csv("P1/P1_S6/spikes_with_mni.csv")

mni_diff_df['MNI Channel 1'] = mni_diff_df['MNI Channel 1'].apply(ast.literal_eval)
mni_diff_df['MNI Channel 2'] = mni_diff_df['MNI Channel 2'].apply(ast.literal_eval)

# Extract x, y, z from each column
mni_diff_df[['x1', 'y1', 'z1']] = pd.DataFrame(mni_diff_df['MNI Channel 1'].tolist(), index=mni_diff_df.index)
mni_diff_df[['x2', 'y2', 'z2']] = pd.DataFrame(mni_diff_df['MNI Channel 2'].tolist(), index=mni_diff_df.index)

# Compute differences
mni_diff_df['dx'] = mni_diff_df['x2'] - mni_diff_df['x1']
mni_diff_df['dy'] = mni_diff_df['y2'] - mni_diff_df['y1']
mni_diff_df['dz'] = mni_diff_df['z2'] - mni_diff_df['z1']

mni_diff_df['distance'] = ((mni_diff_df['dx']**2 + mni_diff_df['dy']**2 + mni_diff_df['dz']**2)**0.5)

mni_diff_df.to_csv('P1/P1_S6/mni_diff_with_mni.csv')