A notebook for generating tables of single cell results when we break things out by segment for specific cell types

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import itertools
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from janelia_core.stats.multiple_comparisons import apply_by
from janelia_core.stats.multiple_comparisons import apply_bonferroni
from janelia_core.stats.regression import visualize_coefficient_stats

### Parameters go here 

In [3]:
# Specify folders where raw results for all analyses are
base_folders = [r'/Volumes/bishoplab/projects/keller_vnc/results/single_cell/publication_results_v0/a00c', 
                r'/Volumes/bishoplab/projects/keller_vnc/results/single_cell/publication_results_v0/basin',
                r'/Volumes/bishoplab/projects/keller_vnc/results/single_cell/publication_results_v0/handle']
                

# The regex pattern for results files
rs_file_str = 'rs_*.pkl'

# Specify the results that will go in the table
# Determine how we will populate the table
man_tgt = 'A4' # None corresponds to both A4 and A9 manipulations

# Specify if we look at results for pooled turns or not
pool_turns = True

# Specify the type of test we want results for, and how we abbreivate them 
test_types = {'SDAR': 'after_reporting', 
              'DD': 'decision_dependence',
              'PD': 'prediction_dependence',
              'BR': 'before_reporting'}

# Specify the cell type and segments we want results for and how we abbreviate them 
if False:
    cell_type = 'a00c'
    segment_ids = {'ant': ['antL', 'antR'],
                   'mid': ['midL', 'midR'],
                   'post': ['postL', 'postR']}
if True:
    cell_type = 'basin'
    segment_ids = {'A1': ['A1R', 'A1L', '1AL', '1AR'],
                   'A2': ['A2L', 'A2R', '2AL', '2AR'],
                   'A3': ['A3R', 'A3L', '3AL', '3AR'],
                   'A4': ['A4R', 'A4L', '4AL', '4AR'],
                   'A5': ['A5R', '5AL', '5AR', 'A5L '],
                   'A6': ['A6L','A6R', '6AL', '6AR'],
                   'A7': ['A7L', 'A7R', '7AL', '7AR'],
                   'A8': ['A8L', 'A8R', '8AL', '8AR'],
                   'A9': ['A9R', 'A9L', '9AL', '9AR'],
                   'T1': ['T1L', '1TL', 'T1R'],
                   'T2': ['T2R', 'T2L', '2TL', '2TR'],
                   'T3': ['T3L', 'T3R', '3TL', '3TR']}

if False:
    cell_type = 'handle'
    segment_ids = {'A1': ['A1'],
                   'A2': ['A2'],
                   'A3': ['A3'],
                   'A4': ['A4'],
                   'A5': ['A5', 'A5  '],
                   'A6': ['A6'],
                   'A7': ['A7'],
                   'A8': ['A8'],
                   'A9': ['A9'],
                   'T1': ['T1'],
                   'T2': ['T2'],
                   'T3': ['T3']}
    
# Generate table filter
tbl_filter = {}
for test_type_ab, test_type in test_types.items():
    for segment_id_ab, cell_ids in segment_ids.items():
        tbl_filter[test_type_ab + '_' + cell_type + '_' + segment_id_ab] = {'cell_type': cell_type, 
                                                                           'man_tgt': man_tgt,
                                                                           'test_type': test_type,
                                                                           'cell_ids': cell_ids, 
                                                                           'pool_turns': pool_turns}

# Specify if we show stats for the original models or the the mean comparisons
stats_type = 'orig_fit' # 'orig_fit' or 'mean_cmp'

# Level to control for multiple comparisons at 
mc_alpha = .05

## Load base results

In [4]:
all_files = []
for folder in base_folders:
    all_files.append(glob.glob(str(Path(folder) / rs_file_str)))
all_files = list(itertools.chain(*all_files))

In [5]:
all_rs = []
for f in all_files:
    with open(f, 'rb') as fl: 
        all_rs.append(pickle.load(fl))

## Filter results

In [6]:
def match(rs_l, match_dict):
    match = True
    for k in match_dict:
        if match_dict[k] is None:
            if rs_l['ps'][k] is not None:
                match = False
        else:
            if rs_l['ps'][k] != match_dict[k]:
                match = False
    return match

In [7]:
tbl_matches = dict()
for k in tbl_filter.keys():
    matches = np.argwhere([match(rs_i, tbl_filter[k]) for rs_i in all_rs])
    if len(matches) > 0:
        if  len(matches) != 1:
            raise(RuntimeError('Found multiple matches for key ' + k + '.'))
        tbl_matches[k] = matches[0][0]

## Get statistical results 

In [8]:
tbl_stats = dict()
for k in tbl_matches.keys():
    rs_k = all_rs[tbl_matches[k]]
    if stats_type == 'orig_fit':
        behs = rs_k['rs']['one_hot_vars_ref'][0:-1]
        p_vls = rs_k['rs']['init_fit_stats']['non_zero_p'][0:-1]
        beta = rs_k['rs']['init_fit']['beta'][0:-1]
        
        
        # Save extra information that is not necessary for making the table but useful for visualizing stats for
        # an original fit
        all_behs = rs_k['rs']['one_hot_vars_ref']
        all_p_vls = rs_k['rs']['init_fit_stats']['non_zero_p']
        all_beta = rs_k['rs']['init_fit']['beta']
        all_c_ints = rs_k['rs']['init_fit_stats']['c_ints']
        
        
        
        tbl_stats[k] = {'behs': behs, 'beta': beta, 'p_vls': p_vls, 
                        'all_behs': all_behs, 'all_p_vls': all_p_vls, 'all_beta': all_beta, 'all_c_ints': all_c_ints}
        
        
        
    elif stats_type == 'mean_cmp':
        tbl_stats[k] = {'behs': rs_k['rs']['cmp_stats']['cmp_vars'], 
                        'beta': np.ones(len(rs_k['rs']['cmp_stats']['cmp_vars'])),
                        'p_vls': rs_k['rs']['cmp_stats']['cmp_p_vls']}
    else:
        raise(ValueError('stats_type must be orig_fit or mean_cmp'))


## Apply multiple comparisons adjustment

In [9]:
all_p_vls = []
map_back_inds = dict()
cur_ind = 0
for k in tbl_stats.keys():
    all_p_vls.append(tbl_stats[k]['p_vls'])
    map_back_inds[k] = np.arange(len(tbl_stats[k]['p_vls'])) + cur_ind
    cur_ind += len(tbl_stats[k]['p_vls'])
    
all_p_vls = np.concatenate(all_p_vls)

In [10]:
_, adjusted_p_vls = apply_by(all_p_vls, mc_alpha)
for k in tbl_stats.keys():
    tbl_stats[k]['adjusted_p_vls'] = adjusted_p_vls[map_back_inds[k]]

## Put results in tables

In [11]:
def create_p_vl_tbl(stats, p_vl_str):
    
    rows = list(stats.keys())
    n_rows = len(rows)
    
    all_behs = list(set(list(itertools.chain(*[stats[k]['behs'] for k in rows]))))
    all_behs.sort()
    all_behs = [all_behs[i] for i in range(len(all_behs)-1, -1, -1)]
    n_behs = len(all_behs)
    
    
    tbl = pd.DataFrame(np.zeros([n_rows, n_behs]), index=rows, columns=all_behs)
    
    for row in rows:
        for beh in all_behs:
            if beh in stats[row]['behs']:
                beh_i = np.argwhere(np.asarray(stats[row]['behs']) == beh)[0][0]
                sign = np.sign(stats[row]['beta'][beh_i])
                tbl[beh][row] = sign*stats[row][p_vl_str][beh_i]
            else:
                tbl[beh][row] = np.nan
    
    return tbl

In [12]:
p_vl_tbl = create_p_vl_tbl(tbl_stats, 'p_vls')
adj_p_vl_tbl = create_p_vl_tbl(tbl_stats, 'adjusted_p_vls')

## Visualize tables

In [13]:
def style_negative(v):
    return 'color:red;' if v < 0 else None
def fade_non_sig(v):
    return 'opacity: 10%;' if (np.abs(v) > mc_alpha) else None

In [14]:
styled_p_vl_tbl = p_vl_tbl.style.applymap(style_negative).applymap(fade_non_sig)
styled_adj_p_vl_tbl = adj_p_vl_tbl.style.applymap(style_negative).applymap(fade_non_sig)

In [15]:
styled_p_vl_tbl

Unnamed: 0,beh_before_TC,beh_before_F,beh_before_B,beh_after_TC,beh_after_F,beh_after_B
SDAR_basin_A1,0.044685,0.02908,0.000726,-0.873245,0.013814,0.825605
SDAR_basin_A2,0.217003,0.014824,0.148882,-0.890644,0.010778,0.182499
SDAR_basin_A3,0.117308,0.017827,0.025235,-0.088703,0.010409,0.746734
SDAR_basin_A4,0.011175,0.002096,0.086915,-0.825638,9.6e-05,0.138175
SDAR_basin_A6,0.373532,0.020175,0.0089,-0.956012,0.000222,0.797053
SDAR_basin_A7,-0.660129,0.507063,0.565461,0.801036,0.004397,0.31997
SDAR_basin_A8,0.766137,0.083431,0.327621,-0.805082,0.011158,-0.766133
SDAR_basin_A9,0.05302,0.026016,0.007201,-0.131653,0.197061,-0.076474
DD_basin_A1,0.015489,0.011889,0.000176,-0.315387,0.218819,-0.215467
DD_basin_A2,0.198829,0.011476,0.089535,-0.804219,0.28641,0.652355


In [16]:
styled_adj_p_vl_tbl

Unnamed: 0,beh_before_TC,beh_before_F,beh_before_B,beh_after_TC,beh_after_F,beh_after_B
SDAR_basin_A1,0.550339,0.383431,0.023944,-1.0,0.221168,1.0
SDAR_basin_A2,1.0,0.230747,1.0,-1.0,0.191736,1.0
SDAR_basin_A3,1.0,0.262894,0.358003,-0.983796,0.188167,1.0
SDAR_basin_A4,0.192691,0.054638,0.974113,-1.0,0.004302,1.0
SDAR_basin_A6,1.0,0.289891,0.16353,-1.0,0.009203,1.0
SDAR_basin_A7,-1.0,1.0,1.0,1.0,0.104853,1.0
SDAR_basin_A8,1.0,0.954152,1.0,-1.0,0.192691,-1.0
SDAR_basin_A9,0.638952,0.36447,0.14159,-1.0,1.0,-0.883606
DD_basin_A1,0.234594,0.19888,0.007586,-1.0,1.0,-1.0
DD_basin_A2,1.0,0.194873,0.983796,-1.0,1.0,1.0
