# Imports

In [1]:
import numpy as np                     
import pandas as pd
import random as rand
from scipy.signal import correlate, correlation_lags
from scipy.ndimage import gaussian_filter1d

import matplotlib.pyplot as plt       
from matplotlib.patches import FancyArrow, Patch, Circle
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
import matplotlib.patches as patches
from matplotlib.lines import Line2D

import braingeneers                  
from braingeneers.analysis.analysis import SpikeData, read_phy_files, load_spike_data, burst_detection, randomize_raster
import braingeneers.data.datasets_electrophysiology as ephys
from multiprocessing import Pool
from tqdm import tqdm

# Data

In [2]:
sd = read_phy_files('/workspaces/human_hippocampus/data/ephys/2023-04-02-e-hc328_unperturbed/derived/kilosort2/hc3.28_hckcr1_chip16835_plated34.2_rec4.2_curated.zip')

  sd = read_phy_files('/workspaces/human_hippocampus/data/ephys/2023-04-02-e-hc328_unperturbed/derived/kilosort2/hc3.28_hckcr1_chip16835_plated34.2_rec4.2_curated.zip')


# Code

In [3]:
def calculate_mean_firing_rates(spike_data):
    # Compute mean firing rates for each neuron
    firing_rates = [len(train) / spike_data.length for train in spike_data.train]
    return firing_rates

def get_neuron_positions(spike_data):
    # Extract neuron positions from spike_data
    neuron_x = []
    neuron_y = []
    for neuron in spike_data.neuron_data[0].values():
        neuron_x.append(neuron['position'][0])
        neuron_y.append(neuron['position'][1])
    neuron_positions = np.array([neuron_x, neuron_y]).T
    return neuron_positions

def precalculate_distances_angles(neuron_positions):
    # Vectorized calculation of distances
    diff = neuron_positions[:, np.newaxis, :] - neuron_positions[np.newaxis, :, :]
    distances = np.sqrt(np.sum(diff**2, axis=2))
    
    # Vectorized calculation of angles
    angles = np.arctan2(diff[..., 1], diff[..., 0]) % (2 * np.pi)
    
    return distances, angles


def create_reverse_rank_lookup(event_ranks):
    """
    Create a reverse lookup table for event ranks.
    
    Parameters:
    - event_ranks: A dictionary mapping (neuron_id, spike_time) to event rank.
    
    Returns:
    - A dictionary mapping event rank to (neuron_id, spike_time).
    """
    reverse_lookup = {rank: (neuron_id, spike_time) for (neuron_id, spike_time), rank in event_ranks.items()}
    return reverse_lookup 

def calculate_event_ranks(spike_data, precision=5):
    # Flatten the list of spikes, rounding spike times, and sort by time
    # Include the original index of each spike for uniqueness
    all_spikes = [(neuron_id, round(spike_time, precision), idx) 
                  for neuron_id, spikes in enumerate(spike_data.train) 
                  for idx, spike_time in enumerate(spikes)]
                  
    # Sort by neuron_id, then rounded spike time, then original index
    all_spikes_sorted = sorted(all_spikes, key=lambda x: (x[1], x[0], x[2]))
    
    # Generate a dictionary with event rank as key, (neuron_id, spike_time) as value
    ranks = {rank: (neuron_id, spike_time) for rank, (neuron_id, spike_time, _) in enumerate(all_spikes_sorted)}
    
    print(f"Total unique events: {len(ranks)}")
    
    return ranks

def precompute_close_neurons(distances, window_size=17.5):
    close_neurons = {}
    for i in range(len(distances)):
        close_neurons[i] = [j for j in range(len(distances)) if i != j and distances[i, j] < window_size]
    return close_neurons

In [4]:
def create_histograms_for_events(spike_data, event_ranks, spatial_range=(82, 1092), time_window_rank=30, bins=6):
    total_events = len(event_ranks)

    histograms_for_each_event = {}

    distances = precalculate_distances_angles(get_neuron_positions(spike_data))[0]
    angles = precalculate_distances_angles(get_neuron_positions(spike_data))[1]

    distance_bins = np.linspace(0, np.max(distances), bins+1)
    angle_bins = np.linspace(0, 2*np.pi, bins+1)

    print_every_n = max(total_events // 100, 1)  # Update progress every 10% or at least once
    
    for current_event_id in event_ranks:
        if current_event_id % print_every_n == 0:
            print(f"Processing event {current_event_id + 1}/{total_events}...")
            
        start_rank = max(0, current_event_id - time_window_rank)
        end_rank = min(total_events, current_event_id + time_window_rank + 1)

        event_distances, event_angles = [], []

        # Only consider events within the rank window
        for other_event_id in range(start_rank, end_rank):
            if other_event_id == current_event_id:
                continue

            current_neuron_id = event_ranks[current_event_id][0]
            other_neuron_id = event_ranks[other_event_id][0]

            # Lookup distance and angle between the two neurons
            distance = distances[current_neuron_id, other_neuron_id]
            angle = angles[current_neuron_id, other_neuron_id]

            if spatial_range[0] < distance < spatial_range[1]:
                event_distances.append(distance)
                event_angles.append(angle)

        distance_hist, _ = np.histogram(event_distances, bins=distance_bins)
        angle_hist, _ = np.histogram(event_angles, bins=angle_bins)

        histograms_for_each_event[current_event_id] = {
            'distance': distance_hist,
            'angle': angle_hist
        }
    
    print("Processing complete")
    return histograms_for_each_event

In [5]:
event_ranks = calculate_event_ranks(sd)
histograms = create_histograms_for_events(sd, event_ranks)

Total unique events: 113477
Processing event 1/113477...
Processing event 1135/113477...
Processing event 2269/113477...


  for neuron in spike_data.neuron_data[0].values():


Processing event 3403/113477...
Processing event 4537/113477...
Processing event 5671/113477...
Processing event 6805/113477...
Processing event 7939/113477...
Processing event 9073/113477...
Processing event 10207/113477...
Processing event 11341/113477...
Processing event 12475/113477...
Processing event 13609/113477...
Processing event 14743/113477...
Processing event 15877/113477...
Processing event 17011/113477...
Processing event 18145/113477...
Processing event 19279/113477...
Processing event 20413/113477...
Processing event 21547/113477...
Processing event 22681/113477...
Processing event 23815/113477...
Processing event 24949/113477...
Processing event 26083/113477...
Processing event 27217/113477...
Processing event 28351/113477...
Processing event 29485/113477...
Processing event 30619/113477...
Processing event 31753/113477...
Processing event 32887/113477...
Processing event 34021/113477...
Processing event 35155/113477...
Processing event 36289/113477...
Processing event

In [7]:
len(histograms)

113477