In [1]:
import torch.nn as nn
import torch
import math
import matplotlib.pyplot as plt
from pathlib import Path
import pprint
import numpy as np
import pandas as pd
import os
from sklearn.decomposition import PCA
from dataset import *
import random
from train import *

In [2]:
from allensdk.core.brain_observatory_cache import BrainObservatoryCache

In [3]:
output_dir = '.'
boc =  BrainObservatoryCache(
    manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))
targeted_structures = boc.get_all_targeted_structures()
print("all targeted structures: " + str(targeted_structures))

all targeted structures: ['VISal', 'VISam', 'VISl', 'VISp', 'VISpm', 'VISrl']


In [4]:
# Download a list of all cre driver lines 
cre_lines = boc.get_all_cre_lines()
print("all cre lines:\n")
pprint.pprint(cre_lines)
cre_lines_to_use = [
    'Cux2-CreERT2',
    'Emx1-IRES-Cre',
    'Fezf2-CreER',
    'Nr5a1-Cre',
    'Ntsr1-Cre_GN220',
    'Rbp4-Cre_KL100',
    'Rorb-IRES2-Cre',
    'Scnn1a-Tg3-Cre',
    'Slc17a7-IRES2-Cre',
    'Tlx3-Cre_PL56',
]

all cre lines:

['Cux2-CreERT2',
 'Emx1-IRES-Cre',
 'Fezf2-CreER',
 'Nr5a1-Cre',
 'Ntsr1-Cre_GN220',
 'Pvalb-IRES-Cre',
 'Rbp4-Cre_KL100',
 'Rorb-IRES2-Cre',
 'Scnn1a-Tg3-Cre',
 'Slc17a7-IRES2-Cre',
 'Sst-IRES-Cre',
 'Tlx3-Cre_PL56',
 'Vip-IRES-Cre']


Download Experiments for a Container

An experiment container is a group of experiments. Each experiment has a different stimulus protocol. For example, one experiment protocol contains the static gratings stimulus and another has the natural scenes stimulus. The BrainObservatoryCache helps you find out which experiment associated with a container has the stimuli you are interested in. First, let's see what experiments are available for a single container.

The session_type field indicates which experimental protocol was used. If you just want to find which experiment contains the static gratings stimulus, you can do the following:

In [5]:
# import allensdk.brain_observatory.stimulus_info as stim_info

# # pick one of the cux2 experiment containers
# cux2_ec_id = cux2_ecs[-1]['id']

# # Find the experiment with the static static gratings stimulus
# exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], 
#                                 stimuli=[stim_info.STATIC_GRATINGS])[0]
# print("Experiment with static gratings:")
# pprint.pprint(exp)

Now we can download the NWB file for this experiment.

In [5]:
exps = get_exps(boc, cre_lines=cre_lines_to_use, targeted_structures=None, session_types=['three_session_B'])
len(exps)


336

In [27]:
data_set = boc.get_ophys_experiment_data(exps[19]['id'])
cids = data_set.get_cell_specimen_ids()
len(cids)

214

In [6]:
pre = 15
post = 7
dataset = prep_dataset(boc, exps[22:23], pre, post, data_type='dff')

In [None]:
data2plot = 100000
ax = plot_traces(dataset['model_input'][data2plot].numpy(), np.arange(45), input_type='pca', figsize=(2,10))
ax.set_title(f"image#{dataset['model_labels'][data2plot].numpy()}")
ax.axvline(x = 30, color = 'r', linestyle = '--')

test code for training

In [68]:
rand_idx = np.random.permutation(len(dataset['model_input']))
num_training_sample = int(len(dataset['model_input'])*0.7)
train_idx = rand_idx[:num_training_sample]
test_idx = rand_idx[num_training_sample:]
train_dataset = {'model_input' :[dataset['model_input'][i] for i in train_idx],
                    'model_labels':[dataset['model_labels'][i] for i in train_idx]
                }
test_dataset = {'model_input' :[dataset['model_input'][i] for i in test_idx],
                    'model_labels':[dataset['model_labels'][i] for i in test_idx]
                }

In [67]:
print(len(rand_idx))
print(len(train_idx))
print(len(test_idx))

1987300
1391110
596190


In [70]:
batch_idx = np.arange(0, 128)
batch_idx

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127])

In [75]:
batch_X = [train_dataset['model_input'][i].T for i in batch_idx]
batch_X = torch.stack(batch_X)
batch_X.shape

torch.Size([128, 45, 50])

In [80]:
batch_Y = [train_dataset['model_labels'][i] for i in batch_idx]
batch_Y = torch.stack(batch_Y)
batch_Y.shape

torch.Size([128])

In [30]:
batch_idx = random.sample(list(range(5950)), 128)
batch_idx

[5392,
 280,
 2858,
 4998,
 2488,
 544,
 2658,
 2992,
 3287,
 2665,
 5464,
 2447,
 5117,
 1199,
 179,
 4683,
 5048,
 5072,
 2750,
 4756,
 431,
 3397,
 2035,
 875,
 2408,
 80,
 2253,
 658,
 4391,
 2309,
 2297,
 4121,
 2008,
 4337,
 4389,
 5081,
 3817,
 2329,
 2730,
 3309,
 5128,
 3948,
 2264,
 3244,
 5848,
 4533,
 830,
 4504,
 4124,
 4157,
 5186,
 4882,
 2756,
 3005,
 5316,
 5288,
 331,
 5523,
 4764,
 953,
 5537,
 2726,
 3158,
 5263,
 3088,
 702,
 99,
 1916,
 5106,
 2117,
 4418,
 475,
 1985,
 3688,
 5627,
 3809,
 4687,
 5582,
 5353,
 5269,
 3293,
 4572,
 4352,
 3728,
 423,
 3349,
 180,
 5645,
 5690,
 3146,
 5152,
 3659,
 5708,
 5891,
 5376,
 1164,
 3667,
 2273,
 1318,
 2721,
 1952,
 4090,
 5404,
 4586,
 3665,
 1454,
 977,
 92,
 865,
 3872,
 5548,
 2531,
 1444,
 3978,
 5677,
 3803,
 2416,
 2023,
 3780,
 1370,
 2170,
 3868,
 654,
 3105,
 4505,
 5362,
 3595,
 4943]