<font size=7> Basic Analysis

This notebook contains the code used to run some basic sanity checks on a selected dataset.

<font color="red">

**To Do**
* remove the unnecesary imports
* for text summary, check that Tom's code produces same result as Aris code. You only have to check this for one file.

# Import stuff

In [None]:
print("\tLoading analysis source code...")

In [1]:
import fcm
import os
import re
import glob
import random
import numpy as np
import scipy
import scipy.io as sio
import scipy.ndimage as ndimage
from scipy.ndimage import gaussian_filter1d
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
from matplotlib.patches import Patch, Circle
from matplotlib.lines import Line2D
from PIL import Image
import ipywidgets as ipw
from ipywidgets import interact, interactive, fixed, interact_manual  # package for interactive widgets 
import braingeneers                                                   # Braingeneers code
from braingeneers.analysis.analysis import SpikeData, read_phy_files
import braingeneers.data.datasets_electrophysiology as ephys
from IPython.display import HTML, display, Javascript, clear_output

# Create text Summary

In [2]:
data_folders = os.listdir("/home/jovyan/data/ephys")[1:]

## <font color="gray">Helper function to `analyze_spike_data`

In [None]:
def analyze_spike_data(sd):
    idces_control, times_control = sd.idces_times()
    n_neurons_control = len(sd.rates())

    print("Number of spikes: ", len(idces_control))
    print("Length: ", int(times_control[-1]/1000), "seconds")
    print("Number of Neurons: ", n_neurons_control)
    entire_firing_rate_control = len(idces_control) / (times_control[-1] / 1000)
    avg_rate_control = entire_firing_rate_control / n_neurons_control
    print("Average Firing Rate: ", round(avg_rate_control, 2))

    isis_raw = sd.interspike_intervals()
    # Remove all isi's greater than 100ms. As there are likely neurons not following periodic firing pattern
    isis = []
    for i in range(len(isis_raw)):
        isi = isis_raw[i]
        isis = isis + isi[isi < 100].tolist()

    isi_mean = sum(isis) / len(isis)
    isi_var = sum([((x - isi_mean) ** 2) for x in isis]) / len(isis)
    isi_std = isi_var ** 0.5
    cv = isi_std / isi_mean
    print("Coefficient of Variation: ", round(cv,3) )

## Main code

In [3]:
def GetTextSummary(folder_name):
    path = f"/home/jovyan/data/ephys/{folder_name}/derived/kilosort2/"
    file_extension = "_curated.zip"
    spike_data_objects = {}  # Dictionary to store spike data objects

    def get_last_digit(s):
        return int(''.join(filter(str.isdigit, s))[-1])

    for filename in os.listdir(path):
        if filename.endswith(file_extension):
            file_path = os.path.join(path, filename)
            try:
                sd = read_phy_files(file_path)
                sd.original_file = filename
                # Generating the unique name for the spike data object based on the last digit of the filename
                sd_name = "sd_" + str(get_last_digit(filename))
                spike_data_objects[sd_name] = sd
            except:
                print(f"WARNING: Unable to Read < {filename} >")
    print("-----------------------------")       
    for sd_name in sorted(spike_data_objects.keys(), key=lambda x: int(x.split("_")[-1])):
        sd_object = spike_data_objects[sd_name]
        #print(f"Analyzing {sd_name}...")
        print(f"Filename: {sd_object.original_file}:")
        analyze_spike_data(sd_object)
        print("-----------------------------")

# Create a Plot

## <font color="gray"> helper code

Get firing rates of individual neurons:

In [None]:
def calculate_mean_firing_rates(spike_data):
    mean_firing_rates = []
    for neuron_spikes in spike_data.train:
        num_spikes = len(neuron_spikes)
        time_duration = spike_data.length / 1000  # Assuming spike times are in milliseconds
        firing_rate = num_spikes / time_duration
        mean_firing_rates.append(firing_rate)

    return np.array(mean_firing_rates)

Interspike Interval of spikedata

In [None]:
def ISI(sd):
    # Interspike-intervals of 2 select neurons
    isis_raw = sd.interspike_intervals()
    # Remove all isi's greater than 100ms. As there are likely neurons not following periodic firing pattern
    isis=[]
    for i in range(len(isis_raw)):   
        isi=isis_raw[i]
        isis = isis + isi[isi<100].tolist() 
        
    return isis

Interspike Interval of individual Neural units

In [None]:
def IndivISI(sd, neuron):
    # Interspike-intervals of individual neuron
    neuronISIs = sd.interspike_intervals()[neuron];
    isis = []
    
    for i in range(len(neuronISIs)):
        if neuronISIs[i]<100:
            isis.append(neuronISIs[i])
            
    return isis

## Main function

Get data filenames:

In [13]:
filenames = []
data_folders = !ls /home/jovyan/data/ephys
for folder in data_folders:
    files = !ls /home/jovyan/data/ephys/{folder}/derived/kilosort2/*curated*
    filenames += files
    #print()

Main plotting function:

In [None]:
def createPlots(filename):
    global sd
    # creates plots for spikedata analysis including: ISI hist, firing rate hist and layout, and Spikeraster of first 30 seconds
    sd = read_phy_files(filename)
    firing_rates = calculate_mean_firing_rates(sd)
    seconds=30 # seconds to display raster
    neuron_x = []
    neuron_y = []
    
    for neuron in sd.neuron_data[0].values():
        neuron_x.append(neuron["position"][0])
        neuron_y.append(neuron["position"][1])
    
    # Plot main figure --------------------------------------------------------------------
    figs, plots = plt.subplots(nrows=2,ncols=2,figsize=(12,12))
    figs.suptitle(f"Plots of recording: {filename}", ha="center")
    
    # Plot ISI Histogram subplot
    plots[0,0].hist(ISI(sd), bins=50);
    plots[0,0].set_title("Interspike Interval of Recording")
    plots[0,0].set_xlabel("Time bin(ms)")
    plots[0,0].set_ylabel("ISI count")
    
    # Plot Firing Rates Histogram subplot
    plots[0,1].hist(firing_rates);
    plots[0,1].set_title("Average Firing Rate for Neural Units") 
    plots[0,1].set_xlabel("Firing Rate(ms)")
    plots[0,1].set_ylabel("Unit Count") 
    
    # Plot Neuron Firing Rate Layout subplot
    plots[1,0].scatter(neuron_x, neuron_y, s=firing_rates*100, c="red", alpha=0.3)
    #plots[1,0].scatter(neuron_x, neuron_y, s=(2**firing_rates)*10, c="red", alpha=0.3)
    plots[1,0].set_title("Neuron Firing Rate Across MEA")
    plots[1,0].set_xlabel("um")
    plots[1,0].set_ylabel("um")
    #plots[3] = Firing_Rate_Layout(sd);
    
    
    # Plot Raster with plotted firing rate over time subplot
    # Zoomed Raster and pop rate
    # Get coordinates for raster
    idces, times = sd.idces_times()
    
    # Get population rate for everything
    pop_rate = sd.binned(bin_size=1)# in ms
    # Lets smooth this to make it neater
    sigma = 5
    pop_rate_smooth = gaussian_filter1d(pop_rate.astype(float),sigma=sigma) 
    t = np.linspace(0,sd.length,pop_rate.shape[0])/1000
    
    plots[1,1].scatter(times/1000,idces,marker='|',s=1)
    plots2 = plots[1,1].twinx()
    plots2.plot(t,pop_rate_smooth,c='r')

    plots[1,1].set_xlim(0,seconds)
    plots[1,1].set_title("Spike Raster Analysis")
    plots[1,1].set_xlabel("Time(s)")
    plots[1,1].set_ylabel("Unit #")
    plots2.set_ylabel("Firing Rate")
    
    # Plot second figure ------------------------------------------------------------------
    figs2, axs = plt.subplots(nrows=2,ncols=4,figsize=(30,10)) 
    figs2.suptitle(f"Interspike Interval of Individual Neural Units of File {filename}")
    
    for i in range(8): # Plot individual ISI figures
        if(i < sd.N):
            if i < 4: # First Row
                axs[0,i].hist(IndivISI(sd, i))
                axs[0,i].set_title(f"Interspike Interval of Neural Unit {i}")
                axs[0,i].set_xlabel("Time bin(ms)")
                axs[0,i].set_ylabel("ISI count")
            else: # Second Row
                axs[1,i-4].hist(IndivISI(sd, i))
                axs[1,i-4].set_title(f"Interspike Interval of Neural Unit {i}")
                axs[1,i-4].set_xlabel("Time bin(ms)")
                axs[1,i-4].set_ylabel("ISI count")
        else: # Print warning title in case neuron count is uner 8
            figs2.suptitle(f"Interspike Interval of Individual Neural Units of File {filename}\n Note: Neuron Count Under 8 ({sd.N})")
    

# Deeper Analysis

## <Font color='grey'>Helper Functions

In [14]:
def correlation(sd):
    # Correlation
    corr = np.zeros((sd.N,sd.N)) #inds by inds
    
    dense_raster = sd.raster(bin_size=1) # in ms
    sigma = 5                            # Blur it
    dense_raster = gaussian_filter1d(dense_raster.astype(float),sigma=sigma)
    corr=np.corrcoef( dense_raster )
        
    return corr;

## Main Function

In [26]:
def DeeperAnalysis(filename):
    # Plots three plots for Spikedata analysis, STTC and Correlation matrices along with Functional Connectivity Map
    # set up
    sd = read_phy_files(filename)
    STTC = sd.spike_time_tilings()
    Corr = correlation(sd)
    
    # Mosaic Layout
    figLayout = """AB"""
                    
    fig, plots = plt.subplot_mosaic(figLayout, figsize=(12,10))
    
    # subplot of STTC -----------------------------------------------------
    pltA = plots["A"].imshow(STTC)
    plots["A"].set_title("STTC")
    plots["A"].set_xlabel("unit")
    plots["A"].set_ylabel("unit")
    
    fig.colorbar(pltA, ax=plots["A"], shrink=0.3)
    
    
    # subplot of Correlation ----------------------------------------------
    pltB = plots["B"].imshow(Corr)
    plots["B"].set_title("Correlation")
    plots["B"].set_xlabel("unit")
    plots["B"].set_ylabel("unit")
    
    fig.colorbar(pltB, ax=plots["B"], shrink=0.3)
    
    #subplot of functional connectivity -----------------------------------
    fcm.FCM_Plotter(filename, 0, sd.length/1000, "Functional Connectivity Map", saved="no")
    

# Load Data

`read_phy_files` is the function we use to load data. The funciton currently (9/12/23) causes an error on braingeneers, so we use an older version

In [None]:
import io
import zipfile
from typing import List, Tuple

def read_phy_files(path: str, fs=20000.0):
    """
    :param path: a s3 or local path to a zip of phy files.
    :return: SpikeData class with a list of spike time lists and neuron_data.
            neuron_data = {0: neuron_dict, 1: config_dict}
            neuron_dict = {"new_cluster_id": {"channel": c, "position": (x, y),
                            "amplitudes": [a0, a1, an], "template": [t0, t1, tn],
                            "neighbor_channels": [c0, c1, cn],
                            "neighbor_positions": [(x0, y0), (x1, y1), (xn,yn)],
                            "neighbor_templates": [[t00, t01, t0n], [tn0, tn1, tnn]}}
            config_dict = {chn: pos}
    """
    assert path[-3:] == 'zip', 'Only zip files supported!'
    import braingeneers.utils.smart_open_braingeneers as smart_open
    with smart_open.open(path, 'rb') as f0:
        f = io.BytesIO(f0.read())

        with zipfile.ZipFile(f, 'r') as f_zip:
            assert 'params.py' in f_zip.namelist(), "Wrong spike sorting output."
            with io.TextIOWrapper(f_zip.open('params.py'), encoding='utf-8') as params:
                for line in params:
                    if "sample_rate" in line:
                        fs = float(line.split()[-1])
            clusters = np.load(f_zip.open('spike_clusters.npy')).squeeze()
            templates = np.load(f_zip.open('templates.npy'))  # (cluster_id, samples, channel_id)
            channels = np.load(f_zip.open('channel_map.npy')).squeeze()
            templates_w = np.load(f_zip.open('templates.npy'))
            wmi = np.load(f_zip.open('whitening_mat_inv.npy'))
            spike_templates = np.load(f_zip.open('spike_templates.npy')).squeeze()
            spike_times = np.load(f_zip.open('spike_times.npy')).squeeze() / fs * 1e3  # in ms
            positions = np.load(f_zip.open('channel_positions.npy'))
            amplitudes = np.load(f_zip.open("amplitudes.npy")).squeeze()
            if 'cluster_info.tsv' in f_zip.namelist():
                cluster_info = pd.read_csv(f_zip.open('cluster_info.tsv'), sep='\t')
                cluster_id = np.array(cluster_info['cluster_id'])
                # select clusters using curation label, remove units labeled as "noise"
                # find the best channel by amplitude
                labeled_clusters = cluster_id[cluster_info['group'] != "noise"]
            else:
                labeled_clusters = np.unique(clusters)

    df = pd.DataFrame({"clusters": clusters, "spikeTimes": spike_times, "amplitudes": amplitudes})
    cluster_agg = df.groupby("clusters").agg({"spikeTimes": lambda x: list(x),
                                              "amplitudes": lambda x: list(x)})
    cluster_agg = cluster_agg[cluster_agg.index.isin(labeled_clusters)]

    cls_temp = dict(zip(clusters, spike_templates))
    neuron_dict = dict.fromkeys(np.arange(len(labeled_clusters)), None)

    # un-whitten the templates before finding the best channel
    templates = np.dot(templates_w, wmi)

    neuron_attributes = []
    for i in range(len(labeled_clusters)):
        c = labeled_clusters[i]
        temp = templates[cls_temp[c]]
        amp = np.max(temp, axis=0) - np.min(temp, axis=0)
        sorted_idx = [ind for _, ind in sorted(zip(amp, np.arange(len(amp))))]
        nbgh_chan_idx = sorted_idx[::-1][:12]
        nbgh_temps = temp.transpose()[nbgh_chan_idx]
        best_chan_temp = nbgh_temps[0]
        nbgh_channels = channels[nbgh_chan_idx]
        nbgh_postions = [tuple(positions[idx]) for idx in nbgh_chan_idx]
        best_channel = nbgh_channels[0]
        best_position = nbgh_postions[0]
        # neighbor_templates = dict(zip(nbgh_postions, nbgh_temps))
        cls_amp = cluster_agg["amplitudes"][c]
        neuron_dict[i] = {"cluster_id": c, "channel": best_channel, "position": best_position,
                          "amplitudes": cls_amp, "template": best_chan_temp,
                          "neighbor_channels": nbgh_channels, "neighbor_positions": nbgh_postions,
                          "neighbor_templates": nbgh_temps}
        neuron_attributes.append(
            NeuronAttributes(
                cluster_id=c,
                channel=best_channel,
                position=best_position,
                amplitudes=cluster_agg["amplitudes"][c],
                template=best_chan_temp,
                templates=templates[cls_temp[c]].T,
                label=cluster_info['group'][cluster_info['cluster_id'] == c].values[0],
                neighbor_channels=channels[nbgh_chan_idx],
                neighbor_positions=[tuple(positions[idx]) for idx in nbgh_chan_idx],
                neighbor_templates=[templates[cls_temp[c]].T[n] for n in nbgh_chan_idx]
            )
        )

    config_dict = dict(zip(channels, positions))
    neuron_data = {0: neuron_dict}
    metadata = {0: config_dict}
    spikedata = SpikeData(list(cluster_agg["spikeTimes"]), neuron_data=neuron_data, metadata=metadata, neuron_attributes=neuron_attributes)
    return spikedata

class NeuronAttributes:
    cluster_id: int
    channel: np.ndarray
    position: Tuple[float, float]
    amplitudes: List[float]
    template: np.ndarray
    templates: np.ndarray
    label: str

    # These lists are the same length and correspond to each other
    neighbor_channels: np.ndarray
    neighbor_positions: List[Tuple[float, float]]
    neighbor_templates: List[np.ndarray]

    def __init__(self, *args, **kwargs):
        self.cluster_id = kwargs.pop("cluster_id")
        self.channel = kwargs.pop("channel")
        self.position = kwargs.pop("position")
        self.amplitudes = kwargs.pop("amplitudes")
        self.template = kwargs.pop("template")
        self.templates = kwargs.pop("templates")
        self.label = kwargs.pop("label")
        self.neighbor_channels = kwargs.pop("neighbor_channels")
        self.neighbor_positions = kwargs.pop("neighbor_positions")
        self.neighbor_templates = kwargs.pop("neighbor_templates")
        for key, value in kwargs.items():
            setattr(self, key, value)

    def add_attribute(self, key, value):
        setattr(self, key, value)

    def list_attributes(self):
        return [attr for attr in dir(self) if not attr.startswith('__') and not callable(getattr(self, attr))]


## <font color='brown'>test main function

In [28]:
#file = '/home/jovyan/data/ephys/2023-05-09-e-hc52_18763/derived/kilosort2/hc52_18763_rec05092023_12_curated.zip'
#sd = read_phy_files(file)

In [29]:
#fcm.FCM_Plotter(file, 0, sd.length/1000, "plot", saved="no")

In [30]:
#interact_manual( DeeperAnalysis, filename=filenames)

In [None]:
print("\tDone!")