In [1]:
import torch.nn as nn
import torch
import math
import matplotlib.pyplot as plt
from pathlib import Path
import pprint
import numpy as np
import pandas as pd
import os
from sklearn.decomposition import PCA
from dataset import *
import random
from train import *
from sklearn.cross_decomposition import CCA

In [2]:
from allensdk.core.brain_observatory_cache import BrainObservatoryCache

In [3]:
output_dir = '.'
boc =  BrainObservatoryCache(
    manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))
targeted_structures = boc.get_all_targeted_structures()
print("all targeted structures: " + str(targeted_structures))

all targeted structures: ['VISal', 'VISam', 'VISl', 'VISp', 'VISpm', 'VISrl']


In [4]:
# Download a list of all cre driver lines 
cre_lines = boc.get_all_cre_lines()
print("all cre lines:\n")
pprint.pprint(cre_lines)
cre_lines_to_use = [
    'Cux2-CreERT2',
    'Emx1-IRES-Cre',
    'Fezf2-CreER',
    'Nr5a1-Cre',
    'Ntsr1-Cre_GN220',
    'Rbp4-Cre_KL100',
    'Rorb-IRES2-Cre',
    'Scnn1a-Tg3-Cre',
    'Slc17a7-IRES2-Cre',
    'Tlx3-Cre_PL56',
]

all cre lines:

['Cux2-CreERT2',
 'Emx1-IRES-Cre',
 'Fezf2-CreER',
 'Nr5a1-Cre',
 'Ntsr1-Cre_GN220',
 'Pvalb-IRES-Cre',
 'Rbp4-Cre_KL100',
 'Rorb-IRES2-Cre',
 'Scnn1a-Tg3-Cre',
 'Slc17a7-IRES2-Cre',
 'Sst-IRES-Cre',
 'Tlx3-Cre_PL56',
 'Vip-IRES-Cre']


Download Experiments for a Container

An experiment container is a group of experiments. Each experiment has a different stimulus protocol. For example, one experiment protocol contains the static gratings stimulus and another has the natural scenes stimulus. The BrainObservatoryCache helps you find out which experiment associated with a container has the stimuli you are interested in. First, let's see what experiments are available for a single container.

The session_type field indicates which experimental protocol was used. If you just want to find which experiment contains the static gratings stimulus, you can do the following:

In [5]:
# import allensdk.brain_observatory.stimulus_info as stim_info

# # pick one of the cux2 experiment containers
# cux2_ec_id = cux2_ecs[-1]['id']

# # Find the experiment with the static static gratings stimulus
# exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], 
#                                 stimuli=[stim_info.STATIC_GRATINGS])[0]
# print("Experiment with static gratings:")
# pprint.pprint(exp)

Now we can download the NWB file for this experiment.

In [5]:
exps = get_exps(boc, cre_lines=cre_lines_to_use, targeted_structures=['VISp'], session_types=['three_session_B'])
len(exps)


94

In [13]:
numCell = []
for exp in exps:
    data_set = boc.get_ophys_experiment_data(exp['id'])
    cids = data_set.get_cell_specimen_ids()
    numCell.append(len(cids))
np.argmax(numCell)


84

In [30]:
numCell[4]

20

In [6]:
pre = 15
post = 7
dataset = prep_dataset(boc, exps[22:23], pre, post, data_type='dff')

In [None]:
data2plot = 100000
ax = plot_traces(dataset['model_input'][data2plot].numpy(), np.arange(45), input_type='pca', figsize=(2,10))
ax.set_title(f"image#{dataset['model_labels'][data2plot].numpy()}")
ax.axvline(x = 30, color = 'r', linestyle = '--')

test code for training

In [68]:
rand_idx = np.random.permutation(len(dataset['model_input']))
num_training_sample = int(len(dataset['model_input'])*0.7)
train_idx = rand_idx[:num_training_sample]
test_idx = rand_idx[num_training_sample:]
train_dataset = {'model_input' :[dataset['model_input'][i] for i in train_idx],
                    'model_labels':[dataset['model_labels'][i] for i in train_idx]
                }
test_dataset = {'model_input' :[dataset['model_input'][i] for i in test_idx],
                    'model_labels':[dataset['model_labels'][i] for i in test_idx]
                }

In [6]:
pca_comp = 20
exp = exps[19]
meta = boc.get_ophys_experiment_data(exp['id']).get_metadata()
cellnum = len(boc.get_ophys_experiment_data(exp['id']).get_cell_specimen_ids())
print(f'num of cell: {cellnum}')
pre = 15
post = 7
dff = get_fluo(boc, exp)
pca_dff, explained_var = pca_and_pad(dff, num_comp=pca_comp)
print(f'num comp: {pca_comp}, explained variance: {explained_var:.2f}')
stim_df = get_stim_df(boc, exp, stimulus_name='natural_scenes')
ref_data, ref_labels = extract_data_by_images(pca_dff, stim_df, pre, post)
ref_labels  = [118 if x == -1 else x for x in ref_labels]
sort_idx = np.argsort(ref_labels)
sort_data = [ref_data[i] for i in sort_idx]
sort_labels = [ref_labels[i] for i in sort_idx]
# sampled_data = []
# for i in range(119):
#     sampled_data = sampled_data + sort_data[i*50:i*50+20]
ref_concat_data = np.concatenate(sort_data, axis=1)
ref_concat_data.shape

num of cell: 331
num comp: 20, explained variance: 0.67


(20, 178500)

In [7]:
pca_comp = 20
exp = exps[3]
meta = boc.get_ophys_experiment_data(exp['id']).get_metadata()
cellnum = len(boc.get_ophys_experiment_data(exp['id']).get_cell_specimen_ids())
print(f'num of cell: {cellnum}')
pre = 15
post = 7
dff = get_fluo(boc, exp)
pca_dff, explained_var = pca_and_pad(dff, num_comp=pca_comp)
print(f'num comp: {pca_comp}, explained variance: {explained_var:.2f}')
stim_df = get_stim_df(boc, exp, stimulus_name='natural_scenes')
data, labels = extract_data_by_images(pca_dff, stim_df, pre, post)
labels  = [118 if x == -1 else x for x in labels]
sort_idx = np.argsort(labels)
sort_data = [data[i] for i in sort_idx]
sort_labels = [labels[i] for i in sort_idx]
# sampled_data = []
# for i in range(119):
#     sampled_data = sampled_data + sort_data[i*50:i*50+20]
concat_data = np.concatenate(sort_data, axis=1)
concat_data.shape

num of cell: 249
num comp: 20, explained variance: 0.33


(20, 178500)

In [8]:
pca_comp_corr = []
for i in range(concat_data.shape[0]):
    pca_comp_corr.append(np.corrcoef(ref_concat_data[i,:], concat_data[i,:])[0,1])
np.mean(pca_comp_corr)

0.00091427237905062

In [9]:
cca = CCA(n_components=ref_concat_data.shape[0])

cca.fit(ref_concat_data.T, concat_data.T)

trans_data = concat_data.T.dot(cca.y_rotations_).dot(np.linalg.inv(cca.x_rotations_)).T

In [10]:
return_data, return_labels = cca_align(ref_data, ref_labels, data, labels)

In [12]:
len(return_data)

5950

In [None]:
plot_traces(trans_data, x_range=range(2000), input_type='pca',figsize=(5,5))

In [11]:
pca_comp_corr = []
for i in range(concat_data.shape[0]):
    pca_comp_corr.append(np.corrcoef(ref_concat_data[i,:], trans_data[i,:])[0,1])
np.mean(pca_comp_corr)

0.027901439423242202