# Read CBC data and create targets

In [None]:
import pandas as pd
import numpy as np
import h5py
import scipy.io as sio
from matplotlib import pyplot as plt

In [None]:
import sys
import os

In [None]:
pythoncodepath = os.path.abspath(os.path.join('..', 'pythoncode'))
sys.path = [pythoncodepath] + sys.path

import importhelper
importhelper.addfolders2path(pythoncodepath)

import data_utils
import math_utils

# Meaning of indexes used in DataSet

In [None]:
cell_types = {
    0: 'NO', 1: 'BC1', 2: 'BC2', 3: 'BC3a', 4: 'BC3b', 5: 'BC4',
    6: 'BC5t', 7: 'BC5o', 8: 'BC5i', 9: 'BCX', 10: 'BC6',
    11: 'BC7', 12: 'BC8', 13: 'BC9', 14: 'BCR'
}

In [None]:
drug_types = {
    0: "no",
    1: 'Gbz + TPMPA',
    2: 'Strychnine',
    3: 'Gbz',
    4: 'TPMPA',
    5: 'lAP4',
}

# Read data submission 1

## Load no drug data

In [None]:
sub1_no_drug_file = os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data', 'submission1', 'FrankeEtAl_BCs_2017_v1.mat')

with h5py.File(sub1_no_drug_file,'r') as sub1_NoDrugsdata_raw:
    release_time = np.array(sub1_NoDrugsdata_raw['chirp_time']).flatten()
    sub1_no_drug_lchirp_traces = np.array(sub1_NoDrugsdata_raw['lchirp_avg'])
    sub1_cluster_idx = np.array(sub1_NoDrugsdata_raw['cluster_idx']).flatten()
    
    # Replace NaNs with zeros. Both mean no cluster.
    sub1_cluster_idx[~np.isfinite(sub1_cluster_idx)] = 0
    
    # Get stimulus.
    sub1_stim = pd.DataFrame(
        {'Time': np.array(sub1_NoDrugsdata_raw['chirp_stim_time']).flatten(),
         'Stim': np.array(sub1_NoDrugsdata_raw['chirp_stim']).flatten()}
    )
    
print('Number of traces: {}'.format(sub1_cluster_idx.size))
print('Percentage that is clusterd: {:.1%}'.format(np.sum(sub1_cluster_idx>0) / sub1_cluster_idx.size))

## Load drug data

In [None]:
sub1_drug_file = os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data', 'submission1', 'FrankeEtAl_BCs_2017_drugdata.mat')

sub1_Drugsdata_raw = sio.loadmat(sub1_drug_file)

sub1_drug_lchirp_traces = sub1_Drugsdata_raw['lchirp_drug_avg'].T
sub1_drug_idxs = sub1_Drugsdata_raw['drug'].flatten()
# Replace NaN with 0.
sub1_drug_idxs[~np.isfinite(sub1_drug_idxs)] = 0

del sub1_Drugsdata_raw

# Read data submission 2

In [None]:
sub2_file = os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data',
    'submission2', 'Franke2017_additional_data.hdf5')

In [None]:
with h5py.File(sub2_file,'r') as sub2_raw:
    
    sub2_cluster_idx = np.array(sub2_raw['cluster_idx'])
    sub2_drug_idxs = np.array(sub2_raw['drug_idx'])
    sub2_no_drug_lchirp_traces = np.array(sub2_raw['no_drug_lchirp_traces'])
    sub2_strychnine_lchirp_traces = np.array(sub2_raw['strychnine_lchirp_traces'])

# Merge submission 1 and 2

In [None]:
assert sub1_no_drug_lchirp_traces.shape[1] == sub2_no_drug_lchirp_traces.shape[1]
no_drug_lchirp_traces = np.vstack([sub1_no_drug_lchirp_traces, sub2_no_drug_lchirp_traces])
print(no_drug_lchirp_traces.shape)

In [None]:
assert sub1_drug_lchirp_traces.shape[1] == sub2_strychnine_lchirp_traces.shape[1]
drug_lchirp_traces = np.vstack([sub1_drug_lchirp_traces, sub2_strychnine_lchirp_traces])
print(drug_lchirp_traces.shape)

In [None]:
drug_idxs = np.concatenate([sub1_drug_idxs, sub2_drug_idxs])
print(drug_idxs.shape)

In [None]:
cluster_idx = np.concatenate([sub1_cluster_idx, sub2_cluster_idx])
print(cluster_idx.shape)

# Load stimuli

In [None]:
import pandas as pd

# Get recorded stimulus (noisy)
stim_rec = pd.read_csv(os.path.join(
    '..', 'experimental_data', 'data_iGluSnFR', 'cbc_data', 'Franke2017_recorded_stimulus.csv'))

# Get corrected stimulus
stim_corrected = pd.read_csv(os.path.join(
    'data_preprocessed', 'Franke2017_stimulus_time_and_amp_corrected.csv'))

# Plot cluster distribution

In [None]:
plt.figure(1,(12,3))
ax = plt.subplot(111)

plt.hist(cluster_idx, orientation='horizontal', bins=50)
ax.set_yticks(list(cell_types.keys()))
ax.set_yticklabels(list(cell_types.values()))
ax.set_xlabel('Num. of traces')
plt.tight_layout()
plt.show()

# Plot drug data distribution.

In [None]:
plt.figure(1,(12,2))
ax = plt.subplot(111)

plt.semilogx()
found_drug_idxs, counts = np.unique(drug_idxs, return_counts=True)

for drug_idx, count in zip(found_drug_idxs, counts): 
    plt.plot([0, count], [drug_idx, drug_idx], c='k', lw=4)
    
ax.set_yticks(found_drug_idxs)
ax.set_yticklabels([drug_types[found_drug_idx] for found_drug_idx in found_drug_idxs])
plt.xlim(100, None)
ax.set_xlabel('Num. of traces')

plt.tight_layout()
plt.show()

# Plot release for CBC clusters

In [None]:
def plot_mean_and_std(ax, plot_t, plot_mean, plot_std):
    ''' Plot mean and plus-minus one standard deviation.
    '''
    ax.plot(plot_t, plot_mean, 'k-', clip_on=False)
    ax.fill_between(plot_t, plot_mean-plot_std, plot_mean+plot_std, color='r', alpha=0.3, clip_on=False)

In [None]:
def plot_release(release_df, title, stims=None):
    '''Plot release data.
    Creates a figure.
    
    Parameters:
    
    release_df : DataFrame with columns 'Time', 'mean' and 'std'
        Release data to plot.
        
    title : str
        Title of Data
        
    stims : DataFrame or list of DataFrames with colums 'Time' and 'Stim'
        Stimulus or stimuli to be plotted.
    
    '''
    
    # Get data.
    plot_t = release_df['Time']
    plot_mean = release_df['mean']
    plot_std = release_df['std']
    
    # Plot.
    plt.figure(figsize=(15,2))
    ax1 = plt.subplot2grid((1,5), (0,0), colspan=3)
    ax2 = plt.subplot2grid((1,5), (0,3), colspan=1)
    ax3 = plt.subplot2grid((1,5), (0,4), colspan=1)
    
    ax = ax1
    ax.set_title(title, loc='left')
    plot_mean_and_std(ax, plot_t, plot_mean, plot_std)
    
    for t01, ax in zip([(1.9, 2.2), (4.8, 5.3)], [ax2, ax3]):
        
        t0 = t01[0]
        t1 = t01[1]
    
        idx0 = np.where(plot_t >= t0)[0][0]
        idx1 = np.where(plot_t >= t1)[0][0]
        
        plot_mean_and_std(ax, plot_t[idx0:idx1], plot_mean[idx0:idx1], plot_std[idx0:idx1])
        
        if stims is not None:
            stim_ax = ax.twinx()
            if not isinstance(stims, list): stims = [stims]
            for stim in stims:
                idx0 = np.where(stim['Time'] >= t0)[0][0]
                idx1 = np.where(stim['Time'] >= t1)[0][0]
            
                stim_ax.plot(stim['Time'][idx0:idx1], math_utils.normalize(stim['Stim'][idx0:idx1]))

    plt.tight_layout()
    plt.show()

## Drug data

In [None]:
drug_traces_sorted = {}

for drug_idx, drug_type in drug_types.items():
    
    drug_traces_sorted[drug_type] = {}
    
    for cell_idx, cell_type in cell_types.items():

        idxs = (cluster_idx==cell_idx) \
                & (drug_idxs==drug_idx) \
                & (np.all(np.isfinite(drug_lchirp_traces), axis=1))
        
        if np.sum(idxs) > 0:
            
            # Get data.
            traces = drug_lchirp_traces[idxs,:]
            
            drug_traces_sorted[drug_type][cell_type] = {}
            drug_traces_sorted[drug_type][cell_type]['Time'] = release_time
            drug_traces_sorted[drug_type][cell_type]['mean'] = np.mean(traces,0)
            drug_traces_sorted[drug_type][cell_type]['std']  = np.std(traces,0)

            # Plot.
            plot_release(
                release_df=pd.DataFrame(drug_traces_sorted[drug_type][cell_type]),
                title=cell_type + ' with ' + drug_type + '; n=' + str(np.sum(idxs)),
                stims=[stim_rec, stim_corrected]
            )

            drug_traces_sorted[drug_type][cell_type]['traces'] = traces
            
        else:
            drug_traces_sorted[drug_type][cell_type] = None

In [None]:
data_utils.save_var(drug_traces_sorted, os.path.join('data_preprocessed', 'drug_traces_sorted.pkl'))

## No drug Data

In [None]:
no_drug_traces_sorted = {}

for cell_idx, cell_type in cell_types.items():
    
    idxs = (cluster_idx==cell_idx)
    n = np.sum(idxs)
    
    traces = no_drug_lchirp_traces[idxs,:]
    
    # Get data.
    no_drug_traces_sorted[cell_type] = {}
    no_drug_traces_sorted[cell_type]['Time'] = release_time
    no_drug_traces_sorted[cell_type]['mean'] = math_utils.normalize(np.mean(traces,0))
    no_drug_traces_sorted[cell_type]['std']  = np.std(traces,0)
    
    # Plot.
    plot_release(
        release_df=pd.DataFrame(no_drug_traces_sorted[cell_type]),
        title=cell_type + '; n='+str(n),
        stims=[stim_rec, stim_corrected]
    )
    
    no_drug_traces_sorted[cell_type]['traces'] = traces 

In [None]:
data_utils.save_var(no_drug_traces_sorted, os.path.join('data_preprocessed', 'no_drug_traces_sorted.pkl'))

# Save selection to files

In [None]:
save_selection = {
    ('BC5o', 'Strychnine'),
    ('BC3a', 'Strychnine'),
}

In [None]:
for (cell_type, drug_type) in save_selection:
    
    if drug_type == 'NO':
        data = no_drug_traces_sorted[cell_type]
    else:
        data = drug_traces_sorted[drug_type][cell_type]
    
    file_name = os.path.join('data_preprocessed', f'Franke2017_Release_{cell_type}_{drug_type}.csv')
    print(file_name)
    
    # Normalize.
    mean_release = data['mean']
    mean_release -= np.mean(mean_release[data['Time']<=1])
    mean_release /= mean_release.max()
    
    dataframe = pd.DataFrame({
        'Time': data['Time'],
        'mean': data['mean'],
    })
        
    dataframe.to_csv(file_name, index=False)

# Sanity check - Compare with stimulus

We expect some delay in the stimulus response relative to the stimulus.
The exact number is hard to pin down, but we know the mouse retina is relatively slow and iGluSnFR adds some delaye aswell. So it should definitely be more than 10 ms but certainly less than 100 ms.

In [None]:
plt.figure(1,(12,3))

plt.subplot(121)
plt.plot(stim_corrected['Time'], stim_corrected['Stim'])
plt.plot(no_drug_traces_sorted['BC5o']['Time'], no_drug_traces_sorted['BC5o']['mean'])
plt.plot(drug_traces_sorted['Strychnine']['BC5o']['Time'], drug_traces_sorted['Strychnine']['BC5o']['mean'])
plt.xlim([1.95, 2.2])

plt.subplot(122)
plt.plot(stim_corrected['Time'], stim_corrected['Stim'])
plt.plot(no_drug_traces_sorted['BC3a']['Time'], no_drug_traces_sorted['BC3a']['mean'])
plt.plot(drug_traces_sorted['Strychnine']['BC3a']['Time'], drug_traces_sorted['Strychnine']['BC3a']['mean'])
plt.xlim([4.85, 5.2])