In [None]:
from load_data import load_h5_to_dict
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import h5py

module_path = os.path.abspath(os.path.join('..', 'Git/RFAnalysis/sta_analysis/sta_analysis'))
if module_path not in sys.path:
    sys.path.append(module_path)

from receptive_field_filter import *
from receptive_field_analysis import *
from util_plotting import *
from correlations import *

config_file = "config.toml"
with open(config_file, "r") as f:
    config = toml.load(f)

start = config["st_calculation"]["start"]
end = config["st_calculation"]["end"]

h5_dir = 'data/h5'
led_dir = 'data/led'
# h5_file = '2025-04-11T13-17-58_RecID-762_I24607_KG_bars025-83p3ms_15moff-15mon-15moff_5p7V_DIV25_HS259_spikesonly.h5'
# led_file = 'detailed_sync_2025-04-11T13-17-58_RecID-762_I24607_KG_bars025-83p3ms_15moff-15mon-15moff_5p7V_DIV25_HS259_only-de_2025-06-11_11-50-00.npz'

# led_file = 'detailed_sync_2024-10-02T13-19-22_RecID-078_I23857_IG_poisson025-62p5ms_15moff-15mon-15moff_5p7V_DIV23_HS31_only-de_2025-06-11_11-50-00.npz'
# h5_file = '2024-10-02T13-19-22_RecID-078_I23857_IG_poisson025-62p5ms_15moff-15mon-15moff_5p7V_DIV23_HS31_spikesonly.h5'

h5_file = '2025-03-17T12-19-24_RecID-584_I24605_KG_poisson025-62p5ms_15moff-15mon-15moff_5p5V_DIV14_HS257_spikesonly.h5'
led_file = 'detailed_sync_2025-03-17T12-19-24_RecID-584_I24605_KG_poisson025-62p5ms_15moff-15mon-15moff_5p5V_DIV14_HS257_only-de_2025-06-11_11-50-00.npz'

h5_filename = os.path.join(h5_dir, h5_file)
led_filename = os.path.join(led_dir, led_file)



data_dict = load_h5_to_dict(h5_filename)
led_data = np.load(led_filename)
led_timestamps = led_data['timestamps'] / 1e6  # in s

## the stuff above still needs to be reworked to pair recids and stuff
## but once reworked: its supposed to return: the data_dict of a file and the led_data.


def process_stimulus(stimulus):
    """Rearranges a given stimulus, returning correclty mapped stimulus patterns and total pattern count"""
    # Transform 32 bits into 256 LED Values
    stimulus = np.unpackbits(stimulus.astype(np.uint8), axis=1)
    # Reshape them back to the 16x16 grid
    stimulus = stimulus.reshape(stimulus.shape[0], 16, 16)
    # And change axis (2 = sample)
    stimulus = stimulus.transpose((1, 2, 0))  # shape (16, 16, N)
    # Now, since the mapping is not correct, we split the patterns into 8 left and 8 right columns, mirroring them veritcally  
    stimulus_left_half = stimulus[0:16, 0:8, :]
    stimulus_right_half = stimulus[0:16, 8:16, :]
    stimulus_left_half = np.flip(stimulus_left_half, axis = 1)
    stimulus_right_half = np.flip(stimulus_right_half, axis = 1)
    # And sticking them back together
    stimulus = np.hstack((stimulus_left_half, stimulus_right_half))

    return stimulus, stimulus.shape[2]


def get_timeframes(h5_data, led_data, dt = 0)
    """Takes timestamps in seconds and time-increment dt, returning bin edges of stimulus-phase, aswell as start and end time """
    timestamps = led_data['timestamps'] / 1e6  # in s

    if dt == 0:
        print("Warning: dt was not specified")
        dt = np.mean(np.diff(timestamps))
    
    stim_start = timestamps[0]
    stim_end = timestamps[-1] + dt
    
    bin_edges = np.concatenate([timestamps, [timestamps[-1] + dt]
    
    
    return spike_train, bin_edges, stim_start, stim_end

def get_spike_train(spike_times, bin_edges):
    spike_train = np.historgram(spike_times, bins = bin_edges)[0]
    return spike_train

def get_spike_counts(spike_times, stim_start, stim_end):
    stimulus_mask = (spike_times >= stim_start) & (spike_times < stim_end)
    n_in = stimulus_mask.sum()
    no_stimulus_mask = (spike_times < stim_start) | (spike_times >= stim_end)
    n_out = no_stimulus_mask.sum()
    return n_in, n_out

def get_firing_rates(n_in, n_out, 

def filters():
    if  

for key in data_dict.keys():
    spike_times = data_dict[key]  # in seconds
    
    
    # without spikes we dont have to do stuff
    if n_spikes_in_window <= 1:
        print(f"Skipping {key}: no spikes in stimulus period\n")
        continue

    firing_rate_outside = n_spikes_outside_window / (max(spike_times) - (stim_end - stim_start))  # firing rate in hz
    firing_rate_inside = n_spikes_in_window / (stim_end - stim_start)  # firing rate in hz
    
    # For now i wanna skip insufficient recs: therefore i filter (ARBITRARILY!)

    # recs are about 2700s, i would like to have at least 500 spikes in total so about 0.18 hz
    if (n_spikes_outside_window + n_spikes_in_window) / max(spike_times) < 0.18:
        print(f"{key} skipped. (barely active)\n")
        continue

    # everything that doesnt change firing rate much with(out) stimulation can be ignored for now.
    if abs(firing_rate_outside - firing_rate_inside) < 0.1:
        print(f"{key} skipped. (unchanging)\n")
        continue
    
    # since i only use abs in the step before inhibitory stuff might remain, but currently stc doesnt get enough data, i only use sta
    # and therefore ill just ignore all inhibited stuff
    if firing_rate_outside - firing_rate_inside > 0:
        print(f"{key} skipped. (inhibited)\n")
        continue

    print(f"{key} has {n_spikes_outside_window} outside and {n_spikes_in_window} spikes in stimulus window"
          f"(from {stim_start:.2f}s to {stim_end:.2f}s)\n")

    print(f"Firing Rate without Stimulus: \t{round(firing_rate_outside, 3)} Hz\n"
          f"Firing Rate with Stimulus: \t{round(firing_rate_inside, 3)} Hz\n"
          f"Stimulus Duration: {round((stim_end - stim_start), 3)} s, No-Stimulus Duration: {round((max(spike_times) - (stim_end - stim_start)), 3)} s\n")


    filter_empty = np.zeros((16,16))  #placeholder
    analysis = RFAnalysis(stimulus, spike_train, filter_empty, 'CL')
    analysis.calc_sta(center = True)
    analysis.calc_rta(center = True)

    # i believe stc just doesnt work here - there is just not enough spikes. 
    # analysis.calc_stc(whiten = True, delta = True)
    # analysis.sta = analysis.sta - np.mean(analysis.sta)  # center.. hmm - now added to rfanalysis
    # analysis.rta -= np.mean(analysis.rta)  # center too.. what
    analysis.plot_sta_lags(key)
    # analysis.plot_eigenvals_stc(key)
    # analysis.plot_eigenvecs_stc(key)

    # commented out for now - usually around 0 ! yippieh
    # vlim = (-3, 3)
    # plt.figure(figsize=(5, 4))
    # plt.title(f'RTA - {key}')
    # plt.imshow(analysis.rta, cmap='seismic', vmin=vlim[0], vmax=vlim[1])
    # plt.colorbar(label='RTA value')
    # plt.xticks([]); plt.yticks([])
    # plt.show()

    plt.figure(figsize=(8, 3))
    plt.vlines(led_timestamps, 1, 1.1, colors='red', alpha=0.5, label='With Stimulus')
    plt.eventplot(spike_times, lineoffsets=0.5, colors='black', linewidths = 0.1)
    plt.xlabel('Time in s')
    plt.title(f"Stimulus vs spikes: {key}")
    plt.legend()
    plt.show()