# Analysis example
This notebook shows you how the different utilities of the recording class can be used effectively to make spike train analysis super easy! </br>
In this first analysis we want to plot tuning curves for full field flash responses consisting of flashes at 6 different wavelengths. </br>
To do this, we need to count the number of spikes collected for each wavelenght.</br>
First we load the data

In [None]:
from importlib import reload  
import panel as pn
pn.extension('tabulator')
import numpy as np
import plotly.express as px
import Overview
import polars as pl
pn.extension()
pn.extension('plotly')
import colour_template
import spiketrain_plots
import moving_bars
import chirps
import stimulus_spikes
import plotly.graph_objs as go
import matplotlib.pyplot as plt
import pandas as pd
import plotly_templates
import spiketrains

In [None]:
recording = Overview.Recording.load(r"D:\Zebrafish_14_11_23\ks_sorted\overview")

In [None]:
recording.show_df()

Lets have a look at the stimuli in this recording:

In [None]:
recording.dataframes["stimulus_df"]

We are going to analyse the full field flash stimulus. This stimulus was played two times, stimulus indices 0 and 5. </br>
Let us check how similar the responses were to these stimuli by plotting an overview.

In [None]:
spikes_df = recording.get_spikes_triggered([[20],[21],[24]], [["all"]])
spikes_df

In [None]:
%matplotlib widget
# Establish the colour_template and load the correct stimulus:
CT = colour_template.Colour_template()
CT.pick_stimulus("FFF_6_MC")
# Calculate the mean value of the stimuli across stimulus presentations:
flash_durations = stimulus_spikes.mean_trigger_times(recording.stimulus_df, [0])
# Plot the cells and spikes:
fig, ax = spiketrain_plots.whole_stimulus(spikes_df, stacked=True, height=10, index="stimulus_index")
#fig = CT.add_stimulus_to_plot(fig, flash_durations, names=False)

We can see that the second stimulus presentation triggered much stronger responses than the first on. </br>

## Quality check
We might want to check the quality of the responses of the cells before we continue to analyse the tuning curves.

In [None]:
import binarizer

We could feed both stimuli into the quality control, this would give us the quality of a cell responding to both stimuli:

In [None]:
spikes_df = recording.get_spikes_triggered([["FFF"]], [["all"]], pandas=False)
binary_df = binarizer.timestamps_to_binary_multi(spikes_df, 0.001, np.sum(flash_durations), np.max(recording.stimulus_df.loc[[0,5]]["nr_repeats"]))

In [None]:
qis = binarizer.calc_qis(binary_df)

In [None]:
qis

Lets create a new dataframe which we can use for analysis. We can get the cell indices of cells that spikes in both instances:

In [None]:
cell_ids = binary_df["cell_index"].unique().to_numpy()
cell_ids

We create a new dataframe, which only contains those cells:

In [None]:
recording.dataframes["fff_analysis"] = recording.extract_df_subset(cell_ids.tolist(), stimulus_name=["FFF"])

We can create a new dataframe which contains the qi values and than add this to the dataframe in the recording

In [None]:
qi_df = pd.Series(index=pd.Index(cell_ids, name="cell_index"), data=qis, name="qi")

In [None]:
recording.add_column(qi_df, "fff_analysis")

In [None]:
recording.dataframes["fff_analysis"]

Lets plot the best 50 cells to get an idea for how they behave:

In [None]:
selected_cells = recording.dataframes["fff_analysis"].sort_values("qi", ascending=False)[:100][::2]["cell_index"].tolist()

spikes_df = recording.get_spikes_triggered([["FFF"]], [selected_cells])

In [None]:
fig, ax = spiketrain_plots.whole_stimulus(spikes_df, stacked=True, height=20)
fig = CT.add_stimulus_to_plot(fig, flash_durations, names=False)

Lets look at the worst 50 cells

In [None]:
selected_cells = recording.dataframes["fff_analysis"].sort_values("qi", ascending=False)[-100:][::2]["cell_index"].tolist()
spikes_df = recording.get_spikes_triggered([["FFF"]], [selected_cells])

In [None]:
fig, ax = spiketrain_plots.whole_stimulus(spikes_df, stacked=True, height=20)
fig = CT.add_stimulus_to_plot(fig, flash_durations, names=False)

Lets calculate the tuning curve for the first 100 cells: </br>
We can just calculate the average histogram (as plotted in the figure) and than find the peaks:

In [None]:
from scipy.signal import find_peaks
import histograms

In [None]:
selected_cells = recording.dataframes["fff_analysis"].sort_values("qi", ascending=False)[:100][::2]["cell_index"].tolist()
spikes_df = recording.get_spikes_triggered([["FFF"]], [selected_cells])

psth, bins = histograms.psth(spikes_df, bin_size=0.05, end = np.sum(flash_durations))
psth = psth/len(selected_cells)*0.05 # Calculate the mean per cell and as spikes per seconds, similar to the figure


In [None]:
peaks, properties = find_peaks(psth, distance=1/0.05, height=2)

In [None]:
import Opsins
reload(Opsins)
tmpl = Opsins.Opsin_template()
fig = tmpl.plot_overview(["Zebrafish"])
fig.update_layout(height=600, width=600)

fig.add_trace(go.Scatter(x=CT.wavelengths[:peaks.shape[0]][::2], y=properties["peak_heights"][::2]/np.max(properties["peak_heights"][::2]), mode="lines", line=dict(color="yellow"), name="ON Response"))
fig.update_layout(template="scatter_template")

fig.add_trace(go.Scatter(x=CT.wavelengths[:peaks.shape[0]][1::2], y=properties["peak_heights"][1::2]/np.max(properties["peak_heights"][::2]), mode="lines", line=dict(color="black"), name="OFF Response"))
fig.update_layout(template="scatter_template")
fig.update_yaxes(title="Spike count")
fig.update_xaxes(title="Wavelength")
fig.show()

A slightly more reliable way would be to calculate the max at specific periods: </br>
Here, we can use the flexibility of the recording object to extract from the recording: </br>
First, we create new trigger signals, that fit to out time windows, then we extract the spikes as we want:

In [None]:
# reminder, we already have a specific cell_df for this stimulus:
recording.dataframes["fff_analysis"]

In [None]:
# Lets look at the stimuli:

Lets create a new trigger signal by halfing the existing triggers 2 times. </br>
This means the new trigger signals are effectively every 1 second (4s/4)

In [None]:
fff_stim_df = recording.split_triggers(stimulus_indices=[0,5])
fff_stim_df

In [None]:
recording.dataframes["fff_tuning_curves"] = fff_stim_df

In [None]:
recording.dataframes["fff_tuning_curves"]

In [None]:
tuning_df = recording.get_spikes_triggered([[0], [5]],[["all"]],  stimulus_df="fff_tuning_curves", cell_df="fff_analysis", pandas=False)

In [None]:
trigger_df = spiketrains.count_spikes(tuning_df, columns="trigger", name="count")

In [None]:
on_trigger = np.arange(0,24,4) # New on every 4 seconds

In [None]:
off_trigger  = np.arange(2,24,4) # New off every 4 seconds, starting after 2 seconds

In [None]:
tuning_on = trigger_df.filter(pl.col("trigger").is_in(on_trigger)).to_pandas()["count"].to_numpy()

In [None]:
tuning_off = trigger_df.filter(pl.col("trigger").is_in(off_trigger)).to_pandas()["count"].to_numpy()

In [None]:
tuning_on = tuning_on/np.max(tuning_on) #Normalization

In [None]:
tuning_off = tuning_off/np.max(tuning_off) # Normalization

In [None]:
import Opsins
reload(Opsins)
tmpl = Opsins.Opsin_template()
fig = tmpl.plot_overview(["Zebrafish"])
fig.update_layout(height=600, width=600)


fig.add_trace(go.Scatter(x=CT.wavelengths[::2], y=tuning_on, mode="lines", line=dict(color="yellow"), name="ON Response"))
fig.update_layout(template="scatter_template")

fig.add_trace(go.Scatter(x=CT.wavelengths[::2], y=tuning_off, mode="lines", line=dict(color="black"), name="OFF Response"))
fig.update_layout(template="scatter_template")
fig.update_yaxes(title="Spike count")
fig.update_xaxes(title="Wavelength")
fig.show()

As we can see, the results differ slightly (most noteably for the ON response). That is, because we have now considered the sum of spikes, rather than just the peak spike frequency.

# Save results
Lets save one tuning curve per cell

In [None]:
trigger_df_cell = spiketrains.count_spikes(tuning_df,["stimulus_index", "cell_index", "trigger"], name="count")

In [None]:
trigger_df_cell

In [None]:
spiketrains(trigger_df_cell.filter(pl.col("trigger").is_in(on_trigger)), ["stimulus_index", "cell_index"], name

In [None]:
tuning_on_cell = trigger_df_cell.filter(pl.col("trigger").is_in(on_trigger)).group_by("stimulus_index", "cell_index").agg(pl.col("count").alias("tuning_on")).to_pandas()

In [None]:
tuning_off_cell = trigger_df_cell.filter(pl.col("trigger").is_in(off_trigger)).group_by("stimulus_index", "cell_index").agg(pl.col("count").alias("tuning_off")).to_pandas()

In [None]:
tuning_off_cell.set_index(["stimulus_index", "cell_index"]).squeeze().index.names

We end up with two pandas dataframes, which we can add to our previously created fff dataframe:

In [None]:
recording.add_column(tuning_off_cell.set_index(["stimulus_index", "cell_index"]).squeeze(), dataframe="fff_analysis")

In [None]:
recording.add_column(tuning_on_cell.set_index(["stimulus_index", "cell_index"]).squeeze(), dataframe="fff_analysis")

In [None]:
recording.dataframes["fff_analysis"]

In [None]:
recording.save(r"D:\zebrafish_26_10_23\ks_sorted\overview")

In [None]:
reload(spiketrains)

In [None]:
test = pl.from_pandas(spikes_df)

In [None]:
single_trig = spiketrains.collect_as_arrays(test, ["trigger"], "times_triggered", "spikes_split")

In [None]:
single_trig = single_trig["spikes_split"].to_numpy()

In [None]:
single_trig = single_trig - np.array([[0, 4, 8, 12, 16, 20]])

In [None]:
single_trig

In [None]:
from plotly.subplots import make_subplots
fig = make_subplots(rows=6, cols=1, shared_xaxes=True)
for idx, response in enumerate(single_trig[0]):
    hist, bins = np.histogram(response, bins=np.arange(0,4.01, 0.01))
    fig.add_trace(go.Scatter(x=bins[:-1], y=hist, mode="lines", line=dict(color=CT.colours[::2][idx])), col=1, row=idx+1)

In [None]:
fig.update_layout(template="scatter_template")
fig.add_vrect(x0=2, x1=4, fillcolor="grey", opacity=0.2, line_width=0)
fig.update_xaxes(title="Time in s")
fig.show(renderer="browser")


In [None]:
from plotly.subplots import make_subplots
fig = go.Figure()
for idx, response in enumerate(single_trig[0]):
    hist, bins = np.histogram(response, bins=np.arange(0,4.01, 0.01))
    fig.add_trace(go.Scatter(x=bins[:-1], y=hist, mode="lines", line=dict(color=CT.colours[::2][idx])))
fig.update_layout(template="scatter_template")
fig.add_vrect(x0=2, x1=4, fillcolor="grey", opacity=0.2, line_width=0)
fig.show(renderer="browser")

In [None]:
df = histograms.psth_by_index(test, 0.01, index="trigger", window_end=4)

In [None]:
df

In [None]:
test["trigger"].n_unique()

In [None]:
reload(spiketrains)
reload(histograms)

In [None]:
recording.stimulus_df

In [None]:
c_steps = recording.get_spikes_triggered([["all"]], [[11]])

In [None]:
np.arange(0,40, 4)

In [None]:
csteps_trigger = spiketrains.align_on_condition(c_steps, "trigger", np.arange(0,40, 4))

In [None]:
reload(histograms)
hist_csteps, bins_csteps = histograms.psth_by_index(trigger_df, 0.01, index="trigger", to_bin="aligned_times")

In [None]:
hist_csteps = hist_csteps/np.max(hist_csteps)

In [None]:
CT.pick_stimulus("Contrast_Step")

In [None]:
from plotly.subplots import make_subplots
fig = go.Figure()
for idx, cell in enumerate(hist_csteps):
    fig.add_trace(go.Scatter(x=bins_csteps[:-1], y=cell, mode="lines", line=dict(color=np.flipud(CT.colours)[1::2][idx])))
fig.update_layout(template="scatter_template")
fig.add_vrect(x0=2, x1=4, fillcolor="grey", opacity=0.2, line_width=0)
fig.show(renderer="browser")

In [None]:
fff_df = recording.get_spikes_triggered([["all"]], [[5]])

In [None]:
fff_trigger = spiketrains.align_on_condition(fff_df, "trigger", np.arange(0,24, 4))

In [None]:
hist_fff, bins_fff = histograms.psth_by_index(fff_trigger, 0.01, index="trigger", to_bin="aligned_times")

In [None]:
CT.pick_stimulus("FFF_6_MC")

In [None]:
hist_fff = hist_fff/np.max(hist_fff)

In [None]:
for idx, cell in enumerate(hist_fff):
    fig.add_trace(go.Scatter(x=bins_fff[:-1], y=cell, mode="lines", line=dict(color=CT.colours[::2][idx])))

In [None]:
fig.show(renderer="browser")

In [None]:
spikes_df = recording.get_spikes_triggered([[8]], [["all"]])

In [None]:
fig, ax = spiketrain_plots.whole_stimulus(spikes_df, stacked=False, height=20)