# Import modules.

In [None]:
%%capture
%load_ext autoreload
%autoreload 2


import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd 
import holoviews as hv
import csv
import math

from dask.distributed import Client, LocalCluster
from holoviews.operation.datashader import datashade, regrid
from holoviews.util import Dynamic
from IPython.core.display import display

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"

# Specify data location.

In [None]:
# Set up Initial Basic Parameters#

mpath = "/N/project/Cortical_Calcium_Image/Miniscope data/06.2022_Second_group/AA017_985237_D7"
mouseID = "A17"
date = "D7"
session = "S4"
ipath = os.path.join(mpath, "2022_06_12/16_38_34/Miniscope_2")
behavior = os.path.join(ipath,"timeStamps.csv")
dpath = os.path.join(ipath, session)
minian_path = "."
minian_ds_path = os.path.join(dpath, "minian")
n_workers = int(os.getenv("MINIAN_NWORKERS", 4))
behavior_data=pd.read_csv(os.path.join(mpath, mouseID+"_"+date+"_"+"behavior_ms.csv"),sep=',')

In [None]:
%%capture
sys.path.insert(0, minian_path)
from minian.utilities import (
    open_minian,
    TaskAnnotation,
)

In [None]:
dpath = os.path.abspath(dpath)
hv.notebook_extension("bokeh", width=75)

# Load data

In [None]:
cluster = LocalCluster(
    n_workers=n_workers,
    memory_limit="5GB",
    resources={"MEM": 1},
    threads_per_worker=2,
    dashboard_address=":8787",
)
annt_plugin = TaskAnnotation()
cluster.scheduler.add_plugin(annt_plugin)
client = Client(cluster)

## Data structure of Minian outputs
Here is the output of Minian. It is an xarray Dataset that contains coordinate labels for different aspects of the data. The way you typically reference the data is the same as you would a Python dictionary. That is, using data['key']. For example, data['A']. See more here: https://xarray.pydata.org/en/stable/generated/xarray.Dataset.html#xarray.Dataset. 

The most common variables you will be accessing are A, S, and sometimes C. 
data['A'] contains the spatial footprints of the detected neurons. 
data['S'] contains the deconvolved spikes.
data['C'] contains the modeled calcium traces. 

Each of these are arranged such that their first dimension is unit_id. So for example, in order to access the first neuron's spatial footprint you would call data['A'].sel(unit_id=0). To access the second, data['A'].sel(unit_id=1), etc.

In [None]:
# print(minian_ds_path)
data = open_minian(minian_ds_path)
data

In [None]:
data['C'].sel(unit_id=data['unit_id'])

# Store/open data in different formats

## Converting the data into other formats is relatively simple. See below for examples.

In [None]:
# Matlab-readable
data.to_netcdf(os.path.join(dpath, "minian_dataset.nc"))

In [None]:
# Pandas DataFrame
df = data['C'].to_pandas()
df

In [None]:
# Save C value 
df.to_csv(os.path.join(dpath, 'minian_'+mouseID+date+session+'_C.csv'))

In [None]:
df_S = data['S'].to_pandas()
df_S

In [None]:
# Save S value
df_S.to_csv(os.path.join(dpath, 'minian_'+mouseID+date+session+'_S.csv'))

# Plotting data

## Calcium traces and spikes
Below are difference functions for plotting the modeled traces and spikes for individual neurons and groups of neurons. The features of these plots can be modified by the user.

In [None]:
# # Plot individual traces.
# def plot_trace(data, neuron):
#     fig, C_ax = plt.subplots(figsize=(12,6))
#     S_ax = C_ax.twinx()
#     C_ax.plot(data['C'].sel(unit_id=neuron), color='royalblue')
#     C_ax.set_ylabel("C [modeled activity, A.U.]", color='royalblue')
#     C_ax.set_xlabel('Frame #')
#     C_ax.set_title(f'Neuron #{neuron}')

#     S_ax.plot(data['S'].sel(unit_id=neuron), color='r', alpha=0.5)
#     S_ax.set_ylabel("S [modeled activity, A.U.]", color='r')

In [None]:
# neuron_number = 0
# plot_trace(data, neuron_number)

In [None]:
behavior_data

In [None]:
# Plot traces from multiple neurons
def plot_multiple_traces(data, neurons_to_plot, behavior_data, start_frame, session,shift_amount=0.4):
    shifts = [shift_amount * i for i in range(len(neurons_to_plot))]
    if session == "S1":
        fifteen = behavior_data.loc[behavior_data['Time Stamp (ms)'] <= 900000]
        neuron_data=data.sel(frame=slice(0,len(fifteen)-1))
    else:
        fifteen = behavior_data.loc[behavior_data['Time Stamp (ms)'] >= 2700000]
#         print(len(fifteen))
#         print(len(data['C'].sel(unit_id=1)))
#         neuron_data=data.sel(frame=slice(len(data['C'].sel(unit_id=neurons_to_plot[0]))-len(fifteen),len(data['C'].sel(unit_id=neurons_to_plot[0])-1)))
        begin = len(data['C'].sel(unit_id=neurons_to_plot[0]))-len(fifteen) + data['frame'][0].values
        end = len(data['C'].sel(unit_id=neurons_to_plot[0])) + data['frame'][0].values - 1
#         print(begin)
#         print(end)
#         print(data['frame'])
#         print(data['frame'][0].values)
        neuron_data=data.sel(frame=slice(begin,end))
#         print(len(data['C'].sel(unit_id=neurons_to_plot[0])))
#         print(len(data['C'].sel(unit_id=neurons_to_plot[0]))-1)
#         print(len(fifteen))
#         print(len(neuron_data['C'].sel(unit_id=neurons_to_plot[0])))
    y=shifts[-1]+10
    fig, ax = plt.subplots(figsize=(30.4,y))
    for shift, neuron in zip(shifts, neurons_to_plot):
#         if session == "S1":
#             trace = data['C'].sel(frame=slice(0,len(fifteen))).sel(unit_id=neuron)
#         else:
#             trace = data['C'].sel(frame=slice(len(data['C'])-1-len(fifteen),len(data['C']-1))).sel(unit_id=neuron)
#         print(trace)
        trace = neuron_data['C'].sel(unit_id=neuron)
        trace /= np.max(trace)
#         ax.autoscale()
        ax.text(-1,shift,neuron)
        ax.plot(fifteen['Time Stamp (ms)'],trace + shift)
#         ax.plot(behavior_data.loc[behavior_data['Frame Number']<len(trace)]['Time Stamp (ms)'],trace + shift)
#         start_frame,start_frame+len(trace)
        ax.vlines(fifteen.loc[fifteen['reinforcement']==1]['Time Stamp (ms)'],0,y-8,color="green")
        ax.vlines(fifteen.loc[fifteen['IALP']==1]['Time Stamp (ms)'],0,y-9,color="blue")
        ax.vlines(fifteen.loc[fifteen['ALP']==1]['Time Stamp (ms)'],0,y-9,color="red",alpha=0.5)
#         if session == "S1":
#             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=0) & (behavior_data['Frame Number']<(0+len(trace))) & (behavior_data['reinforcement']==1)]['Time Stamp (ms)'],0,y-8,color="green")
#             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=0) & (behavior_data['Frame Number']<(0+len(trace))) & (behavior_data['IALP']==1)]['Time Stamp (ms)'],0,y-9,color="blue")
#             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=0) & (behavior_data['Frame Number']<(0+len(trace))) & (behavior_data['ALP']==1)]['Time Stamp (ms)'],0,y-9,color="red",alpha=0.5)
#         else:
# #             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=len(data['C'])-1-len(fifteen)) & (behavior_data['Frame Number']<(start_frame+len(trace))) & (behavior_data['reinforcement']==1)]['Time Stamp (ms)'],0,y-8,color="green")
# #             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=len(data['C'])-1-len(fifteen)) & (behavior_data['Frame Number']<(start_frame+len(trace))) & (behavior_data['IALP']==1)]['Time Stamp (ms)'],0,y-9,color="blue")
# #             ax.vlines(behavior_data.loc[(behavior_data['Frame Number']>=len(data['C'])-1-len(fifteen)) & (behavior_data['Frame Number']<(start_frame+len(trace))) & (behavior_data['ALP']==1)]['Time Stamp (ms)'],0,y-9,color="red",alpha=0.5)
    ax.set_xlabel('Time Stamp (ms)')
    ax.set_ylabel('Neurons')
    ax.set_yticks([])
#     test=data['C'].sel(frame=slice(0,10)).sel(unit_id=0)
#     print(test)
    fig.savefig(os.path.join(dpath, mouseID+'_'+date+'_'+session+"_trace_ms.pdf"))

In [None]:
all_neurons=list(np.array(data['unit_id']))
plot_multiple_traces(data, all_neurons,behavior_data,0,session)

## Spatial footprints
Same for spatial footprints. Functions can be customizable by us or the user.

In [None]:
# Plot spatial footprints
def plot_footprints(data, neurons=None):
    if type(neurons) is int:
        neurons = [neurons]
    elif neurons is None:
        neurons = data['A']['unit_id']
    
    fov = np.zeros_like(data['A'].sel(unit_id=0))
    for neuron in neurons:
        fov += np.asarray(data['A'].sel(unit_id=neuron))
        
    fig, ax = plt.subplots(figsize=(12,12))
    ax.imshow(fov, origin='lower')
    ax.axis('image')

In [None]:
# plot_footprints(data)

In [None]:
# plot_footprints(data, good_looking_neurons)

# Basic analysis
Here are some basic descriptive statistics of the detected neurons. 

## Mean amplitudes
It is a good idea to use the S matrix, rather than the C matrix. This is because the S matrix contains deconvolved spikes whereas the C matrix is the modeled calcium trace, which includes the decay portion of the calcium signal. 

In [None]:
# Histogram of mean amplitudes of all detected neurons. 
def histogram_mean_amplitude(data, nbins=50):
    mean_amplitudes = data['S'].mean(dim='frame')
    
    fig, ax = plt.subplots(figsize=(12,6))
    ax.hist(mean_amplitudes, bins=nbins)
    ax.set_xlabel('Mean amplitudes [A.U.]')
    ax.set_ylabel('# of neurons')
    
    return np.asarray(mean_amplitudes)

In [None]:
mean_amplitudes = histogram_mean_amplitude(data, nbins=50)

In [None]:
mean_amplitudes

## Event frequencies
It is even more important to use the S matrix here because the C matrix will include time bins where the calcium signal is decaying.

In [None]:
# Histogram of event frequencies of all detected neurons. 
def histogram_event_freq(data, nbins=50):
    event_freq = [np.sum(n > 0)/len(n) for n in np.asarray(data['S'])]
    
    fig, ax = plt.subplots(figsize=(12,6))
    ax.hist(event_freq, bins=nbins)
    ax.set_xlabel('Event frequency [proportion of active frames]')
    ax.set_ylabel('# of neurons')
    
    return np.asarray(event_freq)

In [None]:
event_freq = histogram_event_freq(data)

In [None]:
event_freq

# Digital Dataset(DDS)


## AUC of S per minute

### merge file

In [None]:
S_df=data['S'].to_pandas()
S_df=S_df.transpose()
S_df

In [None]:
# S_df[2] # unit_id: int

In [None]:
if session=="S1":
    new_pd=pd.merge(S_df,behavior_data,left_index=True,right_on='Frame Number',how='left')
else:
    temp_pd=behavior_data.tail(len(S_df))
    S_df.reset_index(drop=True, inplace = True)
    temp_pd.reset_index(drop=True, inplace=True)
    new_pd=pd.concat([S_df,temp_pd],axis=1)

In [None]:
new_pd

###  Check frame interval

In [None]:
new_pd.loc[new_pd['Time Stamp (ms)'].diff()>100]

## AUC of S per min

In [None]:
new_pd['min']=new_pd['Time Stamp (ms)'].map(lambda x: math.floor(x/60000))
pd.set_option('display.min_rows',50)
res=new_pd.groupby('min').sum()
new_pd

In [None]:
res.drop(columns=['Frame Number','Time Stamp (ms)','Buffer Index']).to_csv(os.path.join(dpath, mouseID+'_'+date+'_'+session+"_AUC_of_S_per_minute.csv"))

In [None]:
res

## AUC of S per Event

In [None]:
column_name = new_pd.columns.values  #get column label
print(column_name)
print(column_name[:-7])

In [None]:
mark = np.zeros(column_name.size-7)
flag = np.zeros(column_name.size-7)

# for i in column_name[:-7]:
#     print(str(i)+": "+str(np.argwhere(new_pd.columns.values==i)[0][0]))


for i in column_name[:-7]:
    new_pd.insert(int(np.argwhere(new_pd.columns.values==i)[0][0]) + 1,str(i)+'_mark',0)   

for index,row in new_pd.iterrows():
    for i in range(column_name.size-7):
        if row[column_name[i]] == 0 and flag[i] != 0:
            flag[i]=0
        elif row[column_name[i]] == 0 and flag[i] == 0:
            continue
        elif row[column_name[i]] != 0 and flag[i] != 0:   
            new_pd.loc[index,str(column_name[i])+'_mark'] = mark[i]
        else:
            flag[i]=1
            mark[i]=mark[i]+1
            new_pd.loc[index,str(column_name[i])+'_mark'] = mark[i]



In [None]:
new_pd

In [None]:
# # new_pd=new_pd.drop(columns=['124_mark','123_mark','122_mark','121_mark','120_mark','119_mark','118_mark','117_mark','116_mark'])
# new_pd.to_csv(os.path.join(dpath, mouseID+'_'+date+'_'+session+"test.csv"))
# print(column_name)
print(new_pd.columns.values)

In [None]:
for i in range(column_name.size-8,-1,-1):
    SUM = pd.DataFrame()
    SUM[str(column_name[i])+'_sum'] = new_pd[column_name[i]].groupby(new_pd[str(column_name[i])+'_mark']).sum()
    temp = pd.merge(new_pd[str(column_name[i])+'_mark'],SUM,how='left',left_on=str(column_name[i])+'_mark',right_index=True)
#     print(temp)
#     print(type(new_pd[column_name[i]].groupby(new_pd[str(column_name[i])+'_mark']).sum()))
    new_pd.insert(int(np.argwhere(new_pd.columns.values==str(column_name[i])+'_mark')[0][0])+1,str(column_name[i])+'_sum',0)
    new_pd.update(temp)

new_pd.insert(int(np.argwhere(new_pd.columns.values== 'Time Stamp (ms)')[0][0])+1,'Time Interval (ms)',new_pd['Time Stamp (ms)'].diff())
new_pd.to_csv(os.path.join(dpath, mouseID+'_'+date+'_'+session+"_AUC_of_S_per_event.csv"))

In [None]:
# print(new_pd[2].groupby(new_pd['2_mark']).count())

## Frequency

In [None]:
df_min=new_pd['min']
frequency_df=new_pd.diff()
col=[]
for i in column_name[:-7]:
    col.append(str(i)+'_mark')
frequency_df=frequency_df[col]
frequency_df[frequency_df<0]=0
frequency_df2=frequency_df[frequency_df>0]
frequency_df2['min']=df_min
frequency_df2=frequency_df2.groupby('min').count()
frequency_df2.columns=frequency_df2.columns.str.rstrip('_mark').astype('int64')
frequency_df2.to_csv(os.path.join(dpath, mouseID+'_'+date+'_'+session+"_Number_of_S_events_per_min.csv"))
frequency_df2.columns.values

# Frame check