# Imports

In [1]:
import numpy as np                     
import pandas as pd
import random as rand
from scipy.signal import correlate, correlation_lags
from scipy.ndimage import gaussian_filter1d

import matplotlib.pyplot as plt       
from matplotlib.patches import FancyArrow, Patch, Circle
import matplotlib.colors as mcolors
from matplotlib.colors import Normalize
import matplotlib.patches as patches
from matplotlib.lines import Line2D

import braingeneers                  
from braingeneers.analysis.analysis import SpikeData, read_phy_files, load_spike_data, burst_detection, randomize_raster
import braingeneers.data.datasets_electrophysiology as ephys
from multiprocessing import Pool
from tqdm import tqdm

# Data

In [2]:
sd = read_phy_files('/workspaces/human_hippocampus/data/ephys/2023-04-02-e-hc328_unperturbed/derived/kilosort2/hc3.28_hckcr1_chip16835_plated34.2_rec4.2_curated.zip')

  sd = read_phy_files('/workspaces/human_hippocampus/data/ephys/2023-04-02-e-hc328_unperturbed/derived/kilosort2/hc3.28_hckcr1_chip16835_plated34.2_rec4.2_curated.zip')


# Code

In [9]:
def calculate_mean_firing_rates(spike_data):
    # Compute mean firing rates for each neuron
    firing_rates = [len(train) / spike_data.length for train in spike_data.train]
    return firing_rates

def get_neuron_positions(spike_data):
    # Extract neuron positions from spike_data
    neuron_x = []
    neuron_y = []
    for neuron in spike_data.neuron_data[0].values():
        neuron_x.append(neuron['position'][0])
        neuron_y.append(neuron['position'][1])
    neuron_positions = np.array([neuron_x, neuron_y]).T
    return neuron_positions

def precalculate_distances_angles(neuron_positions):
    # Vectorized calculation of distances
    diff = neuron_positions[:, np.newaxis, :] - neuron_positions[np.newaxis, :, :]
    distances = np.sqrt(np.sum(diff**2, axis=2))
    
    # Vectorized calculation of angles
    angles = np.arctan2(diff[..., 1], diff[..., 0]) % (2 * np.pi)
    
    return distances, angles

# def calculate_event_ranks(spike_data, precision=5):
#     # Flatten the list of spikes, rounding spike times, and sort by time
#     all_spikes = [(neuron_id, round(spike_time, precision)) for neuron_id, spikes in enumerate(spike_data.train) for spike_time in spikes]
#     all_spikes_sorted = sorted(all_spikes, key=lambda x: x[1])
#     print(len(all_spikes_sorted))

#     # Assign ranks based on sorted order
#     ranks = {spike: rank for rank, spike in enumerate(all_spikes_sorted)}
#     return ranks

# def find_neuron_id_by_rank(event_ranks, target_rank):
#     for (neuron_id, spike_time), rank in event_ranks.items():
#         if rank == target_rank:
#             return neuron_id, spike_time  # Return both neuron_id and spike_time for completeness
#     return None, None

def create_reverse_rank_lookup(event_ranks):
    """
    Create a reverse lookup table for event ranks.
    
    Parameters:
    - event_ranks: A dictionary mapping (neuron_id, spike_time) to event rank.
    
    Returns:
    - A dictionary mapping event rank to (neuron_id, spike_time).
    """
    reverse_lookup = {rank: (neuron_id, spike_time) for (neuron_id, spike_time), rank in event_ranks.items()}
    return reverse_lookup 

def calculate_event_ranks(spike_data, precision=5):
    # Flatten the list of spikes, rounding spike times, and sort by time
    # Include the original index of each spike for uniqueness
    all_spikes = [(neuron_id, round(spike_time, precision), idx) 
                  for neuron_id, spikes in enumerate(spike_data.train) 
                  for idx, spike_time in enumerate(spikes)]
                  
    # Sort by neuron_id, then rounded spike time, then original index
    all_spikes_sorted = sorted(all_spikes, key=lambda x: (x[1], x[0], x[2]))
    
    # Generate a dictionary with event rank as key, (neuron_id, spike_time) as value
    ranks = {rank: (neuron_id, spike_time) for rank, (neuron_id, spike_time, _) in enumerate(all_spikes_sorted)}
    
    print(f"Total unique events: {len(ranks)}")
    
    return ranks

# def calculate_event_ranks2(spike_data, precision=3):
#     all_spikes = [(neuron_id, round(spike_time, precision)) for neuron_id, spikes in enumerate(spike_data.train) for spike_time in spikes]
#     all_spikes_sorted = sorted(all_spikes, key=lambda x: x[1])

#     # Use a list of tuples to maintain original order for duplicates
#     ranks = [(spike, rank) for rank, spike in enumerate(all_spikes_sorted)]
    
#     # Optionally convert to a dictionary if you need to map each spike to its rank uniquely
#     # This step would assign the last occurrence of a duplicate spike its rank
#     ranks_dict = {spike: rank for spike, rank in ranks}

#     return ranks  # or ranks_dict if you prefer the dictionary format

# def calculate_event_ranks3(spike_data, precision=3):
#     all_spikes = [((neuron_id, round(spike_time, precision)), idx) for idx, (neuron_id, spikes) in enumerate(spike_data.train) for spike_time in spikes]
#     all_spikes_sorted = sorted(all_spikes, key=lambda x: (x[0][1], x[1]))  # Sort by rounded spike time, then by original index

#     ranks = {spike: rank for rank, (spike, _) in enumerate(all_spikes_sorted)}
#     return ranks

def precompute_close_neurons(distances, window_size=17.5):
    close_neurons = {}
    for i in range(len(distances)):
        close_neurons[i] = [j for j in range(len(distances)) if i != j and distances[i, j] < window_size]
    return close_neurons

In [36]:
def create_histograms_for_events(spike_data, event_ranks, spatial_range=(82, 1092), time_window_rank=30, bins=6):
    total_events = len(event_ranks)

    histograms_for_each_event = {}

    distances = precalculate_distances_angles(get_neuron_positions(spike_data))[0]
    angles = precalculate_distances_angles(get_neuron_positions(spike_data))[1]

    distance_bins = np.linspace(0, np.max(distances), bins+1)
    angle_bins = np.linspace(0, 2*np.pi, bins+1)

    print_every_n = max(total_events // 100, 1)  # Update progress every 10% or at least once
    
    # Iterate through each event
    for current_event_id in range(total_events):
        
        if current_event_id % print_every_n == 0:
            print(f"Processing event {current_event_id + 1}/{total_events}...")
        # count = 60
        # Determine the rank window of interest
        start_rank = max(0, current_event_id - time_window_rank)
        end_rank = min(total_events, current_event_id + time_window_rank + 1)

        event_distances, event_angles = [], []

        # Only consider events within the rank window
        for other_event_id in range(start_rank, end_rank):
            if other_event_id == current_event_id:
                continue

            current_neuron_id = event_ranks[current_event_id][0][0]
            other_neuron_id = event_ranks[other_event_id][0][0]

            print("Current neuron id and other neuron id")
            print(current_neuron_id, other_neuron_id)

            # Calculate distance and angle between the two neurons
            distance = distances[current_neuron_id, other_neuron_id]
            angle = angles[current_neuron_id, other_neuron_id]

            print(distance)
            if spatial_range[0] < distance < spatial_range[1]:
                event_distances.append(distance)
                event_angles.append(angle)
            
        # Create histograms for the event
        distance_hist, _ = np.histogram(event_distances, bins=distance_bins)
        angle_hist, _ = np.histogram(event_angles, bins=angle_bins)

        histograms_for_each_event[current_event_id] = {
            'distance': distance_hist,
            'angle': angle_hist
        }

    print("Processing complete")
    return histograms_for_each_event



In [37]:
event_ranks = calculate_event_ranks2(sd, precision=3)
neuron_positions = get_neuron_positions(sd)
create_histograms_for_events(sd, event_ranks, spatial_range=(82, 1092), time_window_rank=30, bins=6)

1
Processing event 1/113477...
Current neuron id and other neuron id
(124, 15.2) (89, 16.4)


  for neuron in spike_data.neuron_data[0].values():


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [4]:
event_ranks = calculate_event_ranks(sd, precision=3)
reverse_lookup = create_reverse_rank_lookup(event_ranks)

113477


In [6]:
reverse_lookup

{0: (124, 15.2),
 1: (89, 16.4),
 2: (1, 19.2),
 3: (134, 19.25),
 4: (66, 21.15),
 5: (68, 23.45),
 6: (41, 25.2),
 7: (3, 37.4),
 8: (29, 42.85),
 9: (34, 43.05),
 10: (72, 51.75),
 11: (10, 61.5),
 12: (116, 62.5),
 13: (106, 64.55),
 14: (136, 67.65),
 15: (84, 69.6),
 16: (67, 69.75),
 17: (6, 71.55),
 18: (134, 76.55),
 19: (51, 80.75),
 20: (11, 81.4),
 21: (114, 85.85),
 22: (52, 92.0),
 23: (112, 95.7),
 24: (118, 95.8),
 25: (103, 98.3),
 26: (50, 114.8),
 27: (131, 125.9),
 28: (2, 126.4),
 29: (93, 128.5),
 30: (21, 144.6),
 31: (45, 144.65),
 32: (13, 144.9),
 33: (36, 150.05),
 34: (97, 151.1),
 35: (15, 151.3),
 36: (14, 157.7),
 37: (82, 161.35),
 38: (29, 173.1),
 39: (51, 179.65),
 40: (11, 182.45),
 41: (6, 184.15),
 42: (131, 192.15),
 43: (48, 197.1),
 44: (12, 200.3),
 45: (136, 200.35),
 46: (136, 203.1),
 47: (135, 208.2),
 48: (28, 219.3),
 49: (19, 222.8),
 50: (134, 238.2),
 51: (117, 248.8),
 52: (5, 249.45),
 53: (21, 254.6),
 54: (45, 254.7),
 55: (13, 254

In [7]:
def calculate_event_ranks(spike_data, precision=5):
    # Flatten the list of spikes, rounding spike times, and sort by time
    # Include the original index of each spike for uniqueness
    all_spikes = [(neuron_id, round(spike_time, precision), idx) 
                  for neuron_id, spikes in enumerate(spike_data.train) 
                  for idx, spike_time in enumerate(spikes)]
                  
    # Sort by neuron_id, then rounded spike time, then original index
    all_spikes_sorted = sorted(all_spikes, key=lambda x: (x[1], x[0], x[2]))
    
    # Generate a dictionary with event rank as key, (neuron_id, spike_time) as value
    ranks = {rank: (neuron_id, spike_time) for rank, (neuron_id, spike_time, _) in enumerate(all_spikes_sorted)}
    
    print(f"Total unique events: {len(ranks)}")
    
    return ranks

In [11]:
event_ranks = calculate_event_ranks(sd, precision=3)

Total unique events: 113477


In [12]:
len(event_ranks)

113477

In [38]:
event_ranks = calculate_event_ranks2(sd, precision=3)

In [41]:
event_ranks[1]

((89, 16.4), 1)