## Imports

In [177]:
# Standard library imports
import os
import random

# Data analysis packages
import numpy as np
import pandas as pd
from scipy.signal import correlate, correlation_lags
from scipy.ndimage import gaussian_filter1d

# Visualization packages
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyArrow, Patch, Circle
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D

# Image processing packages
import imageio
from PIL import Image, ImageSequence

# Braingeneers packages for analysis
import braingeneers
import braingeneers.data.datasets_electrophysiology as ephys
from braingeneers.analysis.analysis import SpikeData, read_phy_files, load_spike_data, burst_detection, randomize_raster

In [116]:
def generate_random_overlap_spike_data(total_duration_ms, num_propagations, rows, cols, overlap_duration_ms, break_duration_ms):
    N_channels = rows * cols
    adjusted_total_duration_ms = total_duration_ms - (break_duration_ms * (num_propagations - 1))
    segment_duration_ms = adjusted_total_duration_ms / num_propagations

    # Initialize positions and neuron_data
    positions = [(x, y) for x in range(rows) for y in range(cols)]
    neuron_data = {'positions': {i: {'position': positions[i]} for i in range(N_channels)}}

    # Initialize train with empty lists for each channel
    train_patterned_overlap = [[] for _ in range(N_channels)]
    
    # Pre-assign at least one firing event for each neuron across the propagations
    preassigned_firings = {neuron_index: random.randint(0, num_propagations - 1) for neuron_index in range(N_channels)}

    # Generate spike times with specified overlap, breaks, and random firing within columns
    for propagation in range(num_propagations):
        segment_start = propagation * (segment_duration_ms + break_duration_ms)
        
        for col in range(cols):
            col_start = segment_start + (col * (segment_duration_ms / cols - overlap_duration_ms))
            spike_time_start = col_start
            spike_time_end = spike_time_start + (segment_duration_ms / cols) + overlap_duration_ms

            # Randomly decide which neurons in the column will fire (random proportion)
            firing_neurons_in_col = random.sample(range(col, N_channels, cols), random.randint(1, rows))

            # Ensure pre-assigned neurons for this propagation are included
            for neuron_index in range(col, N_channels, cols):
                if preassigned_firings[neuron_index] == propagation and neuron_index not in firing_neurons_in_col:
                    firing_neurons_in_col.append(neuron_index)

            for neuron_index in firing_neurons_in_col:
                spike_times = np.linspace(spike_time_start, spike_time_end, int(segment_duration_ms / cols / 10))
                train_patterned_overlap[neuron_index].extend(spike_times)

    # Convert lists to numpy arrays
    train_patterned_overlap = [np.array(times) for times in train_patterned_overlap]

    # Create the SpikeData object
    spike_data_patterned_random_firing = SpikeData(train=train_patterned_overlap, N=N_channels, length=total_duration_ms, neuron_data=neuron_data)

    return spike_data_patterned_random_firing

# Function to get neuron positions
def get_neuron_positions(spike_data):
    # Extract neuron positions from spike_data
    neuron_x = []
    neuron_y = []
    for neuron in spike_data.neuron_data['positions'].values():
        neuron_x.append(neuron['position'][0])
        neuron_y.append(neuron['position'][1])
    neuron_positions = np.array([neuron_x, neuron_y]).T
    return neuron_positions

def calculate_mean_firing_rates(spike_data):
    mean_firing_rates = []
    for neuron_spikes in spike_data.train:
        num_spikes = len(neuron_spikes)
        time_duration = spike_data.length / 1000  # Assuming spike times are in milliseconds
        firing_rate = num_spikes / time_duration
        mean_firing_rates.append(firing_rate)

    return np.array(mean_firing_rates)

def plot_raster(sd):
    idces, times = sd.idces_times()
    # Adjust the figsize parameter to change the figure size; (width, height) in inches
    fig, ax = plt.subplots(figsize=(10, 6))  # Example: 10 inches wide by 6 inches tall

    ax.scatter(times/1000, idces, marker='|', s=1)
    ax.set_xlabel("Time(s)")
    ax.set_ylabel('Unit #')
    plt.show()

In [180]:
# def firing_plotter(sd, global_min_rate, global_max_rate):
#     # Custom colormap from pale red/pink to dark red/black
#     colors = ["#ffcccb", "#ff6961", "#ff5c5c", "#ff1c00", "#bf0000", "#800000", "#400000", "#000000"]
#     cmap_name = "custom_red_black"
#     n_bins = 100  # More bins will give us a finer gradient
#     cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

#     firing_rates = calculate_mean_firing_rates(sd)
#     neuron_x, neuron_y = [], []
#     # Extract neuron positions
#     for i, neuron in sd.neuron_data['positions'].items():
#         neuron_x.append(neuron['position'][0])
#         neuron_y.append(neuron['position'][1])
    
#     plt.figure(figsize=(8, 6))
#     # Plotting a small pale blue square at each neuron position
#     plt.scatter(neuron_x, neuron_y, s=10, c='#add8e6', marker='s', alpha=0.6)  # s is the size, c is the color
#     scatter = plt.scatter(neuron_x, neuron_y, s=firing_rates*1.5, c=firing_rates, alpha=0.6, cmap=cm, vmin=global_min_rate, vmax=75)

#     plt.title("Neuron Firing Rates")
#     plt.xlabel("X Position")
#     plt.ylabel("Y Position")
#     plt.colorbar(scatter, label='Firing Rate (Hz)')
#     plt.gca().invert_yaxis()
#     # plt.show()

def firing_plotter(sd, global_min_rate, global_max_rate):
    # Custom colormap from pale red/pink to dark red/black
    colors = ["#ffcccb", "#ff6961", "#ff5c5c", "#ff1c00", "#bf0000", "#800000", "#400000", "#000000"]
    cmap_name = "custom_red_black"
    n_bins = 100  # More bins will give us a finer gradient
    cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

    firing_rates = calculate_mean_firing_rates(sd)
    neuron_x, neuron_y = [], []
    # Extract neuron positions
    for i, neuron in sd.neuron_data['positions'].items():
        neuron_x.append(neuron['position'][0])
        neuron_y.append(neuron['position'][1])

    plt.figure(figsize=(8, 6))
    # Plotting a small pale blue square at each neuron position
    plt.scatter(neuron_x, neuron_y, s=10, c='#add8e6', marker='s', alpha=0.6)
    scatter = plt.scatter(neuron_x, neuron_y, s=firing_rates*1.5, c=firing_rates, alpha=0.6, cmap=cm, vmin=global_min_rate, vmax=75)

    # Creating dummy plots for the legend
    for rate in [20, 40, 60]:  # Example firing rates
        plt.scatter([], [], s=rate*1.5, c='gray', alpha=0.6, label=f'{rate} Hz')

    plt.title("Animation of Firing Rates for Synthetic Wave-Propagation Data")
    plt.xlabel("X Position")
    plt.ylabel("Y Position")
    plt.colorbar(scatter, label='Firing Rate (Hz)')
    plt.gca().invert_yaxis()

    # Adding the legend for sizes with title
    plt.legend(scatterpoints=1, frameon=False, labelspacing=1, title='Firing Rate', loc='upper right')

    # plt.show()

In [193]:
# random_overlap_spike_data = generate_random_overlap_spike_data(
#     total_duration_ms=60 * 1000,  # 1 minute
#     num_propagations=5,
#     rows=16,
#     cols=16,
#     overlap_duration_ms=160,
#     break_duration_ms=6000  # 6 seconds
# )


random_overlap_spike_data = generate_random_overlap_spike_data(
    total_duration_ms=60 * 1000,  # 1 minute
    num_propagations=10,
    rows=16,
    cols=16,
    overlap_duration_ms=30,
    break_duration_ms=5000  # 6 seconds
)


# random_overlap_spike_data = generate_random_overlap_spike_data(
#     total_duration_ms=240 * 1000,  # 1 minute
#     num_propagations=5,
#     rows=32,
#     cols=32,
#     overlap_duration_ms=160,
#     break_duration_ms=6000  # 6 seconds
# )

In [191]:
def create_animation(sd, frames_per_second, output_dir='firing_maps', animation_file='firing_animation.gif'):
    base_path = "/workspaces/human_hippocampus/dev/other/thomas/thesis/"
    full_output_dir = os.path.join(base_path, output_dir)
    
    if not os.path.exists(full_output_dir):
        os.makedirs(full_output_dir)

    filenames = []
    
    # Adjust as necessary to capture the correct time segment
    # sd_test = sd.subtime(0, 10000)
    sd_test = sd
    total_recording_time_ms = sd_test.length  # Total recording time in milliseconds
    total_recording_time_s = total_recording_time_ms / 1000.0  # Convert ms to seconds

    # Calculate the total number of frames for the entire recording
    total_frames = int(frames_per_second * total_recording_time_s)
    
    # Calculate the global minimum and maximum firing rates for consistent coloring
    g_min_rate = np.min(calculate_mean_firing_rates(sd))
    g_max_rate = np.max(calculate_mean_firing_rates(sd))

    images = []
    for frame_number in range(total_frames):
        start_time_ms = (frame_number / frames_per_second) * 1000
        end_time_ms = ((frame_number + 1) / frames_per_second) * 1000
        sd_timestep = sd_test.subtime(start_time_ms, end_time_ms)
        
        firing_plotter(sd_timestep, global_min_rate=g_min_rate, global_max_rate=g_max_rate)
        
        filename = os.path.join(full_output_dir, f"frame_{frame_number}.png")
        plt.savefig(filename)
        plt.close()
        
        images.append(Image.open(filename))

    # Save frames to a GIF
    animation_path = os.path.join(base_path, animation_file)
    images[0].save(animation_path, save_all=True, append_images=images[1:], optimize=False, duration=int(1000 / frames_per_second), loop=0)
    
    print(f"Animation saved to {animation_path}")
    # return animation_path

In [194]:
create_animation(random_overlap_spike_data, 4)

  for i, neuron in sd.neuron_data['positions'].items():
