Some info about the data:
- the neural data is recorded at 30000 Hz
- the behavioral data is recorded at 40 Hz (we refer to the behavioral sampling times as frames)
- threatening stimuli are usually an air puff delivered simultaneously with an auditory threat (sometimes it's only an auditory threat). The threats are always delivered in the same place at exactly the opposite end of the arena compared to the shelter.
- The video data is collected at 1024x1024 pixels (~10pixels per cm). All positional information is in pixels.

In [1]:
import numpy as np
import os
import polars as pl
import dill as pickle
import socket
import tkinter as tk
from tkinter import filedialog

In [7]:
"""The paths to all the sessions"""
# base path: identifies where you have mapped ceph onto your computer
# all_paths: a list of the paths to specific experimental sessions 
# NB: some paths are commented out! This is because we are migrating some of our data between servers right now. It should be available on ceph in a few days.

def get_computer_specific_paths():
    root = tk.Tk()
    root.withdraw()  # Hide the main window
    base_path = filedialog.askdirectory(title="Select Directory in which you have mounted Branco lab ceph")
    return base_path

base_path = get_computer_specific_paths()

all_paths = ['Jasmine_Laurence/Experimental_Data/JAL004/JAL004_flip_rotated_2023_08_28T09_36_04',
             'Jasmine_Laurence/Experimental_Data/JAL004/004_flip_2023_09_03T12_04_16',
             'Jasmine_Laurence/Experimental_Data/JAL004/004_flip_puff2_2023_09_11T09_32_25',
             'Jasmine_Laurence/Experimental_Data/JAL004/004_flipppuf19sept_2023_09_19T14_10_56',
             'Jasmine_Laurence/Experimental_Data/JAL005/005_baseline_2023_09_05T07_48_58',
             'Jasmine_Laurence/Experimental_Data/JAL005/005_flip1_2023_09_08T07_36_54',
             'Jasmine_Laurence/Experimental_Data/JAL005/005_flippuff3_2023_09_21T11_11_13',
            #  'Jasmine_Laurence/Experimental_Data/JAL006/JAL006_barrier_flip3_2024_03_18T11_53_29',
            #  'Jasmine_Laurence/Experimental_Data/JAL006/JAL006_shelter_barrier_flip_3_2024_03_21T11_20_34',
            #  'Jasmine_Laurence/Experimental_Data/JAL006/JAL006_shelter_barrier_flip_5_2024_03_25T11_05_33',
            #  'Jasmine_Laurence/Experimental_Data/JAL006/JAL006_shelter_barrier_flip_6_2024_03_28T10_54_20',
             'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_barrierflip2_2024_03_12T11_18_26',
             'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_shelter_barrier_flip_5_2024_03_22T11_15_43',
            #  'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_shelter_barrier_flip_8_2024_04_09T10_07_45',
            #  'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_shelter_barrier_flip_9_2024_04_16T11_13_05',
            #  'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_shelter_barrier_flip_100_2024_04_23T09_59_40',
            #  'Jasmine_Laurence/Experimental_Data/JAL007/JAL007_tinnybarrier1_2024_04_30T10_57_04',
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_barrier_flip_1_2024_04_25T11_27_42',
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_barrier_flip_2_2024_04_29T12_14_54',,
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_tiny_barrier_flip_1_2024_05_03T10_02_35'
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_barrier_flip_3_2024_05_07T10_16_26',
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_barrier_flip_4_2024_05_10T11_47_47',
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_barrier_flip_5_2024_05_14T10_18_03',
            #  'Jasmine_Laurence/Experimental_Data/JAL008/JAL008_shelter_tiny_flip_2_2024_05_21T11_10_19'
             ]


In [8]:
print(base_path)

/Volumes/branco


In [9]:
"""Choose an experiment to load"""

idx = 4
experiment_path = all_paths[idx]

In [None]:
"""Get data paths and session metadata for a given session"""
# session object has a lot of information about the recording. The following could be useful to you:
# session.audio.onset_frames is a list (sometimes it's a list of lists sorry!) of times in behavioral frames of when the threat was delivered
# session.shelter_location gives you the xy position of the top left and bottom right corners of the shelter
# session.barrier_time gives you the time in minutes when the barrier was first placed in the arena (you probably want everything before this time)

with open(os.path.join(base_path, experiment_path, "processed_data", "metadata"), "rb") as dill_file: 
    session = pickle.load(dill_file)

"""Produces ModuleNotFound error: no module behave_analysis"""

In [12]:
""""Load time x neurons matrix"""
# this will give you the firing rate of each putative single unit at each frame (behavioral sampling timepoint)

frame_by_cluster_matrix = np.load(os.path.join(base_path, experiment_path, "processed_data") + "/" + "frame_by_good_cluster_matrix.npy")

In [15]:
"""Load spikes dataframe
An alternative way to load in neural data that gives you individual spike times for both single and multiunit clusters
"""

# loads in a polars dataframe with the following columns
# spike_df['aligned_spike_times']: IGNORE (spike times in behavioral computer clock)
# spike_df['spike_aligned_to_frame']: the frame (behavioral sampling timepoint) that a given spike was recorded on
# spike_df['aligned_spike_times_in_samples']: IGNORE (spike times in behavioral computer clock)
# spike_df['spike_times']: IGNORE  (spike time in recording computer clock)
# spike_df['spike_clusters']: the cluster that a given spike belongs to
# spike_df['cluster_group']: the spikesorting classification of a given cluster as 'good' (putative single unit), 'mua' (multiunit activity), 'noise' (noise cluster)

spike_df = pl.read_csv(os.path.join(base_path, experiment_path, "processed_data", "Processed_efizz_data"))

In [16]:
"""Load behavioral dataframe"""

# These are the columns of the dataframe
# -- frames: the behavioral frame number (this will match the rows of the time x neurons matrix and the frames in spike_df['spike_aligned_to_frame'])
# -- hdir: the mouse's head direction 
# -- hsa: the mouse's head-shelter angle
# -- mouse_x_position
# -- mouse_y_position
# -- OutofshelterIdx (bool): If True the mouse was outside the shelter
# -- EscapePeriod (bool): If True the mouse is performing an escape
# -- shelter (bool): If True the shelter is present in the arena
# -- barrier_present (bool): If True the barrier is present in the arena (to only look at times when there is no barrier you can use this to fin the frames before this becomes True)
# -- barrier_flipped (bool): If True the barrier has been flipped 180deg
# -- speed
# -- homingPeriod (bool): If True the mouse is performing a homing run
# -- h_preflipbar_a: IGNORE: the angle between the mouse's head and the open side of the barrier before flipping it
# -- h_postflipbar_a: IGNORE: the angle between the mouse's head and the open side of the barrier after flipping it
# -- h_bar_centre_a: IGNORE: the angle between the mouse's head and the centre of the barrier (also the centre of the arena)

video_df = pl.read_csv(os.path.join(base_path, experiment_path, "processed_data", "full_video_dataframe.csv"))

In [None]:
# Filter video_df based on the given conditions
filtered_video_df = video_df.filter(
    (pl.col("OutofshelterIdx") == True) &
    (pl.col("EscapePeriod") == False) &
    (pl.col("shelter") == True) &
    (pl.col("barrier_present") == False) &
    (pl.col("homingPeriod") == False)
)

# Create a new dataframe df by joining frame_by_cluster_matrix with the filtered video_df
# First, we need to convert frame_by_cluster_matrix to a DataFrame
frame_by_cluster_df = pl.DataFrame(frame_by_cluster_matrix)

# Add a 'frames' column to frame_by_cluster_df to match with video_df
frame_by_cluster_df = frame_by_cluster_df.with_columns((pl.Series("frames", np.arange(len(frame_by_cluster_df)))))

# Unpivot the frame_by_cluster_df to have 'cluster' and 'firing_rate' columns

frame_by_cluster_df = frame_by_cluster_df.unpivot(index=["frames"], on=frame_by_cluster_df.columns[:-1], variable_name="cluster", value_name="firing_rate")

# # Join the filtered video_df with frame_by_cluster_df on 'frames'
df = filtered_video_df.join(frame_by_cluster_df, on="frames")

# Select the required columns
df = df.select([
    "frames",
    "hdir",
    "hsa",
    "mouse_x_position",
    "mouse_y_position",
    "speed",
    "cluster",
    "firing_rate"
])

# print(df)

In [None]:
print(len(df))
df.head()

121037648


frames,hdir,mouse_x_position,mouse_y_position,speed,OutofshelterIdx,EscapePeriod,shelter,barrier_present,barrier_flipped,hsa,h_preflipbar_a,h_postflipbar_a,h_bar_centre_a,head_randP_0,head_randP_1,head_randP_2,head_randP_3,head_randP_4,head_randP_5,head_randP_6,head_randP_7,head_randP_8,head_randP_9,head_randP_10,head_randP_11,head_randP_12,head_randP_13,head_randP_14,head_randP_15,head_randP_16,head_randP_17,head_randP_18,head_randP_19,head_randP_20,head_randP_21,head_randP_22,…,head_randP_128,head_randP_129,head_randP_130,head_randP_131,head_randP_132,head_randP_133,head_randP_134,head_randP_135,head_randP_136,head_randP_137,head_randP_138,head_randP_139,head_randP_140,head_randP_141,head_randP_142,head_randP_143,head_randP_144,head_randP_145,head_randP_146,head_randP_147,head_randP_148,head_randP_149,head_randP_150,head_randP_151,head_randP_152,head_randP_153,head_randP_154,head_randP_155,head_randP_156,head_randP_157,head_randP_158,head_randP_159,head_randP_160,head_randP_161,head_randP_162,head_randP_163,homingPeriod
i64,f64,f64,f64,f64,bool,bool,bool,bool,bool,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,bool
1,3.084343,402.004479,233.14482,9.4593e-11,True,False,False,False,False,1.837256,2.650954,2.652756,2.651856,-1.398054,-1.901932,-2.265017,-2.48365,-2.617617,-2.705195,-0.329358,-0.580166,-1.273919,-2.224812,-2.614168,-2.769816,-2.849516,-2.897372,-2.929155,-2.951757,0.087011,0.100397,0.135626,0.466782,3.076388,-3.138186,-3.118831,…,1.723985,1.844362,1.958716,2.064777,2.161312,2.247995,2.325151,2.393492,1.287719,1.388567,1.494619,1.603766,1.713492,1.821202,1.924573,2.021821,2.111818,2.194056,2.268528,2.335571,1.411515,1.507701,1.60617,1.705064,1.802476,1.896664,1.986222,2.070165,2.147941,2.219368,1.608141,1.698147,1.787032,1.87346,1.95631,2.034743,False
2,3.084343,402.004479,233.14482,9.5114e-11,True,False,False,False,False,1.837256,2.650954,2.652756,2.651856,-1.398054,-1.901932,-2.265017,-2.48365,-2.617617,-2.705195,-0.329358,-0.580166,-1.273919,-2.224812,-2.614168,-2.769816,-2.849516,-2.897372,-2.929155,-2.951757,0.087011,0.100397,0.135626,0.466782,3.076388,-3.138186,-3.118831,…,1.723985,1.844362,1.958716,2.064777,2.161312,2.247995,2.325151,2.393492,1.287719,1.388567,1.494619,1.603766,1.713492,1.821202,1.924573,2.021821,2.111818,2.194056,2.268528,2.335571,1.411515,1.507701,1.60617,1.705064,1.802476,1.896664,1.986222,2.070165,2.147941,2.219368,1.608141,1.698147,1.787032,1.87346,1.95631,2.034743,False
3,3.084343,402.004479,233.14482,9.6938e-11,True,False,False,False,False,1.837256,2.650954,2.652756,2.651856,-1.398054,-1.901932,-2.265017,-2.48365,-2.617617,-2.705195,-0.329358,-0.580166,-1.273919,-2.224812,-2.614168,-2.769816,-2.849516,-2.897372,-2.929155,-2.951757,0.087011,0.100397,0.135626,0.466782,3.076388,-3.138186,-3.118831,…,1.723985,1.844362,1.958716,2.064777,2.161312,2.247995,2.325151,2.393492,1.287719,1.388567,1.494619,1.603766,1.713492,1.821202,1.924573,2.021821,2.111818,2.194056,2.268528,2.335571,1.411515,1.507701,1.60617,1.705064,1.802476,1.896664,1.986222,2.070165,2.147941,2.219368,1.608141,1.698147,1.787032,1.87346,1.95631,2.034743,False
4,3.084343,402.004479,233.14482,1.0079e-10,True,False,False,False,False,1.837256,2.650954,2.652756,2.651856,-1.398054,-1.901932,-2.265017,-2.48365,-2.617617,-2.705195,-0.329358,-0.580166,-1.273919,-2.224812,-2.614168,-2.769816,-2.849516,-2.897372,-2.929155,-2.951757,0.087011,0.100397,0.135626,0.466782,3.076388,-3.138186,-3.118831,…,1.723985,1.844362,1.958716,2.064777,2.161312,2.247995,2.325151,2.393492,1.287719,1.388567,1.494619,1.603766,1.713492,1.821202,1.924573,2.021821,2.111818,2.194056,2.268528,2.335571,1.411515,1.507701,1.60617,1.705064,1.802476,1.896664,1.986222,2.070165,2.147941,2.219368,1.608141,1.698147,1.787032,1.87346,1.95631,2.034743,False
5,3.084343,402.004479,233.14482,1.077e-10,True,False,False,False,False,1.837256,2.650954,2.652756,2.651856,-1.398054,-1.901932,-2.265017,-2.48365,-2.617617,-2.705195,-0.329358,-0.580166,-1.273919,-2.224812,-2.614168,-2.769816,-2.849516,-2.897372,-2.929155,-2.951757,0.087011,0.100397,0.135626,0.466782,3.076388,-3.138186,-3.118831,…,1.723985,1.844362,1.958716,2.064777,2.161312,2.247995,2.325151,2.393492,1.287719,1.388567,1.494619,1.603766,1.713492,1.821202,1.924573,2.021821,2.111818,2.194056,2.268528,2.335571,1.411515,1.507701,1.60617,1.705064,1.802476,1.896664,1.986222,2.070165,2.147941,2.219368,1.608141,1.698147,1.787032,1.87346,1.95631,2.034743,False


In [None]:
print(np.min(df['hdir']), np.max(df['hdir']))
print(np.min(df['hsa']), np.max(df['hsa']))
print(np.min(df['mouse_x_position']), np.max(df['mouse_x_position']))
print(np.min(df['mouse_y_position']), np.max(df['mouse_y_position']))
print(np.min(df['speed']), np.max(df['speed']))
print(np.min(df['firing_rate']), np.max(df['firing_rate']))
print(len(np.unique(df['cluster'])))

In [None]:
import polars as pl
import numpy as np

def compute_mutual_information(df: pl.DataFrame) -> pl.DataFrame:
    """
    Compute I(R_i, theta) for each neuron in the given DataFrame.
    Columns required: ['neuron_id', 'firing_rate', 'heading', 'time'].
    Returns a Polars DataFrame with columns ['neuron_id', 'mutual_information'].
    """

    # 1) Identify all unique neuron IDs
    neuron_ids = df.select(pl.col("neuron_id").unique()).to_series().to_list()

    # We'll store the results in a list of dicts to later convert to a DataFrame
    results = []

    # We choose 100 bins for Z-scored rates and 100 bins in [0, 2π) for heading
    n_bins_z = 100
    n_bins_theta = 100

    # Precompute bin edges for heading, covering [0, 2π)
    theta_edges = np.linspace(0, 2*np.pi, n_bins_theta + 1)

    for neuron_id in neuron_ids:
        # ---------------------------------------------------
        # Step A: Extract data for a single neuron
        # ---------------------------------------------------
        df_i = df.filter(pl.col("neuron_id") == neuron_id)
        # Convert to NumPy for convenience in some steps:
        firing_rates = df_i.select("firing_rate").to_series().to_numpy()
        headings = df_i.select("heading").to_series().to_numpy()

        # ---------------------------------------------------
        # Step 1: Z-score the firing rate for this neuron
        # ---------------------------------------------------
        mean_r = firing_rates.mean()
        std_r = firing_rates.std()
        # Guard against zero std (constant firing rate) if it can occur
        if std_r == 0:
            # If std is zero, all firing rates are the same => no information
            # We'll just skip or set MI to 0.0
            mi_value = 0.0
            results.append({"neuron_id": neuron_id, "mutual_information": mi_value})
            continue

        z_scores = (firing_rates - mean_r) / std_r

        # ---------------------------------------------------
        # Step 2 & 3: Discretize z_scores and heading
        # ---------------------------------------------------
        # For z_scores, we'll choose bin edges from min to max
        z_min, z_max = np.min(z_scores), np.max(z_scores)
        # If z_min == z_max (pathological), handle it
        if z_min == z_max:
            mi_value = 0.0
            results.append({"neuron_id": neuron_id, "mutual_information": mi_value})
            continue

        z_edges = np.linspace(z_min, z_max, n_bins_z + 1)

        # Digitize (bin) each value -> returns bin indices in [1..n_bins]
        z_bins = np.digitize(z_scores, bins=z_edges) - 1  # shift to [0..n_bins-1]
        # For heading we use the fixed theta_edges
        theta_bins = np.digitize(headings, bins=theta_edges) - 1
        # Make sure no bin index is out of range (edge case if values == max)
        z_bins[z_bins == n_bins_z] = n_bins_z - 1
        theta_bins[theta_bins == n_bins_theta] = n_bins_theta - 1

        # ---------------------------------------------------
        # Step 4,5,6: Compute probabilities P(z_i), P(theta), P(z_i, theta)
        # ---------------------------------------------------
        # We have as many samples as time frames in df_i
        n_samples = len(df_i)

        # Count how many times each z-bin occurs
        # (z_bins is an array of length n_samples with integers in [0..n_bins_z-1]).
        z_counts = np.bincount(z_bins, minlength=n_bins_z)
        # Probability distribution for z
        p_z = z_counts / n_samples

        # Similarly for theta
        theta_counts = np.bincount(theta_bins, minlength=n_bins_theta)
        # Probability distribution for theta
        p_theta = theta_counts / n_samples

        # For the joint distribution, we can do a 2D histogram
        # or manually accumulate counts in a loop:
        joint_counts_2d = np.zeros((n_bins_z, n_bins_theta), dtype=np.float64)

        for zb, tb in zip(z_bins, theta_bins):
            joint_counts_2d[zb, tb] += 1
        p_ztheta = joint_counts_2d / n_samples

        # ---------------------------------------------------
        # Step 7: Compute mutual information
        # I(R_i, theta) = sum_z sum_theta p(z,theta) log2 [ p(z,theta) / (p(z)*p(theta)) ]
        # We'll skip terms with p(z,theta) = 0 to avoid log(0).
        # Also skip if p(z) = 0 or p(theta) = 0 (which can happen if a bin is empty).
        # Terms with zero are effectively zero contribution in the sum.
        # ---------------------------------------------------
        mi_value = 0.0
        # Vector forms:
        for zb in range(n_bins_z):
            for tb in range(n_bins_theta):
                p_zt = p_ztheta[zb, tb]
                if p_zt > 0 and p_z[zb] > 0 and p_theta[tb] > 0:
                    mi_value += p_zt * np.log2(p_zt / (p_z[zb] * p_theta[tb]))

        # Store the result for this neuron
        results.append({"neuron_id": neuron_id, "mutual_information": mi_value})

    # Convert results to a Polars DataFrame
    mi_df = pl.DataFrame(results)
    return mi_df


# ---------------------------------------------------------------------------
# Example usage:
# ---------------------------------------------------------------------------
# Suppose you have a Polars DataFrame df with columns:
# ['time', 'neuron_id', 'firing_rate', 'heading']
# you can then compute:

# mi_results_df = compute_mutual_information(df)
# print(mi_results_df)

# This will produce a Polars DataFrame of the form:
# ┌───────────┬───────────────────────┐
# │ neuron_id │ mutual_information   │
# │ ---       │ ---                 │
# │ i64       │ f64                 │
# ╞═══════════╪═════════════════════╡
# │ 0         │ 1.234567            │
# │ 1         │ 0.765432            │
# │ ...       │ ...                 │
# └───────────┴───────────────────────┘


0.0 400.0
