In [None]:
#Auto-reload modules (used to develop functions outside this notebook)
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import h5py
from nd2_to_caiman import np_arr_from_nd2
import labrotation.file_handling as fh
from matplotlib import pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
from math import floor, sqrt
from datetime import datetime
import json
from labrotation import json_util
import scipy
from scipy import ndimage
import datadoc_util
from statsmodels.nonparametric.smoothers_lowess import lowess
import pandas as pd
import seaborn as sns
from math import floor
import multiprocess as mp  # multiprocessing does not work with IPython. Use fork instead.
import os

# Open traces file

In [None]:
whole_traces_h5_fpath = fh.open_file("Open traces h5 file!")
print(whole_traces_h5_fpath)

In [None]:
with h5py.File(whole_traces_h5_fpath, 'r') as hf:
    session_uuid = hf.attrs["uuid"]
    moco_intervals = hf["moco_intervals"][()]
    moco_flags = hf["moco_flags"][()]
    cnmf_intervals = hf["cnmf_intervals"][()]
    cnmf_flags = hf["cnmf_flags"][()]
    begin_end_frames = hf["begin_end_frames"][()]
    # spatial components: CNMF A field
    A_data = hf["spatial"]["data"][()]
    A_indices = hf["spatial"]["indices"][()]
    A_indptr = hf["spatial"]["indptr"][()]
    A_shape = hf["spatial"]["shape"][()]
    # temporal signals, i.e. neuron traces
    temporal = hf["traces"][()]

In [None]:
n_neurons, n_frames = temporal.shape

In [None]:
# convert spatial from sparse matrix into dense matrix of proper dimensions
spatial = scipy.sparse.csc.csc_matrix((A_data, A_indices, A_indptr), shape=A_shape).todense()
spatial = np.array(spatial)  # change type to allow np.reshape (?)
spatial = np.reshape(spatial, (512, 512, n_neurons)) # (262144 -> 512x512, i.e. "unflatten")
spatial = np.transpose(spatial, axes=[2, 1,0])  # move neuron index to end

In [None]:
res_x, res_y = spatial[0].shape

In [None]:
centers_x = np.zeros((n_neurons))
centers_y = np.zeros((n_neurons))
for i_neuron in range(n_neurons):
    xy = ndimage.measurements.center_of_mass(spatial[i_neuron].T) # transpose needed so that imshow and scatter plotting match 
    centers_x[i_neuron] = xy[0]
    centers_y[i_neuron] = xy[1]

In [None]:
fig = plt.figure(figsize=(16,16))
plt.imshow(spatial[0])
plt.scatter(centers_x[0], centers_y[0], color="lime")
plt.show()

In [None]:
fig = plt.figure(figsize=(16,16))
plt.scatter(centers_x, centers_y, color="red")
ax = plt.gca()
ax.invert_yaxis() 
plt.show()

# Create x values
in best case, these should be read out of nikon metadata file...

In [None]:
x_vals = np.linspace(0, n_frames-1,num=n_frames)

## Get important frames
beginning of seizure, SD waves...

In [None]:
ddoc = datadoc_util.DataDocumentation()
ddoc.loadDataDoc()

In [None]:
segments_list = ddoc.getSegmentsForUUID(session_uuid, as_df=False)

In [None]:
segments_list

In [None]:
segments = ddoc.getSegmentsForUUID(session_uuid)

In [None]:
n_waves = len(segments[segments["interval_type"] == "sd_wave"])
print(f"Recording contains {n_waves} SD waves")

In [None]:
# sort by beginning frame of sd wave
sd_waves_list = sorted(list(filter(lambda seg: seg[0] == "sd_wave", segments_list)), key= lambda x: x[1])  

In [None]:
sd_waves_list

In [None]:
i_frame_begin = sd_waves_list[0][1]
i_frame_end = sd_waves_list[-1][2]

# Filter signal

In [None]:
lowess_filtered_traces = np.zeros(temporal.shape)
lowess_filtered_xvals = np.zeros(temporal.shape)

### Run lowess filtering in parallel

In [None]:
mp.cpu_count()

In [None]:
def get_lowess(row):
    from statsmodels.nonparametric.smoothers_lowess import lowess
    import numpy as np
    x_vals = np.linspace(0, len(row)-1,num=len(row))
    lowess_filtered = lowess(row, x_vals, frac=0.001, is_sorted=True,) 
    # lowess_x = lowess_filtered.T[0]
    # lowess_y = lowess_filtered.T[1] 
    return lowess_filtered.T

In [None]:
p = mp.Pool(processes=mp.cpu_count()-2)
lowess_results = p.map(get_lowess, temporal[range(n_neurons)])

### Create 2d np array of results

In [None]:
for i_neuron in range(n_neurons):
    lowess_filtered_xvals[i_neuron,:] = lowess_results[i_neuron][0]
    lowess_filtered_traces[i_neuron,:] = lowess_results[i_neuron][1]

# Calculate onset times

In [None]:
derivatives_lowess = np.zeros(temporal.shape, dtype=temporal.dtype)
derivatives_raw = np.zeros(temporal.shape, dtype=temporal.dtype)

In [None]:
# It turns out that the slow waves make symmetric first derivative a better signal to get onset as local maxima. For seizure onset, 
# this will not be the case (most likely)

# symmetric derivative: dy/dt[i] = (y[i+1] - y[i-1])/2h (h:= 0.5). 
# Not entirely valid, as y[i+h] should be the value instead, but 1/2h factor can be neglected
#for i_neuron in range(n_neurons):
#    derivatives[i_neuron][1:-1] = temporal[i_neuron][2:] - temporal[i_neuron][:-2]

# extended symmetric derivative: dy/dt[i] = (y[i+2] + y[i+1] - y[i-1] - y[i-2])/4h
#for i_neuron in range(n_neurons):
#    derivatives[i_neuron][2:-2] = temporal[i_neuron][4:] + temporal[i_neuron][3:-1] - temporal[i_neuron][1:-3] - temporal[i_neuron][:-4]

# super-extended symmetric derivative: dy/dt[i] = (y[i+4]+y[i+3]+y[i+2]+y[i+1] - y[i-1]-y[i-2]-y[i-3]-y[i-4])  
for i_neuron in range(n_neurons):
    derivatives_raw[i_neuron][4:-4] = temporal[i_neuron][8:] + temporal[i_neuron][7:-1] + temporal[i_neuron][6:-2] + temporal[i_neuron][5:-3] - temporal[i_neuron][3:-5] - temporal[i_neuron][2:-6] - temporal[i_neuron][1:-7] - temporal[i_neuron][:-8]

# asymmetric first derivative: dy/dt[i] = (y[i] - y[i-1])/h (h := 1)
#for i_neuron in range(n_neurons):
#    derivatives[i_neuron][1:] = temporal[i_neuron][1:] - temporal[i_neuron][:-1]

# second derivative: d^2y/dt^2[i] = (y[i+h] - 2y[i] + y[i-h])/h^2 (h:=1)
#for i_neuron in range(n_neurons):
#    derivatives[i_neuron][1:-1] = temporal[i_neuron][2:] - 2*temporal[i_neuron][1:-1] + temporal[i_neuron][:-2]

# lowess-filtered extended symmetric derivative:
for i_neuron in range(n_neurons):
    derivatives_lowess[i_neuron][2:-2] = lowess_filtered_traces[i_neuron][4:] + lowess_filtered_traces[i_neuron][3:-1] - lowess_filtered_traces[i_neuron][1:-3] - lowess_filtered_traces[i_neuron][:-4]

In [None]:
plot_example_trace = False
if plot_example_trace:
    fig = plt.figure(figsize=(14,14))
    plt.plot(lowess_filtered_xvals[0], lowess_filtered_traces[0])
    plt.plot(lowess_filtered_xvals[0], temporal[0])
    plt.plot(lowess_filtered_xvals[0], derivatives_lowess[0])
    plt.xlim((i_frame_begin - 200,i_frame_end + 200))
    plt.show()

### Find first [n_waves] local maxima for neurons in the time window

In [None]:
MIN_DELAY_REQUIRED = 20  # at least this many frames need to pass between the onset of two SD waves in one place
def get_n_maxima(derivatives, n_maxima, i_frame_begin, i_frame_end, force_2d: bool = False, set_beginning_dark: bool = False):
    # n_maxima: number of maxima to find (n=1: first maximum, ...)
    # i_frame_begin: in 1-indexing, the first frame to consider for being maximum
    # i_frame_end: in 1-indexing, the last frame to consider for being maximum
    # force_2d: even if n_maxima = 1, return a 2d array, for example [[0], [1], ...]. If False: [0, 1, ...]
    # set_beginning_dark: if True, find global minimum in interval i_frame_begin to i_frame_end, set this as new i_frame_begin
    onsets = []
    for i_neuron in range(n_neurons):
        
        if set_beginning_dark:  # limit time window beginning to darkest frame in original time window
            # TODO: can do one argsort, choose from back first max that is > i_frame_begin_neuron
            i_frame_begin_neuron = np.argsort(temporal[i_neuron][i_frame_begin-1:i_frame_end+1])[0] + i_frame_begin + 15  # start 1 s after global minimum
            sorted_indices = np.argsort(derivatives[i_neuron][i_frame_begin_neuron-1:i_frame_end+1])
            assert i_frame_end - i_frame_begin_neuron > 30  # ad hoc value to test interval stays large enough (~2s)
        else:
            i_frame_begin_neuron = i_frame_begin
            sorted_indices = np.argsort(derivatives[i_neuron][i_frame_begin-1:i_frame_end+1])
        onset_frames = []
        if len(sorted_indices)< 1:
            print(i_neuron)
        onset_frames.append(sorted_indices[-1])
        i = len(sorted_indices) - 2
        while len(onset_frames) < n_maxima:
            too_close = False
            for onset_frame in onset_frames:
                if abs(sorted_indices[i] - onset_frame) < MIN_DELAY_REQUIRED:                
                    too_close = True
                    break
            if not too_close:
                onset_frames.append(sorted_indices[i])
            i -= 1
        # convert to whole video time frame
        for i_onset in range(len(onset_frames)):
            onset_frames[i_onset] += i_frame_begin_neuron-1
        # algorithm finds maxima by y-value, need to sort by x-value (i.e. in time, first maximum comes first...)
        if n_maxima == 1 and not force_2d:
            onsets.append(onset_frames[0])
        else:
            onsets.append(sorted(onset_frames))
    return onsets

### OPTIONAL: set beginning of interval where looking for SD wave the frame with minimum luminosity

In [None]:
set_beginning_dark = True

### Ignore maxima too close to each other

In [None]:
onsets = get_n_maxima(derivatives_lowess, n_waves, i_frame_begin, i_frame_end, force_2d=False, set_beginning_dark=set_beginning_dark)

In [None]:
I_NEURON = 290

fig = plt.figure(figsize=(14,14))
plt.plot(lowess_filtered_xvals[I_NEURON], lowess_filtered_traces[I_NEURON])
plt.plot(lowess_filtered_xvals[I_NEURON], temporal[I_NEURON])
plt.plot(lowess_filtered_xvals[I_NEURON], derivatives_lowess[I_NEURON])
vline_colors = ["red", "black", "orange"]  # add more if needed
for i_wave in range(n_waves):
    plt.vlines(onsets[I_NEURON][i_wave], ymin=-500, ymax=2000, color=vline_colors[i_wave], linewidth=2)
    
plt.vlines([i_frame_begin, i_frame_end], ymin=-1000, ymax=4000, color="lime", linewidth=3)

plt.vlines([np.argsort(temporal[I_NEURON][i_frame_begin-1:i_frame_end+1])[0]+i_frame_begin], ymin=-1500, ymax=4500, color="yellow",linewidth=3)

#plt.vlines(sorted_indices[-1] + i_frame_begin - 1, ymin=-500, ymax=2000, color="black", linewidth=0.5)
# TODO: mismatch in onsets[] as well as in sorted_indices[]! See plot...
plt.xlim((i_frame_begin - 200,i_frame_end + 200))
plt.show()

In [None]:
onsets_np = np.array(onsets)

## Plot sorted onset time for each neuron
Sorted by first SD wave onset, second one is plotted accordingly.

In [None]:
onsets_np_sorted = np.argsort(onsets_np[:,0])

In [None]:
onsets_np_sorted_second = np.argsort(onsets_np[:,1])

In [None]:
fig = plt.figure(figsize=(12,12))
t_vals = [i_neuron for i_neuron in range(n_neurons)]
for i_wave in range(n_waves):
    plt.scatter(t_vals, onsets_np[:,i_wave], label=f"Wave {i_wave+1}")
plt.legend(fontsize=20)
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
t_vals = [i_neuron for i_neuron in range(n_neurons)]
for i_wave in range(n_waves):
    plt.scatter(t_vals, onsets_np[onsets_np_sorted,i_wave], label=f"Wave {i_wave+1}")
plt.legend(fontsize=20)
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
t_vals = [i_neuron for i_neuron in range(n_neurons)]
for i_wave in range(n_waves):
    plt.scatter(t_vals, onsets_np[onsets_np_sorted_second,i_wave], label=f"Wave {i_wave+1}")
plt.legend(fontsize=20)
plt.show()

## Assign cells to grid tiles

In [None]:
grid_shape = (8,8)

In [None]:
def getGridTile(x, y, grid_shape):
    # based on x and y coordinates, get tile that covers this point.
    if x >= res_x or y >= res_y or x < 0 or y < 0:
        raise ValueError(f"x or y coordinate does not fit {res_x}x{res_y} FOV: x: {x}, y: {y}")
    else:
        row = y//(res_y/grid_shape[1])  # 0-indexing
        col = x//(res_x/grid_shape[0])
    return (col, row)

In [None]:
neuron_rows = np.zeros(n_neurons, dtype=np.int16)
neuron_cols = np.zeros(n_neurons, dtype=np.int16)
for i_neuron in range(n_neurons):
    col, row = getGridTile(centers_x[i_neuron], centers_y[i_neuron], grid_shape)
    neuron_rows[i_neuron] = row
    neuron_cols[i_neuron] = col

In [None]:
neuron_tiles = neuron_rows*grid_shape[0] + neuron_cols

In [None]:
getGridTile(511,0, grid_shape)

## Create dataframe

In [None]:
# TODO: depending on n_waves, there might be 0, 1, 2, ... SD waves
onsets_dict = {"neuron_id": [i for i in range(n_neurons)], "onset1" : onsets_np[:,0], "onset2" : onsets_np[:,1], "x": centers_x, "y": centers_y, "row": neuron_rows, "col": neuron_cols, "tile": neuron_tiles}

In [None]:
onsets_df = pd.DataFrame(data=onsets_dict)

In [None]:
onsets_df["quantile1"] = pd.qcut(onsets_df["onset1"], 4, labels=False)

In [None]:
onsets_df["quantile2"] = pd.qcut(onsets_df["onset2"], 4, labels=False)

In [None]:
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(18, 18))
sns.despine(f, left=True, bottom=True)
sns.scatterplot(x="x", y="y",
                size="onset1", hue="quantile1",
                sizes=(8,80), linewidth=0,
                data=onsets_df, ax=ax)
ax.invert_yaxis()  # invert to match imshow() and in general, nd2 videos: (0, 0) is top left corner
plt.show()

In [None]:
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(18, 18))
sns.despine(f, left=True, bottom=True)
sns.scatterplot(x="x", y="y",
                hue="tile", size="onset2",
                sizes=(8,80), linewidth=0,
                data=onsets_df, ax=ax, palette="deep")
ax.invert_yaxis()  # invert to match imshow() and in general, nd2 videos: (0, 0) is top left corner
plt.show()

## Get mean/median onset per grid

In [None]:
# TODO: assign grid index to each neuron. Then, make average/median onset.
# Then, create pivot (of a new df, maybe, which contains grid row, column, and onset)

In [None]:
onset_grid = np.zeros(grid_shape, dtype=onsets_np.dtype)

In [None]:
median_onsets_df = onsets_df.groupby("tile", as_index=False).median()

In [None]:
for i in range(grid_shape[0]):  # go through cols
    for j in range(grid_shape[1]): # go through rows
        i_tile = i + j*grid_shape[0]
        if i_tile in median_onsets_df["tile"].values:
            onset_grid[i][j] = median_onsets_df[median_onsets_df["tile"] == i_tile]["onset1"].values[0]


In [None]:
median_onsets1_pivot = median_onsets_df.pivot("row", "col", "onset1")
median_onsets2_pivot = median_onsets_df.pivot("row", "col", "onset2")

In [None]:
median_onsets1_pivot

In [None]:
test_flag = False
if test_flag:
    row_np = np.zeros((64),)
    col_np = np.zeros((64),)
    for i in range(64):
        row_np[i] = i%8
        col_np[i] = i//8
    test_dict = {"col": col_np, "row": row_np}
    test_df = pd.DataFrame(test_dict)
    test_df["val"] = 8*test_df["col"] + test_df["row"]
    test_pivot = test_df.pivot("col", "row", "val")
    f, ax = plt.subplots(figsize=(12, 12))
    #sns.heatmap(median_onsets1_pivot, annot=False, linewidths=.5, ax=ax)
    sns.heatmap(test_pivot, annot=False, linewidths=.5, ax=ax)
    #ax.invert_yaxis()
    plt.show()

In [None]:
f, ax = plt.subplots(figsize=(12, 12))
sns.heatmap(median_onsets1_pivot, annot=False, linewidths=.5, ax=ax)
#ax.invert_yaxis()
#ax.invert_xaxis()
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
ax = fig.gca(projection='3d')
surf=ax.plot_trisurf(onsets_df['x'], onsets_df['y'], onsets_df['onset1'], cmap=plt.cm.jet, linewidth=0.2)
#ax.view_init(10, 90)
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
ax = fig.gca(projection='3d')
surf=ax.plot_trisurf(onsets_df['x'], onsets_df['y'], onsets_df['onset2'], cmap=plt.cm.jet, linewidth=0.2)
ax.view_init(30, -45)
plt.show()

# Seizure onset analysis

In [None]:
sz = segments[segments["interval_type"] == "sz"]  
assert len(sz) == 1# there should be only one
seizure_begin_frame = sz["frame_begin"].values[0]
seizure_end_frame = sz["frame_end"].values[0]

In [None]:
seizure_begin_frame

In [None]:
#small window:
small_window = False
extra_large_window = True
if small_window:
    sz_lower_limit = seizure_begin_frame-2
    sz_upper_limit = seizure_begin_frame+10
elif extra_large_window:
    sz_lower_limit = seizure_begin_frame - 5
    sz_upper_limit = seizure_begin_frame + 100
else:
    sz_lower_limit = seizure_begin_frame-2
    sz_upper_limit = seizure_begin_frame+20
    
# might be necessary to set limits manually:
manual_limits = False
if manual_limits:
    sz_lower_limit = seizure_begin_frame - 2
    sz_upper_limit = 5541

In [None]:
onsets_sz = get_n_maxima(derivatives_lowess, 1,sz_lower_limit-2, sz_upper_limit)
onsets_sz_np = np.array(onsets_sz)

In [None]:
np.where(onsets_sz > seizure_begin_frame+10)

In [None]:
fig = plt.figure(figsize=(12,12))
plt.scatter([i_neuron for i_neuron in range(n_neurons)], sorted(onsets_sz))
plt.show()

In [None]:
onsets_df["onset_sz"] = onsets_sz_np

In [None]:
onsets_df["quantile_sz"] = pd.qcut(onsets_df["onset_sz"], 4, labels=False)

In [None]:
fig = plt.figure(figsize=(16,12))
sns.histplot(data=onsets_df, x="onset_sz")
plt.show()

In [None]:
onsets_df

In [None]:
I_NEURON = 100

fig = plt.figure(figsize=(14,14))
plt.plot(lowess_filtered_xvals[I_NEURON], lowess_filtered_traces[I_NEURON])
plt.plot(lowess_filtered_xvals[I_NEURON], temporal[I_NEURON])
plt.plot(lowess_filtered_xvals[I_NEURON], derivatives_lowess[I_NEURON])
plt.vlines(onsets_sz[I_NEURON], ymin=-500, ymax=2000, color="red", linewidth=2)
#plt.vlines(sorted_indices[-1] + i_frame_begin - 1, ymin=-500, ymax=2000, color="black", linewidth=0.5)
# TODO: mismatch in onsets[] as well as in sorted_indices[]! See plot...
plt.vlines([sz_lower_limit, sz_upper_limit], ymin=-500, ymax=2000, color="black", linewidth=2)
plt.xlim((seizure_begin_frame - 50,seizure_begin_frame + 200))
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
ax = fig.gca(projection='3d')
surf=ax.plot_trisurf(onsets_df['x'], onsets_df['y'], onsets_df['onset_sz'], cmap=plt.cm.jet, linewidth=0.2)
#ax.view_init(10, 90)
plt.show()

In [None]:
sns.set_theme(style="whitegrid")
f, ax = plt.subplots(figsize=(18, 18))
sns.despine(f, left=True, bottom=True)
#sns.scatterplot(x="x", y="y", hue="onset_sz", size="onset_sz",
#                sizes=[5, 40, 80, 160, 240, 320], linewidth=0,
#                data=onsets_df, ax=ax, palette=hues)
sns.scatterplot(x="x", y="y", hue="quantile_sz", size="onset_sz",
                sizes=(10,80), linewidth=0,
                data=onsets_df, ax=ax)
ax.invert_yaxis()  # invert to match imshow() and in general, nd2 videos: (0, 0) is top left corner
plt.show()

In [None]:
sz_onset_grid = np.zeros(grid_shape, dtype=onsets_sz_np.dtype)

In [None]:
mean_onsets_df = onsets_df.groupby("tile", as_index=False).mean()

In [None]:
for i in range(grid_shape[0]):  # go through cols
    for j in range(grid_shape[1]): # go through rows
        i_tile = i + j*grid_shape[0]
        if i_tile in mean_onsets_df["tile"].values:
            sz_onset_grid[i][j] = mean_onsets_df[mean_onsets_df["tile"] == i_tile]["onset_sz"].values[0]


In [None]:
mean_sz_onsets_pivot = mean_onsets_df.pivot("col", "row", "onset_sz")

In [None]:
f, ax = plt.subplots(figsize=(14, 14))
sns.heatmap(mean_sz_onsets_pivot, annot=False, linewidths=.5, ax=ax)
#ax.invert_yaxis()
#ax.invert_xaxis()
plt.show()

In [None]:
fig = plt.figure(figsize=(12,12))
t_vals = [i_neuron for i_neuron in range(n_neurons)]
for i_wave in range(n_waves):
    plt.scatter(t_vals, onsets_np[onsets_np_sorted_second,i_wave], label=f"Wave {i_wave+1}")
plt.scatter(t_vals,onsets_sz_np[onsets_np_sorted_second], label="Sz")
plt.legend(fontsize=20)
plt.show()

# Plot onset for each neuron

In [None]:
onsets_df

In [None]:
fig = plt.figure(figsize=(18,18))

for i_neuron in range(n_neurons):
    t_onset = onsets_df[onsets_df["neuron_id"] == i_neuron]["onset_sz"].values[0]
    tsteps = [t for t in range(t_onset-100, t_onset+100)]
    plt.plot(tsteps, lowess_filtered_traces[i_neuron][t_onset-100:t_onset+100])
plt.vlines([tsteps[100]], ymin=-2000, ymax=10000)
ax = plt.gca()
plt.show()

# Save to hdf5 file

In [None]:
export_folder, export_fname = os.path.split(whole_traces_h5_fpath)
# assuming file name was xy_traces.h5
export_fname = os.path.splitext(export_fname)[0][:-7] + "_grid.h5"
export_fpath = os.path.join(export_folder, export_fname)
print(f"Saving results to\n\t{export_fpath}")

In [None]:
onsets_df.to_hdf(export_fpath, key="uuid"+session_uuid)

# Get direction as single vector

In [None]:
onsets_df