In [5]:
import os
import pickle

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gs
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm, trange

from lib.data import conditions as cond
from lib.data import tables as tab
from lib import bootstrap as bs
from lib.features import correlation as corr
from lib.features import decomposition as de

In [2]:
engine = tab.sa.create_engine(os.environ['SQLALCHEMY_ENGINE_URL'])
Session = tab.sa.orm.sessionmaker(bind=engine)
session = Session()

In [3]:
sample_mouseids = bs.get_mouseid_sampler(session, cond.CellType.pyr)
sample_cellids = {
    mid: (lambda mid=mid: bs.get_day_one_cell_ids(session, mid))
    for mid in bs.get_mouseids_by_celltype(session, cond.CellType.pyr)
}
sample_averagetr = {
    day: {
        mid: bs.get_averagetrace_given_cellid_sampler(session, sample_cellids[mid], mid, day)
        for mid in bs.get_mouseids_by_celltype(session, cond.CellType.pyr)
    }
    for day in (1, 7)
}

In [6]:
trial_structure = de.TrialBasisFunctions(
    390, tone_duration=1.0, delay_duration=1.5, reward_duration=2.5
)
trial_bases = {
    'tone + delay': trial_structure.tone + trial_structure.delay,
    'reward': trial_structure.reward
}

In [7]:
bootstrap_records = []
for i in trange(10, desc='Bootstrap samples', leave=True):
    mouseids = sample_mouseids()
    average_traces = []
    for mid in tqdm(mouseids, desc='Mice', leave=False):
        for label, basis in trial_bases.items():
            sp_stats_d1 = corr.vectorized_spearman_corr(sample_averagetr[1][mid](), basis)
            sp_stats_d7 = corr.vectorized_spearman_corr(sample_averagetr[7][mid](), basis)
            sp_stats_delta = sp_stats_d7 - sp_stats_d1  # Cell-by-cell change in tuning. Assumes sample_averagetr returns cells in same order.
            bootstrap_records.extend([
                {'bs_sample': i, 'mouse_id': mid, 'day': '1', 'trial_component': label, 'spearman': sp_stat}
                for sp_stat in sp_stats_d1
            ])
            bootstrap_records.extend([
                {'bs_sample': i, 'mouse_id': mid, 'day': '7', 'trial_component': label, 'spearman': sp_stat}
                for sp_stat in sp_stats_d7
            ])
            bootstrap_records.extend([
                {'bs_sample': i, 'mouse_id': mid, 'day': 'delta', 'trial_component': label, 'spearman': sp_stat}
                for sp_stat in sp_stats_delta
            ])

HBox(children=(FloatProgress(value=0.0, description='Bootstrap samples', max=10.0, style=ProgressStyle(descrip…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='Mice', max=6.0, style=ProgressStyle(description_width='in…




In [22]:
session.add(tab.SessionTrace(cell_id=1, day=90, trace=np.array([1, 2, 3, 4])))
session.query(tab.SessionTrace).filter(tab.SessionTrace.cell_id==1, tab.SessionTrace.day==90).first().trace


InvalidRequestError: This Session's transaction has been rolled back due to a previous exception during flush. To begin a new transaction with this Session, first issue Session.rollback(). Original exception was: (raised as a result of Query-invoked autoflush; consider using a session.no_autoflush block if this flush is occurring prematurely)
(MySQLdb._exceptions.IntegrityError) (1452, 'Cannot add or update a child row: a foreign key constraint fails (`main`.`sessiontraces`, CONSTRAINT `sessiontraces_ibfk_1` FOREIGN KEY (`cell_id`) REFERENCES `cells` (`id`))')
[SQL: INSERT INTO `SessionTraces` (cell_id, day, trace) VALUES (%s, %s, %s)]
[parameters: (99999, 90, b'\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@')]
(Background on this error at: http://sqlalche.me/e/13/gkpj) (Background on this error at: http://sqlalche.me/e/13/7s2a)

In [23]:
session.rollback()

  util.warn(


In [14]:
sample_averagetr[1][mid]()[0, :]

array([-8.66271304e+36,             nan, -6.90503423e+36, -1.42227825e+36,
       -1.04707214e+26,  5.81825453e+34,  4.26285391e+26,             nan,
        3.17870541e+27,             nan,  2.24654834e+29, -5.40253757e+32,
        2.95857583e+28,  3.39742837e+21,  4.16529131e+31,  3.27743316e+35,
       -1.85641924e+34,  9.64419186e+35,  3.76292150e+36, -1.49746553e+36,
        1.56346474e+19,  4.72612222e+35, -1.07946987e+34,             nan,
        4.67615275e+32,  1.36240047e+36,  2.47017119e+34,  7.35565458e+33,
       -3.32881695e+36, -7.66798351e+34, -2.43239293e+35,  1.19937277e+32,
       -4.92215464e+13,  4.60685804e+32, -3.03265289e+28,             nan,
       -2.25714550e+23,  6.45144137e+30,  9.70047644e+31,  8.39378809e+33,
       -1.50420678e+35, -3.56144957e+35,  1.57026607e+36, -3.90182965e+35,
       -1.27023951e+29, -8.14857978e+35,  2.16218104e+28,  3.67750132e+33,
       -5.68825427e+31, -1.90063367e+34,  1.06785284e+32,  5.58533022e+36,
       -1.90346742e+28, -

In [11]:
corr.vectorized_spearman_corr(sample_averagetr[1][mid](), basis)

array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na

In [8]:
bs_df = pd.DataFrame(bootstrap_records)

In [9]:
bs_df.head()

Unnamed: 0,bs_sample,mouse_id,day,trial_component,spearman
0,0,CL184,1,tone + delay,
1,0,CL184,1,tone + delay,
2,0,CL184,1,tone + delay,
3,0,CL184,1,tone + delay,
4,0,CL184,1,tone + delay,


In [5]:
def foo(val):
    fooval = 'foo'+val
    def bar():
        print(fooval)
    return bar

foobar = foo('bar')
foobar()
foobaz = foo('baz')
foobaz()
foobar()

foobar
foobaz
foobar


In [4]:
sample_cellids['CL174']()

[945,
 946,
 947,
 948,
 949,
 950,
 951,
 952,
 953,
 954,
 955,
 956,
 957,
 958,
 959,
 960,
 961,
 962,
 963,
 964,
 965,
 966,
 967,
 968,
 969,
 970,
 971,
 972,
 973,
 974,
 975,
 976,
 977,
 978,
 979,
 980,
 981,
 982,
 983,
 984,
 985,
 986,
 987,
 988,
 989,
 990,
 991,
 992,
 993,
 994,
 995,
 996,
 997,
 998,
 999,
 1000,
 1001,
 1002,
 1003,
 1004,
 1005,
 1006,
 1007,
 1008,
 1009,
 1010,
 1011,
 1012,
 1013,
 1014,
 1015,
 1016,
 1017,
 1018,
 1019,
 1020,
 1021,
 1022,
 1023,
 1024,
 1025,
 1026,
 1027,
 1028,
 1029,
 1030,
 1031,
 1032,
 1033,
 1034,
 1035,
 1036,
 1037,
 1038,
 1039,
 1040,
 1041,
 1042,
 1043,
 1044,
 1045,
 1046,
 1047,
 1048,
 1049,
 1050,
 1051,
 1052,
 1053,
 1054,
 1055,
 1056,
 1057,
 1058,
 1059,
 1060,
 1061,
 1062,
 1063,
 1064,
 1065,
 1066,
 1067,
 1068,
 1069,
 1070,
 1071,
 1072,
 1073,
 1074,
 1075,
 1076,
 1077,
 1078,
 1079,
 1080,
 1081,
 1082,
 1083,
 1084,
 1085,
 1086,
 1087,
 1088,
 1089,
 1090,
 1091,
 1092,
 1093,
 1094,
 1095

In [5]:
(
    session.query(tab.SessionTrace)
    .filter(tab.SessionTrace.day==1, tab.SessionTrace.cell_id==999)
    .first()
).trace

array([-7.8816150e+28, -1.3762600e+00,  1.2028697e-14, ...,
       -1.3949311e+00, -2.6179599e-03, -1.4392928e+00], dtype=float32)

In [11]:
dir(tab.SessionTrace)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__mapper__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__table__',
 '__tablename__',
 '__weakref__',
 '_decl_class_registry',
 '_sa_class_manager',
 'cell',
 'cell_id',
 'day',
 'metadata',
 'trace']