Here we test for multiple interactions at once

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io

from keller_zlatic_vnc.data_processing import count_unique_subjs_per_transition
from keller_zlatic_vnc.data_processing import extract_transitions
from keller_zlatic_vnc.data_processing import generate_transition_dff_table
from keller_zlatic_vnc.data_processing import read_raw_transitions_from_excel
from keller_zlatic_vnc.data_processing import recode_beh
from keller_zlatic_vnc.linear_modeling import one_hot_from_table
from keller_zlatic_vnc.linear_modeling import reference_one_hot_to_beh

from janelia_core.stats.regression import grouped_linear_regression_ols_estimator
from janelia_core.stats.regression import grouped_linear_regression_acm_stats
from janelia_core.stats.regression import visualize_coefficient_stats



In [3]:
%matplotlib notebook

## Options for analysis

In [4]:
# Type of cells we fit models to
cell_type = 'a00c' 

# If we fit data to perturbations targeted at 'A4', 'A9' or 'both'
manip_type = 'both'

# Define the cutoff time we use to define quiet behaviors following stimulation
cut_off_time = 3.656 #3.656 #9.0034

# Specify if we predict dff 'before' or 'after' the manipulation
period = 'after'

# Specify specific transitions we want to test significance for
#interactions = ['FF']
interactions = ['BB']

## Location of data

In [5]:
data_folder = r'/Users/williambishop/Desktop/extracted_dff_v2'
transition_file = 'transition_list.xlsx'

a00c_a4_act_data_file = 'A00c_activity_A4.mat'
a00c_a9_act_data_file = 'A00c_activity_A9.mat'

basin_a4_act_data_file = 'Basin_activity_A4.mat'
basin_a9_act_data_file = 'Basin_activity_A9.mat'

handle_a4_act_data_file = 'Handle_activity_A4.mat'
handle_a9_act_data_file = 'Handle_activity_A9.mat'

## Specify some parameters we use in the code below

In [6]:
if cell_type == 'a00c':
    a4_act_file = a00c_a4_act_data_file
    a9_act_file = a00c_a9_act_data_file
elif cell_type == 'basin':
    a4_act_file = basin_a4_act_data_file
    a9_act_file = basin_a9_act_data_file
elif cell_type == 'handle':
    a4_act_file = handle_a4_act_data_file
    a9_act_file = handle_a9_act_data_file
else:
    raise(ValueError('The cell type ' + cell_type + ' is not recogonized.'))

## Load data

In [7]:
# Read in raw transitions
raw_trans = read_raw_transitions_from_excel(pathlib.Path(data_folder) / transition_file)

# Read in activity
a4_act = scipy.io.loadmat(pathlib.Path(data_folder) / a4_act_file, squeeze_me=True)
a9_act = scipy.io.loadmat(pathlib.Path(data_folder) / a9_act_file, squeeze_me=True)

# Correct mistake in labeling if we need to
if cell_type == 'basin' or cell_type == 'handle':
    ind = np.argwhere(a4_act['newTransitions'] == '0824L2CL')[1][0]
    a4_act['newTransitions'][ind] = '0824L2-2CL'

# Recode behavioral annotations
raw_trans = recode_beh(raw_trans, 'Beh Before')
raw_trans = recode_beh(raw_trans, 'Beh After')

# Extract transitions
trans = extract_transitions(raw_trans, cut_off_time)

# Generate table of data 
a4table = generate_transition_dff_table(act_data=a4_act, trans=trans)
a9table = generate_transition_dff_table(act_data=a9_act, trans=trans)

# Put the tables together
a4table['man_tgt'] = 'A4'
a9table['man_tgt'] = 'A9'
data = a4table.append(a9table, ignore_index=True)

## Down select for manipulation target

In [8]:
if manip_type == 'A4':
    print('Analyzing only A4 manipulation events.')
    data = data[data['man_tgt'] == 'A4']
elif manip_type == 'A9':
    print('Analyzing only A9 manipulation events.')
    data = data[data['man_tgt'] == 'A9']
else:
    print('Analyzing all manipulation events.')

Analyzing all manipulation events.


## Remove trials that have no behavior of interest

In [9]:
beh_before = list(set([inter[0] for inter in interactions] + ['Q'])) 
beh_after = list(set([inter[1] for inter in interactions] + ['Q'])) 

keep_behs = set(interactions + ['Q' + b for b in beh_after] + [b + 'Q' for b in beh_before])

keep_rows = data.apply(lambda row: row['beh_before'] + row['beh_after'] in keep_behs, axis = 1)
data = data[keep_rows]

## Look at number of subjects we have for each type of transition

## Pull out $\Delta F/F$

In [11]:
if period == 'before':
    dff = data['dff_before'].to_numpy()
elif period == 'after':
    dff = data['dff_after'].to_numpy()
else:
    raise(ValueError('The period ' + ' period is not recogonized.'))

## Find grouping of data by subject

In [12]:
unique_ids = data['subject_id'].unique()
g = np.zeros(len(data))
for u_i, u_id in enumerate(unique_ids):
    g[data['subject_id'] == u_id] = u_i

In [13]:
# Do original encoding of all one hot vars
beh_before = list(set([inter[0] for inter in interactions] + ['Q'])) 
beh_after = list(set([inter[1] for inter in interactions] + ['Q'])) 

one_hot_data, one_hot_vars = one_hot_from_table(data, beh_before=beh_before, beh_after=beh_after, 
                                         enc_subjects=False, enc_beh_interactions=True)

# Reference one hot variables to quiet state
one_hot_data_ref, one_hot_vars_ref = reference_one_hot_to_beh(one_hot_data=one_hot_data,
                                                              one_hot_vars=one_hot_vars,
                                                              beh='Q',
                                                              remove_interaction_term=True)


# Include only interaction terms the user has specified
int_start_ind = len(beh_before) + len(beh_after) - 2
all_interaction_terms = [var[-2:] for var in one_hot_vars_ref[int_start_ind:]]
del_inds = [i + int_start_ind for i in range(len(all_interaction_terms))
            if all_interaction_terms[i] not in set(interactions)]

one_hot_data_ref = np.delete(one_hot_data_ref, del_inds, 1)
one_hot_vars_ref = [one_hot_vars_ref[i] for i in range(len(one_hot_vars_ref)) 
                    if i not in set(del_inds)]

# Add constant term
one_hot_data_ref = np.concatenate([one_hot_data_ref, np.ones([one_hot_data_ref.shape[0], 1])], axis=1)
one_hot_vars_ref = one_hot_vars_ref + ['ref']

beta, acm, n_gprs = grouped_linear_regression_ols_estimator(x=one_hot_data_ref, y=dff, g=g)
stats = grouped_linear_regression_acm_stats(beta=beta, acm=acm, n_grps=n_gprs, alpha=.05)

interact_ind: [7]
del_inds: [1, 3, 7]


LinAlgError: Singular matrix

## Visualize results

In [None]:
visualize_coefficient_stats(var_strs=one_hot_vars_ref, 
                            theta=beta, c_ints=stats['c_ints'],
                            sig=stats['non_zero'])
plt.ylabel('$\Delta F/F$')
plt.tight_layout()