## Med64 QQ Export D3 Data ##

Prior name of file: `med64_qq_fr_nwb.ipynb`

To use with CoLab notebooks, upload the data files manually. 
This will downsample the data and export in a file for uploading to observable.

This version works with the Spyking Circus, NWB, YAML Parameter files

In [2]:
import os, re
import yaml
import csv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.io
from scipy import signal, stats

import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm

import pandas as pd
from functools import partial
from collections.abc import Iterable


# meappy module
from meappy.med64_data import *
from meappy.parameter_yaml import USER, USER_PATHS


In [3]:
from matplotlib import cm
cmap = cm.get_cmap('tab20')

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
# def find_expt_files(yaml_file):    
#     with open(yaml_file, 'r') as file:
#         med64_data = yaml.safe_load(file)

#     # pprint(med64_data, sort_dicts=False, indent=4)
    
#     paths = med64_data['paths']
#     slice_path = os.path.join(paths['base'], paths['protocol'], paths['slice'])
    
#     slice_name = paths['slice']
#     tx_file = os.path.join(slice_path, paths['treatment'])
#     unit_file = os.path.join(slice_path, paths['unit'])

#     #  expt_list[0] now replaced with slice_name
#     # tx_files now tx_file without s
#     return  slice_name, tx_file, unit_file

In [6]:
def find_better_expt_files(data_dir):
    files = os.listdir(data_dir)

    re_units = re.compile(r'^EBM_betterResults.*')
    re_tx = re.compile(r'^treatmentinfo.*')
    re_expr_id = re.compile(r"([0-9]{3}_[0-9]{2}h[0-9]{2}m[0-9]{2}s)\.['csv|mat']{3}")

    unit_files = dict()
    tx_files = dict()

    for file in files:
        expr_id = re_expr_id.findall(file)
        if expr_id:
            units_file = re_units.findall(file)
            if units_file:
                unit_files[expr_id[0]] = units_file[0]
            tx_file = re_tx.findall(file)
            if tx_file:
                tx_files[expr_id[0]] = tx_file[0]
    expt_list = sorted(list(unit_files.keys()))

    print('FOUND Data FILES:')
    print('Experiment IDs: \n\t' + '\n\t'.join(expt_list))
    print('Units Files: \n\t' + '\n\t'.join(unit_files.values()))
    print('Treatment Files: \n\t' + '\n\t'.join(tx_files.values()))
    
    return expt_list, tx_files, unit_files


In [7]:
def mean_firing_rates(units_data, tx_times):
    """
    get firing rate of a units in Hz
    depricated
    """
    num_units = len(units_data)
    start_time_sec = list(tx_times.values())[0]
    end_time_sec = list(tx_times.values())[-1]
    duration_s = end_time_sec - start_time_sec
    
    mean_firing_rate_hz = dict()
    all_unit_activity_count = 0
    for (id, data) in units_data.items():
        timestamps = data['timestamps']
        activity_count = timestamps.shape[0]
        mean_firing_rate_hz[id] = activity_count / duration_s
        all_unit_activity_count += activity_count
        
    all_unit_mean_fr_hz = all_unit_activity_count / duration_s / num_units
    return all_unit_mean_fr_hz, mean_firing_rate_hz
   

## TODO 
### check below for A more accurate version of get_tx_ranges  
* Use this when I have time to modify the d3 code to display it.
* Also, check if there are timestamps beyond the "TTX off" and allow for arbitrary end time to last range


In [8]:
# THIS IS ACCURATE VERSION DONT DELETE

# import operator

# def get_tx_ranges(tx_times):
#     tx_ranges = {}
#     sorted_tx_times = sorted(tx_times.items(), key=operator.itemgetter(1))
#     prev_tx = None
#     prev_time = None
#     for tx, time in sorted_tx_times:
#         if (prev_tx != None) and (prev_time != time):
#             tx_ranges[prev_tx.strip()] = [prev_time, time]
#         prev_tx = tx
#         prev_time = time
#     return tx_ranges

# get_tx_ranges(tx_times)

In [9]:
def get_tx_ranges(tx_times):
    all_times = sorted(list(tx_times.values()))
    last_time_index = len(all_times) - 1
    tx_range = dict()
    for tx, time in tx_times.items():
        time_index = all_times.index(time)
        if time_index == last_time_index:
            break
        next_time = all_times[time_index + 1]
        if time >= next_time:
            continue
        tx_range[tx.strip()] = [time, next_time]
    return tx_range

In [10]:
def get_mean_firing_by_tx(units_data, tx_times):
    """
    get mean firing rate of units in Hz for each treatment (tx)
    """
    tx_range = get_tx_ranges(tx_times)
    tx_mean_firing_rate_hz = dict()
    for tx in tx_range.keys():
        tx_mean_firing_rate_hz[tx] = {}
    for (unit_id, data) in units_data.items():
        for (tx, [start_time, end_time]) in tx_range.items():
            timestamps = data['timestamps']
            duration = end_time - start_time
            activity_count = [ts for ts in units_data[unit_id]['timestamps'] if 
                 (ts > start_time) and (ts < end_time)]
            tx_mean_firing_rate_hz[tx][unit_id] = len(activity_count)/duration
    return tx_mean_firing_rate_hz

In [11]:
def get_firing_rate_df(df_d3, mean_firing_by_tx):
    """
    create new dataframe with firing rate data formatted to qq dataframe
    """
    firing_rate_df = pd.DataFrame(columns = (df_d3.columns.tolist() + ['Firing Rate']))
    df_index = 0
    for tx in mean_firing_by_tx.keys():
        for unit in mean_firing_by_tx[tx].keys():
            fr = mean_firing_by_tx[tx][unit]
            nan_cols = firing_rate_df.columns.shape[0] - 4
            firing_rate_df.loc[df_index] = ([np.nan, unit, tx] + [np.nan] * nan_cols + [fr])
            df_index += 1
    return firing_rate_df

In [12]:
#  next line is to create from raw file input
def get_mean_firing_rates(unit_files, tx_files):
    units_data, tx_times = get_expt_data(unit_files, tx_files)    
    mean_firing_by_tx = get_mean_firing_by_tx(units_data, tx_times)
    return mean_firing_by_tx

### Continue with QQ calc

### Speed up get_timestamp_...
by using lookup in a table generated by array functions instead of loops

In [13]:
def get_timestamp_tx(ts, tx_range):
    for tx, times in tx_range.items():
        if (ts >= times[0]) & (ts < times[1]):
            return tx
    return 'No Tx'
        
def get_timestamp_begin(ts, tx_range):
    for tx, times in tx_range.items():
        if (ts >= times[0]) & (ts < times[1]):
            return times[0]
    return np.nan   
    
def get_timestamp_end(ts, tx_range):
    for tx, times in tx_range.items():
        if (ts >= times[0]) & (ts < times[1]):
            return times[1]
    return np.nan   


In [14]:
def build_df(units_data, tx_times):
    tx_range = get_tx_ranges(tx_times)
    time_tx = partial(get_timestamp_tx, tx_range = tx_range)
    time_begin = partial(get_timestamp_begin, tx_range = tx_range)
    time_end = partial(get_timestamp_end, tx_range = tx_range)

    df = None
    # for loop over units
    for unit_id in list(units_data.keys()):
        timestamps = units_data[unit_id]['timestamps']

        unit_df = pd.DataFrame(timestamps, columns=['timestamp'])
        unit_df['unit'] = unit_id
        if isinstance(df, pd.DataFrame):
            df = df.append(unit_df, ignore_index=True)
        else:
            df = unit_df
    
    # end loop over units
    df['tx'] = df['timestamp'].map(lambda x: time_tx(x))
    df['begin'] = df['timestamp'].map(lambda x: time_begin(x))
    df['end'] = df['timestamp'].map(lambda x: time_end(x))

    ts_groups = df.groupby(by = ['unit', 'tx'])
    df['group_idx'] = ts_groups['timestamp'].cumcount()+1
    
    ts_groups = df.groupby(by = ['unit', 'tx'])
    df['cum_dist'] = ts_groups['group_idx'].apply(lambda df: df/df.count())
    
    return df


In [15]:
def add_anchors(df):
    anchors_begin = df[['unit', 'tx', 'begin', 'end']].drop_duplicates()
    anchors_begin.dropna(inplace=True)
    anchors_end = anchors_begin.copy()
    
    anchors_begin['cum_dist'] = 1
    anchors_begin['timestamp'] = anchors_begin['end']

    anchors_end['cum_dist'] = 0
    anchors_end['timestamp'] = anchors_end['begin']
    
    return pd.concat([df, anchors_begin, anchors_end], axis=0, ignore_index=True)

## Export QQ data for d3 Linked Brushing Cross-filtering    

In [16]:
def data_to_qq(unit_files, tx_files):
    units_data, tx_times = get_expt_data(unit_files, tx_files)

    df = build_df(units_data, tx_times)
    df = add_anchors(df)

    df = df[df['tx'] != 'No Tx'] # drop times outside treatment time ranges
    df['x_plot'] = (df['timestamp']-df['begin'])/(df['end']-df['begin'])
    
    # sort zero anchors to beginning of dataframe
    df.sort_values(by=['unit', 'x_plot'], inplace=True)
    
    return df

In [17]:
def get_max_samples(df, max_d3_rows):
    """Returns the max samples per unit per treatment for d3 observable total_rows
    """
    max_samples = round(max_d3_rows / df.unit.unique().size / df.tx.unique().size)
    return max_samples

In [18]:
def downsample_data(df, max_d3_rows):
    """downsamples the data in each unit so that the max total size of data will
    work in d3 plot. If more units, then data per unit is larger.
    """
    max_samples = get_max_samples(df, max_d3_rows)
    df_downsampled = None

    downsample_groups = df.groupby(by=['unit', 'tx'])

    for by, grp in downsample_groups:
    #     if (by[0] == 5): # & (by[1] == 'DAMGO 500nM On'):
        grp.sort_values(by='x_plot', inplace=True)
        samples = grp.shape[0]
        downsample_rate = round(np.floor(samples / max_samples))
        downsample_rate = 1 if downsample_rate < 1 else downsample_rate
        if isinstance(df_downsampled, pd.DataFrame):
            df_downsampled = df_downsampled.append(grp.iloc[:-1:downsample_rate])
        else:
            df_downsampled = grp.iloc[:-1:downsample_rate] # don't add anchor here
        # add final anchor now, because downsampling could have excluded it
        df_downsampled = df_downsampled.append(grp.iloc[-1:])
#         print(downsample_rate, samples)
#         print(grp.iloc[::downsample_rate])
#         print(grp.iloc[::downsample_rate].shape)
#         print(df_downsampled.shape)
    return df_downsampled

In [19]:
def hypoactive_unit_tx(df, activity_threshold):
    # Depricated for activity_gaps_unit_tx
    """return list of tuples with (unit, tx) for low unit
    activity that will not have enough points on d3 lines"""
    threshold_fraction = 0.5
    group_sizes = df.groupby(by=['unit', 'tx'])['timestamp']
    return [i for i, v in (group_sizes.count() < (threshold_fraction * activity_threshold)).items() if v]

# interp_units = hypoactive_unit_tx(df_downsampled, get_max_samples(df, max_d3_rows))

In [20]:
# group_activity = df.groupby(by=['unit', 'tx'])[['timestamp', 'begin', 'end']]
# num_segments = int(get_max_samples(df, max_d3_rows)/2)

 # activity_gaps_unit_tx(df_downsampled, get_max_samples(df, max_d3_rows)):
    
def groups_with_gaps(group_activity, num_segments): 
    gap_groups = []
    for i, v in group_activity:
#     if i[0] < 3:
#         print(i)
        spikes = v.timestamp.values
        begin = v.iloc[0,:]['begin']
        end = v.iloc[0,:]['end']
        segments = np.linspace(begin, end, num_segments)
        prev_seg = segments[0]
        is_gap = False
        for seg in segments[1:]:
            gap_activity = [s for s in spikes if ((s > prev_seg) & (s < seg))]
#             print(len(gap_activity))
            if len(gap_activity) == 0:
                is_gap = True
        if is_gap:
            gap_groups.append(i)
    return gap_groups
    
# groups_with_gaps(group_activity, num_segments)expt_id

In [21]:
def activity_gaps_unit_tx(df, d3_sample_threshold):
    """return list of tuples with (unit, tx) for unit with gaps in
    activity that will not have enough points on d3 lines"""
    num_segments = int(get_max_samples(df, max_d3_rows)/2)
    group_activity = df.groupby(by=['unit', 'tx'])[['timestamp', 'begin', 'end']]
    return groups_with_gaps(group_activity, num_segments)

In [22]:
# df_tmp = df_line[['cum_dist', 'x_plot']].to_numpy()

def euclidist(xy_array):
    """Takes array of x,y coord and calculates the euclidean distance between 
    successive pairs.
    returns array of distances. first is zero, to maintain same length as coord array.
    """
#     min_dist = 1.5 / len(xy_array)
    dist = np.array([0])
    prev_row = xy_array[0]
    for row in xy_array[1:]:
        x_diff = row[0] - prev_row[0]
        y_diff = row[1] - prev_row[1]
        dist = np.append(dist, np.sqrt(x_diff**2 + y_diff**2))
        prev_row = row
#     num_new_points = np.floor(np.array(dist / min_dist))
    return dist

def num_new_points(df, dist_array):
    """returns number of points to be added before each index of a
    distance array.
    Assumes total distance of 2 (between sqrt(2) and 2) for qq plot.
    Adjust this number to change density of new points added.
    """
    expected_total_dist = 2 #1.5
    min_dist = expected_total_dist / get_max_samples(df, max_d3_rows)
    return np.floor(np.array(dist_array / min_dist))

# interp_input = np.column_stack((df_tmp, num_new_points(euclidist(df_tmp))))
# interp_input

In [23]:
def interp_points(xy_array, num_new_points):
    x0, y0 = xy_array[0]
    new_points = None
    iterator = enumerate(num_new_points)
    for i, n in iterator:
        if (i != 0):
            if (n > 0):
                x0, y0 = xy_array[i - 1]
                x1, y1 = xy_array[i]
                # divide new points along longest of x or axes
                run = x0 - x1
                rise = y0 - y1
                if ((rise/run) < 1):
                    _x = np.linspace(x0, x1, num=int(n + 2))
                    _x = _x[1:-1]  # remove end points because they're not new points
                    _y = np.interp(_x, [x0, x1], [y0, y1])
                else:
                    _y = np.linspace(y0, y1, num=int(n + 2))
                    _y = _y[1:-1]  # remove end points because they're not new points
                    _x = np.interp(_y, [y0, y1], [x0, x1])
                interp_coords = np.column_stack((_x, _y))
                if not isinstance(new_points, np.ndarray):
                    new_points = interp_coords
                else:
                    new_points = np.vstack((new_points, interp_coords))
#     print("interpolating " + str(len(new_points)) + " new points")                    
    return new_points 

# interp_points(df_tmp, num_new_points(euclidist(df_tmp)))

In [24]:
def get_interp_points_df(df_downsampled, interp_units):
    """Takes dataframe of processed points for d3 of qq plots.
    returns another df with same columns but interpolated 
    data points to fill in gaps in raw data. Interpolated points
    aid in selecting in interactive d3 charts.
    """
    tmp_downsampled_df = df_downsampled.iloc[0, :].copy()
    col_names = df_downsampled.columns
    # print(col_names)
    # print(col_names.to_list().index('cum_dist'))

    interp_points_for_df = None

    for unit, tx in interp_units:
#         print("Unit " + str(unit) + " " + tx)
        df_line = df_downsampled[(df_downsampled.unit==unit) & (df_downsampled.tx==tx)]
        xy_array = df_line[['x_plot', 'cum_dist']].to_numpy()
    #     line_euclidist = euclidist(df_line[['cum_dist', 'x_plot']].to_numpy())
        new_points = interp_points(xy_array, num_new_points(df_downsampled, euclidist(xy_array)))
        if isinstance(new_points, Iterable):
        #     print("orig num points " + str(len(xy_array)))
        #     print(new_points)
            base_row = df_line.iloc[0:1,:].copy()
            base_row['timestamp'] = np.nan
            collect_rows = base_row  # remove this first dimension holding row after loop
        #     print(base_row.values)
            for x, y in new_points:
                x_index = col_names.to_list().index('x_plot')
                y_index = col_names.to_list().index('cum_dist')
                new_row = base_row.values
                new_row[0][x_index] = x
                new_row[0][y_index] = y 
                collect_rows = np.vstack((collect_rows, new_row))
            collect_rows = collect_rows[1:]
            if not isinstance(interp_points_for_df, np.ndarray):
                interp_points_for_df = collect_rows
            else:
                interp_points_for_df = np.vstack((interp_points_for_df, collect_rows))
        
    print("total interpolated points " + str(interp_points_for_df.shape))      
    interp_df = pd.DataFrame(data = interp_points_for_df, columns = col_names)

    return interp_df

## to do   
- add new dataframe of interp points
- resort dataframe with interp points
- test in d3 observable
- try exporting without padding NaN. it might just work!

### Process final array to export for d3

In [25]:
def clean_column_header(df):
    """Removes column headers not needed for export.
    Changes name of x-value header."""
    df_d3 = df.copy()
    df_d3.drop(["begin", "end", "group_idx"], axis=1, inplace=True)
    df_d3.rename(columns={"x_plot": "x_idx"}, inplace=True)
    return df_d3

In [26]:
# ADD NaN columns. Inefficient. unneeded?
# create row with full features for csv export
# this takes a second
# are all the NaN necessary? I think javascript removes them now...

def pad_data(df):
    """This pads the data with empty columns all filled with NaN values.
    This is the format currently used by the d3 chart."""
    df_d3 = df.copy()
    tx_list = df_d3.tx.unique()

    for tx in tx_list: 
        df_d3[tx] = df_d3[['tx', 'cum_dist']].apply(
            lambda df, tx: df['cum_dist'] if (df['tx']==tx) else np.nan,
            axis=1, tx=tx)
    return df_d3

In [27]:
def export_d3_data(df_d3, expt_id):
    print(expt_id)
    save_csv_filepath_dir = '/Users/walter/Data/margolis/observable/'
#     save_csv_filepath = save_csv_filepath_dir + 'qq_for_d3_10K_interp_' + expt_id + '.csv'

    save_csv_filepath = save_csv_filepath_dir + 'd3_cum_fr_10K_interp_' + expt_id + '.csv'

    df_d3.to_csv(save_csv_filepath, index = False)

    print('\tSaved ' + save_csv_filepath)

## TODO
- Editing export here

```
expt_id to slice_id
then tuples (slice_id, unit_filename)

first confirm input for get_mean_firing_rates() as id or filepath for "expt_id"
```

In [28]:
def export_from_filelist(expt_list, datafile_args, interp=True):
    """
    expt_list is list of tuples with [(slice_id, unit_filepath), ...]
    """
    for slice_id, unit_file in expt_list:
        df = data_to_qq(unit_file, datafile_args)
        df_downsampled = downsample_data(df, max_d3_rows)
        if interp:
            interp_units = activity_gaps_unit_tx(df, get_max_samples(df, max_d3_rows))
            ### HERE interp
#             interp_units = hypoactive_unit_tx(df_downsampled, get_max_samples(df, max_d3_rows))
            interp_points_df = get_interp_points_df(df_downsampled, interp_units)
            df_d3 = df_downsampled.append(interp_points_df)
            # sort interpolated date into order of full dataframe
            df_d3.sort_values(by=['unit', 'x_plot'], inplace=True)
        df_d3 = clean_column_header(df_d3)
        df_d3 = pad_data(df_d3)

        mean_firing_by_tx = get_mean_firing_rates(unit_file, datafile_args) # get_mean_firing_by_tx(units_data, tx_times)
        firing_rate_df = get_firing_rate_df(df_d3, mean_firing_by_tx)
        df_d3 = firing_rate_df.append(df_d3)
        export_d3_data(df_d3, slice_id)

# Start Main() here

In [29]:
max_d3_rows = 10000

DATA_DIR = '/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/' # walter local
SLICE_PARAMS_FILE = 'slice_parameters.yaml'
yaml_file = os.path.join(DATA_DIR, SLICE_PARAMS_FILE)

expt_list, tx_files, unit_files = find_expt_files(yaml_file)

export_from_filelist([(expt_list, unit_files)], tx_files, expt_list)


total interpolated points (2385, 8)
20211005_17h33m55s
	Saved /Users/walter/Data/margolis/observable/d3_cum_fr_10K_interp_20211005_17h33m55s.csv


In [145]:
print(expt_list)
print(unit_files)

20211005_17h33m55s
/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/20211005_17h33m55s_units_ts.mat


## TODO
- previous run with earlier files used this kind of export list  

`expt_list[3:]`
```
output: ['825_16h06m19s', '827_12h06m26s', '827_13h30m23s', '827_15h57m53s']
```

In [29]:
def file_attachment_list(expt_list):
#     save_csv_filepath_dir = '/Users/walter/Data/margolis/observable/'
    attach_list = "["
    for expt_id in expt_list:
        filename = "FileAttachment(\'d3_cum_fr_10K_interp_" + expt_id + ".csv\')"
        attach_list += filename + ",\n"
    attach_list += "]"
#     return attach_list
    print(attach_list)

In [30]:
file_attachment_list(expt_list)

[FileAttachment('d3_cum_fr_10K_interp_2.csv'),
FileAttachment('d3_cum_fr_10K_interp_0.csv'),
FileAttachment('d3_cum_fr_10K_interp_2.csv'),
FileAttachment('d3_cum_fr_10K_interp_1.csv'),
FileAttachment('d3_cum_fr_10K_interp_1.csv'),
FileAttachment('d3_cum_fr_10K_interp_0.csv'),
FileAttachment('d3_cum_fr_10K_interp_0.csv'),
FileAttachment('d3_cum_fr_10K_interp_5.csv'),
FileAttachment('d3_cum_fr_10K_interp__.csv'),
FileAttachment('d3_cum_fr_10K_interp_1.csv'),
FileAttachment('d3_cum_fr_10K_interp_7.csv'),
FileAttachment('d3_cum_fr_10K_interp_h.csv'),
FileAttachment('d3_cum_fr_10K_interp_3.csv'),
FileAttachment('d3_cum_fr_10K_interp_3.csv'),
FileAttachment('d3_cum_fr_10K_interp_m.csv'),
FileAttachment('d3_cum_fr_10K_interp_5.csv'),
FileAttachment('d3_cum_fr_10K_interp_5.csv'),
FileAttachment('d3_cum_fr_10K_interp_s.csv'),
]


# Use observable.py module

In [30]:
from observable import main

main()

groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 45
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 4
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 4
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 290
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 8
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 4
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 145
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 48
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 30
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 102
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 49
groups_with_gaps() --> len(segments): 31
groups_with_gaps() --> len(spikes): 33
groups_with_gaps() --> len(segments): 31


# Debugging Code Below
- Use these cells to see inside functions defined above for debugging

## Find local data filenames ##

In [49]:
# define global file path

# DATA_DIR = '/Users/walter/Data/med64/experiment/HB_139_DAMGO/825_12h24m37s'
DATA_DIR = '/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/' # walter local

SLICE_PARAMS_FILE = 'slice_parameters.yaml'
PRODUCT_DIR = '/Users/walter/Data/med64/product/HB_139_DAMGO/'


yaml_file = os.path.join(DATA_DIR, SLICE_PARAMS_FILE)

## TODO: Which param file to use?

In [52]:
# expt_list, tx_files, unit_files = find_better_expt_files(DATA_DIR)
splice_params_filepath = os.path.join(DATA_DIR, SLICE_PARAMS_FILE)
expt_list, tx_files, unit_files = find_expt_files(splice_params_filepath)

# datafile_args = dict(data_dir = DATA_DIR, unit_files = unit_files, tx_files = tx_files)

In [53]:
print(yaml_file)
print(splice_params_filepath)

/Users/walter/Data/med64/experiment/HB_139_DAMGO/825_12h24m37s/slice_parameters.yaml
/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/slice_parameters.yaml


```python
# Original code for older files
expt_list, tx_files, unit_files = find_expt_files(DATA_DIR)
datafile_args = dict(data_dir = DATA_DIR, unit_files = unit_files, tx_files = tx_files)
```

## QQ Plot ##
- get timestamps for tx
- rank timestamps as lin_space from 0 to  1
- rank time in tx as lin_space from start to finish
- plot (ranks of tx, ranks of timestamps )

In [54]:
print(tx_files)
print(unit_files)

units_data, tx_times = get_expt_data(unit_files, tx_files)

/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/20211005_17h33m55s_treatments.csv
/Users/walter/Data/med64/experiment/VTA_NMDA/20211005_17h33m55s/20211005_17h33m55s_units_ts.mat


### Firing Rate Calc

In [55]:
 
expt_mean_fr, unit_mean_fr = mean_firing_rates(units_data, tx_times)

## Get_TX_Ranges   
- get_tx_rangesget_tx_ranges

In [33]:
# tx_range_tmp = get_tx_ranges(tx_times)
# tx_range_tmp

{'Baseline': [6.0, 645.0],
 'NMDA': [645.0, 1281.0],
 'NMDA Apamin': [1281.0, 1900.0],
 'TTX': [1900.0, 2550.0]}

In [57]:
### KEEP ###
# Keep this for constant used in num_new_points()

max_d3_rows = 10000
MAX_SAMPLES = 32 # default when no df for get_max_samples(df, max_d3_rows)
print(MAX_SAMPLES)

32


## Run by steps

In [58]:
# df = data_to_qq(expt_id, datafile_args)
df = data_to_qq(unit_files, tx_files)

df_downsampled = downsample_data(df, max_d3_rows)

In [90]:
# interp_units = activity_gaps_unit_tx(df, get_max_samples(df, max_d3_rows))
# interp_units

[(0, 'Baseline'),
 (0, 'NMDA'),
 (0, 'NMDA Apamin'),
 (1, 'Baseline'),
 (1, 'NMDA'),
 (1, 'NMDA Apamin'),
 (2, 'Baseline'),
 (4, 'NMDA'),
 (4, 'NMDA Apamin'),
 (5, 'Baseline'),
 (5, 'NMDA Apamin'),
 (6, 'NMDA Apamin'),
 (7, 'NMDA Apamin'),
 (8, 'NMDA Apamin'),
 (9, 'NMDA Apamin'),
 (10, 'Baseline'),
 (12, 'NMDA'),
 (12, 'NMDA Apamin'),
 (13, 'NMDA'),
 (14, 'NMDA'),
 (15, 'NMDA Apamin'),
 (16, 'NMDA'),
 (16, 'NMDA Apamin'),
 (18, 'Baseline'),
 (18, 'NMDA'),
 (19, 'Baseline'),
 (20, 'Baseline'),
 (21, 'Baseline'),
 (22, 'NMDA Apamin'),
 (26, 'Baseline'),
 (27, 'NMDA'),
 (28, 'NMDA'),
 (28, 'NMDA Apamin'),
 (30, 'NMDA'),
 (30, 'NMDA Apamin'),
 (32, 'NMDA'),
 (32, 'NMDA Apamin'),
 (33, 'NMDA Apamin'),
 (34, 'NMDA'),
 (34, 'NMDA Apamin'),
 (35, 'NMDA'),
 (35, 'NMDA Apamin'),
 (36, 'NMDA'),
 (36, 'NMDA Apamin'),
 (37, 'NMDA'),
 (37, 'NMDA Apamin'),
 (38, 'Baseline'),
 (39, 'NMDA'),
 (39, 'NMDA Apamin'),
 (41, 'Baseline'),
 (42, 'NMDA'),
 (42, 'NMDA Apamin'),
 (43, 'Baseline'),
 (46, 'NMDA Ap

In [60]:
tx_filter_str = 'Baseline'  # 'AMPA blocker DNQX'
    # Baseline', 'NMDA Apamin', 'NMDA'

In [61]:
# troubleshooting to verify integrity of data at intermediate steps
df[(df.unit==2) & (df.tx==tx_filter_str)]
df_downsampled[(df_downsampled.unit==2) & (df_downsampled.tx==tx_filter_str)]

Unnamed: 0,timestamp,unit,tx,begin,end,group_idx,cum_dist,x_plot
54071,6.000000,2,Baseline,6.0,645.0,,0.000000,0.000000
344,81.046925,2,Baseline,6.0,645.0,2.0,0.013986,0.117444
346,131.508075,2,Baseline,6.0,645.0,4.0,0.027972,0.196413
348,182.912075,2,Baseline,6.0,645.0,6.0,0.041958,0.276858
350,207.410925,2,Baseline,6.0,645.0,8.0,0.055944,0.315197
...,...,...,...,...,...,...,...,...
478,538.672575,2,Baseline,6.0,645.0,136.0,0.951049,0.833603
480,542.278375,2,Baseline,6.0,645.0,138.0,0.965035,0.839246
482,543.579525,2,Baseline,6.0,645.0,140.0,0.979021,0.841283
484,550.752200,2,Baseline,6.0,645.0,142.0,0.993007,0.852507


In [62]:
# interp_units = hypoactive_unit_tx(df_downsampled, get_max_samples(df, max_d3_rows))
max_d3_samples = get_max_samples(df, max_d3_rows)
interp_units = activity_gaps_unit_tx(df, max_d3_samples)
interp_points_df = get_interp_points_df(df_downsampled, interp_units)

df_d3 = df_downsampled.append(interp_points_df)

# sort interpolated date into order of full dataframe
df_d3.sort_values(by=['unit', 'x_plot'], inplace=True)

total interpolated points (2385, 8)


In [63]:
df_d3 = clean_column_header(df_d3)
df_d3 = pad_data(df_d3)

In [64]:
df.tx.unique()

array(['Baseline', 'NMDA', 'NMDA Apamin'], dtype=object)

In [65]:
df_d3[(df_d3.unit==2) & (df_d3.tx==tx_filter_str)] # DAMGO 500nM On

Unnamed: 0,timestamp,unit,tx,cum_dist,x_idx,Baseline,NMDA,NMDA Apamin
54071,6.0,2,Baseline,0.0,0.0,0.000000,,
255,,2,Baseline,0.003497,0.029361,0.003497,,
256,,2,Baseline,0.006993,0.058722,0.006993,,
257,,2,Baseline,0.01049,0.088083,0.010490,,
344,81.046925,2,Baseline,0.013986,0.117444,0.013986,,
...,...,...,...,...,...,...,...,...
268,,2,Baseline,0.994406,0.882006,0.994406,,
269,,2,Baseline,0.995804,0.911504,0.995804,,
270,,2,Baseline,0.997203,0.941003,0.997203,,
271,,2,Baseline,0.998601,0.970501,0.998601,,


In [66]:
mean_firing_by_tx = get_mean_firing_by_tx(units_data, tx_times)
firing_rate_df = get_firing_rate_df(df_d3, mean_firing_by_tx)
firing_rate_df.append(df_d3).head()

Unnamed: 0,timestamp,unit,tx,cum_dist,x_idx,Baseline,NMDA,NMDA Apamin,Firing Rate
0,,0,Baseline,,,,,,0.067293
1,,1,Baseline,,,,,,0.450704
2,,2,Baseline,,,,,,0.223787
3,,3,Baseline,,,,,,0.156495
4,,4,Baseline,,,,,,0.72144


## TODO
- change export_from_filelist(expt_list...) to take single string unit_file instead of list
- change unit_files and tx_files to singular since they are string value now.