In [None]:
import matplotlib
%matplotlib widget

import numpy as np
from split_dataset import SplitDataset
from pathlib import Path
import flammkuchen as fl
import matplotlib.pyplot as plt 

from bouter.utilities import reliability 
from skimage.filters import threshold_otsu
import xarray as xr
from scipy.signal import detrend 

In [None]:
master = Path(r"\\Funes\Shared\experiments\E0040_motions_cardinal\v13_cw_ccw\2p\rf\fixed")
fish_list = list(master.glob("*f[0-9]*"))
fish_dir = fish_list[5]

try:
    with open(next(fish_dir.glob("*metadata.json"))) as i:
        metadata = json.load(i)
    fish_id = metadata['general']['fish_id']
except:
    fish_id = ""
    
print(fish_dir)
print(fish_id)

In [None]:
traces = fl.load(fish_dir / "traces.h5")['traces'][:, 0:-2]
fs = 3
t = np.arange(np.shape(traces)[1]) / fs
num_traces, len_rec = np.shape(traces)
print("num_traces: ", num_traces)
print("len_rec: ", len_rec)
new_len_rec = int(len_rec/3)

In [None]:
norm_traces = np.copy(traces)
norm_traces=norm_traces.T# need to transpose it since the functions work like that 
sd=np.nanstd(norm_traces)
mean=np.nanmean(norm_traces)
norm_traces=norm_traces-mean #numerator in the formula for z-score 
norm_traces=norm_traces/sd
norm_traces=norm_traces.T

In [None]:
norm_traces[np.where(np.isnan(norm_traces))] = 0
#corrected_traces = np.zeros_like(norm_traces)
#for i in range(num_traces):
#    corrected_traces[i] = detrend(norm_traces[i], overwrite_data=False, type='linear')
corrected_traces = detrend(norm_traces, axis=1, overwrite_data=False, type='linear')
fig, ax = plt.subplots(2, 1, figsize=(10,15))
ax[0].imshow(norm_traces, aspect="auto")
ax[1].imshow(corrected_traces,  aspect="auto")

In [None]:
trial_traces = np.zeros((3, num_traces, new_len_rec))
trial_traces = np.zeros((num_traces, 3, new_len_rec))
trial_traces_corrected = np.zeros((num_traces, 3, new_len_rec))

for i in range(3):
    t1 = i * new_len_rec
    t2 = t1 + new_len_rec
    trial_traces[:, i] = traces[:, t1:t2]
    trial_traces_corrected[:, i] = corrected_traces[:, t1:t2]
avg_traces = np.nanmean(trial_traces_corrected, 1)
print(np.shape(trial_traces))

In [None]:
n_blocks = 3
dt = 0.33
traces_xr = xr.DataArray(
    data=trial_traces,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces.shape[2])*dt
        }
    )
reliability_arr = reliability(np.swapaxes(traces_xr, 0, 2).values)
rel_thresh = threshold_otsu(reliability_arr)
print("Reliability threshold: ", rel_thresh)

rel_thresh_3 = np.round(rel_thresh * 1000)
rel_thresh_3 /=1000

traces_xr_det = xr.DataArray(
    data=trial_traces_corrected,                               #Adding the data
    dims=['roi', 'block', 't'],                #Defining name of the dimensions
    coords={                                   #Defining values at which each dimension wase valuated
        'roi':np.arange(trial_traces_corrected.shape[0]), 
        'block':np.arange(n_blocks),
        't':np.arange(trial_traces_corrected.shape[2])*dt
        }
    )
reliability_arr_det = reliability(np.swapaxes(traces_xr_det, 0, 2).values)
rel_thresh_det = threshold_otsu(reliability_arr_det)
print("Reliability threshold: ", rel_thresh_det)

rel_thresh_3_det = np.round(rel_thresh_det * 1000)
rel_thresh_3_det /=1000
print(np.shape(reliability_arr_det))

In [None]:
#Visualize
fig, ax = plt.subplots(1, 2, figsize=(8,5))
ax[0].hist(reliability_arr, bins=50, density=True);
ax[0].axvline(rel_thresh, c='red', ls='--')

ax[0].set_xlim([-1,1])
ax[0].set_xlabel('Average correlation between reps')
ax[0].set_ylabel('Density')
ax[0].set_title("Reliability threshold: " + str(rel_thresh_3))
plt.tight_layout()

ax[1].hist(reliability_arr_det, bins=50, density=True);
ax[1].axvline(rel_thresh_det, c='red', ls='--')
ax[1].set_xlim([-1,1])
ax[1].set_xlabel('Average correlation between reps')
ax[1].set_ylabel('Density')
ax[1].set_title("Detrend, Reliability threshold: " + str(rel_thresh_3_det))
plt.tight_layout()

file_name = "reliability index " + fish_id
fig.savefig(str(fish_dir/file_name))

In [None]:
roi_map = fl.load(fish_dir / "merged_rois.h5")['stack']
num_planes = np.shape(roi_map)[0]
print("num planes:", num_planes)
print("num traces:", num_traces)
print(np.shape(roi_map))


In [None]:
roi_map_rel = np.zeros_like(roi_map, dtype=float)
roi_map_rel_test = np.copy(roi_map)
for i in range(0, num_traces):
    #print(np.where(roi_map == (i)))
    #print(reliability_arr_det[i])
    #print("ddddd")
    roi_map_rel[np.where(roi_map == (i+1))] = reliability_arr_det[i]
    roi_map_rel_test[np.where(roi_map_rel_test == (i+1))] = reliability_arr_det[i]
    #print(roi_map_rel[np.where(roi_map == (i+1))])
print(np.unique(roi_map_rel))
#print(np.unique(reliability_arr_det))


In [None]:
fig, ax = plt.subplots(4, 4, figsize=(12, 12))

for i in range(num_planes):
    r = i // 4
    c = np.mod(i, 4)

    roi_layer = roi_map_rel[i]
    roi_layer_orig = roi_map[i]
    roi_layer = np.ma.masked_where(roi_layer_orig < 1, roi_layer)
    roi_layer = np.rot90(roi_layer, k=1, axes=(1, 0))

    ax[r, c].imshow(roi_layer, cmap="coolwarm",  vmin=-1, vmax=1)
    ax[r, c].axis('off')
    ax[r, c].set_title("z" + str(i))

    
plt.show()
file_name ='rois_reliability_index_' + fish_id + '.jpg'
fig.savefig(str(fish_dir/file_name), dpi=300)