In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import xarray
import pickle
import torch

from zfa.core.default_dirs import DATA_ROOT

from dandi.dandiapi import DandiAPIClient
from pynwb import NWBHDF5IO
from nwbwidgets import nwb2widget

%matplotlib inline
%load_ext autoreload
%autoreload 2


assert torch.cuda.is_available(), print('Not on a compute node. Do not run anything.')

In [2]:
# has both neurons and glia

sesh1 = '/om2/group/fiete/zfa/000350/sub-20161109-2/sub-20161109-2_ses-20161109T211950_ophys.nwb'
sesh2 = '/om2/group/fiete/zfa/000350/sub-20170203-1/sub-20170203-1_ses-20170203T122038_ophys.nwb'
sesh3 = '/om2/group/fiete/zfa/000350/sub-20170228-3/sub-20170228-3_ses-20170228T165730_ophys.nwb'
sesh4 = '/om2/group/fiete/zfa/000350/sub-20170228-4/sub-20170228-4_ses-20170228T185002_ophys.nwb'



In [16]:
glial_save_dir = '/om2/group/yanglab/zfa/'
neural_save_dir = '/om2/group/yanglab/zfa/'

#check that I can load the stored data back 
with open(glial_save_dir + 'glial_trials.pickle', 'rb') as handle:
    glial_trials = pickle.load(handle)

with open(neural_save_dir + 'neural_trials.pickle', 'rb') as handle:
    neural_trials = pickle.load(handle)

In [4]:
brain_data = glial_trials['tensors']
ANIMALS = list(brain_data.keys())

In [5]:
def get_unit_ranges(data, processing_chunk_size):
    # assumed that last index is unit index

    num_units = data.shape[-1]  # get number of units
    num_chunks = int(num_units / processing_chunk_size)  # get processing chunk size

    num_residuals = (
        num_units - num_chunks * processing_chunk_size
    )  # get number of residual units

    # get start and stop unit indices
    units = []
    for i in range(num_chunks):
        units.append((i * processing_chunk_size, (i + 1) * processing_chunk_size))

    # include the residual units
    units.append(
        (
            num_chunks * processing_chunk_size,
            num_chunks * processing_chunk_size + num_residuals,
        )
    )

    return units

In [44]:
from zfa.core.default_dirs import BASE_DIR

# build dictionary with info for processing chunks

processing_chunk_size = 10000

job_info_for_chunking = {}

for animal in ANIMALS:
    print(animal)
    job_info_for_chunking[animal] ={}
    job_info_for_chunking[animal]['glial_num_units'] = glial_trials['tensors'][animal].sizes['units']
    job_info_for_chunking[animal]['neural_num_units'] = neural_trials['tensors'][animal].sizes['units']

    #get units to process based on processing chunk size
    glial_units = get_unit_ranges(glial_trials['tensors'][animal], processing_chunk_size=processing_chunk_size)
    glial_num_jobs = len(glial_units)

    neural_units = get_unit_ranges(neural_trials['tensors'][animal], processing_chunk_size=processing_chunk_size)
    neural_num_jobs = len(neural_units)

    job_info_for_chunking[animal]['glial_units'] = glial_units
    job_info_for_chunking[animal]['neural_units'] = neural_units
    
    job_info_for_chunking[animal]['glial_num_jobs'] = glial_num_jobs
    job_info_for_chunking[animal]['neural_num_jobs'] = neural_num_jobs

    job_info_for_chunking[animal]['glial_processing_chunk_size'] = processing_chunk_size
    job_info_for_chunking[animal]['neural_processing_chunk_size'] = processing_chunk_size



sub-20170228-3_ses-20170228T165730_ophys
sub-20170228-4_ses-20170228T185002_ophys


In [43]:
with open(os.path.join(BASE_DIR,'job_info_for_chunking.pickle'), 'wb') as handle:
    pickle.dump(job_info_for_chunking, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [46]:
# a bit overkill because number of jobs is always thirteen 
ANIMALS  = list(job_info_for_chunking.keys())

['sub-20170228-3_ses-20170228T165730_ophys',
 'sub-20170228-4_ses-20170228T185002_ophys']

# Inter Animal Consistency

In [4]:
from sklearn.model_selection import KFold
from brainmodel_utils.core.constants import RIDGECV_ALPHA_CV
from brainmodel_utils.neural_mappers.utils import (
    generate_train_test_splits,
    convert_dict_to_tuple,
)
from brainmodel_utils.metrics.consistency import get_linregress_consistency
import itertools

In [26]:
splits = generate_train_test_splits(
        num_stim=20, num_splits=5, train_frac=0.5
    )
num_cv_splits = 5

results_list = []
for s in splits:
    results = {}
    for animal_pair in itertools.permutations(ANIMALS, r=2):
        source_animal = animal_pair[0]
        target_animal = animal_pair[1]
        target_resp = brain_data[target_animal][:,:,0:20]
        source_resp = brain_data[source_animal][:,:,0:20]

        target_resp_train = target_resp.isel(time=s["train"])
        source_resp_train = source_resp.isel(time=s["train"])

        kf = KFold(n_splits=num_cv_splits)

        cv_splits = []
        for cv_train_idx, cv_val_idx in kf.split(
            X=source_resp_train.mean(dim="trials", skipna=True)
        ):
            cv_splits.append({"train": cv_train_idx, "test": cv_val_idx})

        if target_animal not in results.keys():
            results[target_animal] = {}

        for alpha in RIDGECV_ALPHA_CV:
            map_kwargs = {
                "map_type": "sklinear",
                "map_kwargs": {
                    "regression_type": "Ridge",
                    "regression_kwargs": {"alpha": alpha},
                },
            }
            # turn it into immutable tuple to store as a key
            map_kwargs_key = convert_dict_to_tuple(map_kwargs)
            results[target_animal][map_kwargs_key] = get_linregress_consistency(
                source=source_resp_train,
                target=target_resp_train,
                map_kwargs=map_kwargs,
                num_bootstrap_iters=10,
                num_parallel_jobs=1,
                splits=cv_splits
            )

        results_list.append(results)


  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  dual_coef = linalg.solve(K, y, assume_a="pos", overwrite_a=False)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  return (2.0 * r) / (1.0 + r)
  

KeyboardInterrupt: 

In [30]:
results

{'sub-20170228-4_ses-20170228T185002_ophys': {(('map_type', 'sklinear'),
   ('map_kwargs',
    (('regression_type', 'Ridge'),
     ('regression_kwargs', (('alpha', 1e-09),))))): {'train': defaultdict(list,
               {'r_xy_n_sb': <xarray.DataArray (trial_bootstrap_iters: 10, train_test_splits: 5, units: 20)>
                array([[[  2.32653135,   2.2481547 ,  29.17836965,   1.01304877,
                           1.01720292,   1.01947972,   1.01337736,   1.0138568 ,
                           1.03766139,   1.0310176 ,   1.01636091,   1.05197426,
                           1.0310622 ,   1.01051069,   1.01355136,   1.01281055,
                           1.00826759,   1.03492238,   1.05183938,   1.00796609],
                        [  3.18349756,   1.65757765,  17.04095914,   1.02653482,
                           1.01409864,   1.01649484,   1.01556737,   1.01433956,
                           1.03321026,   1.03069565,   1.0150743 ,   1.04017385,
                           1.0308105

In [18]:
cv_splits

[{'train': array([2, 3, 4, 5, 6, 7, 8, 9]), 'test': array([0, 1])},
 {'train': array([0, 1, 4, 5, 6, 7, 8, 9]), 'test': array([2, 3])},
 {'train': array([0, 1, 2, 3, 6, 7, 8, 9]), 'test': array([4, 5])},
 {'train': array([0, 1, 2, 3, 4, 5, 8, 9]), 'test': array([6, 7])},
 {'train': array([0, 1, 2, 3, 4, 5, 6, 7]), 'test': array([8, 9])}]

In [None]:
def build_param_lookup(args):
    if args.brain_areas is None:
        brain_areas = BRAIN_AREAS
    else:
        brain_areas = args.brain_areas.split(",")

    param_lookup = {}
    key = 0
    for brain_area in brain_areas:
        param_lookup[str(key)] = {
            "brain_area": brain_area,
            "train_frac": args.train_frac,
            "time_mode": args.time_mode,
            "temporal": args.temporal,
            "trial_threshold": args.trial_threshold,
            "downsample_rate": args.downsample_rate,
            "start_offset_sec": args.start_offset_sec,
            "trial_frac_lower_bound": args.trial_frac_lower_bound,
            "additional_target_offset_sec": args.additional_target_offset_sec,
            "num_splits": args.num_splits,
            "num_cv_splits": args.num_cv_splits,
            "num_bootstrap_iters": args.num_bootstrap_iters,
            "num_parallel_jobs": args.num_parallel_jobs,
            "enforce_finite_mean": True if not args.no_finite_mean_filt else False,
        }
        key += 1

    return param_lookup

def perform_cv(
    brain_area,
    train_frac,
    num_splits,
    num_cv_splits,
    num_parallel_jobs,
    temporal=False,
    trial_threshold=360,
    downsample_rate=None,
    start_offset_sec=0.0,
    trial_frac_lower_bound=0.5,
    additional_target_offset_sec=0.2,
    num_bootstrap_iters=1000,
    metric="pearsonr",
    enforce_finite_mean=True,
    time_mode="target_gocue",
):
    assert len(ANIMALS) == 2  # makes code below simpler

    packaged_fn = get_packaged_data_filename(
        brain_area=brain_area,
        time_mode=time_mode,
        temporal=temporal,
        collapse_temporal=True,
        trial_threshold=trial_threshold,
        start_offset_sec=start_offset_sec,
        downsample_rate=downsample_rate,
        enforce_finite_mean=enforce_finite_mean,
        trial_frac_lower_bound=trial_frac_lower_bound,
        additional_target_offset_sec=additional_target_offset_sec,
    )
    packaged_fn = os.path.join(NEURAL_RESP_PACKAGED, packaged_fn)
    if os.path.exists(packaged_fn):
        print(f"Loading packaged data from {packaged_fn}")
        brain_data = np.load(packaged_fn, allow_pickle=True)["arr_0"][()]
    else:
        if temporal:
            brain_data = get_single_area_common_stim_resp(
                brain_area=brain_area,
                time_mode=time_mode,
                temporal=temporal,
                collapse_temporal=True,
                trial_threshold=trial_threshold,
                start_offset_sec=start_offset_sec,
                downsample_rate=downsample_rate,
                enforce_finite_mean=enforce_finite_mean,
            )
        else:
            brain_data = get_single_area_common_stim_resp(
                brain_area=brain_area,
                time_mode=time_mode,
                temporal=temporal,
                enforce_finite_mean=enforce_finite_mean,
                trial_frac_lower_bound=trial_frac_lower_bound,
                additional_target_offset_sec=additional_target_offset_sec,
            )
    if temporal:
        num_stim = len(brain_data[ANIMALS[0]].stim_time)
    else:
        num_stim = len(brain_data[ANIMALS[0]].stimuli)

    splits = generate_train_test_splits(
        num_stim=num_stim, num_splits=num_splits, train_frac=train_frac
    )

    results_list = []
    for s in splits:
        results = {}
        for animal_pair in itertools.permutations(ANIMALS, r=2):
            source_animal = animal_pair[0]
            target_animal = animal_pair[1]
            target_resp = brain_data[target_animal]
            source_resp = brain_data[source_animal]
            # cv is only on the train data
            if temporal:
                target_resp_train = target_resp.isel(stim_time=s["train"])
                source_resp_train = source_resp.isel(stim_time=s["train"])
            else:
                target_resp_train = target_resp.isel(stimuli=s["train"])
                source_resp_train = source_resp.isel(stimuli=s["train"])
            assert target_resp_train.ndim == 3
            assert source_resp_train.ndim == 3

            kf = KFold(n_splits=num_cv_splits)
            cv_splits = []
            for cv_train_idx, cv_val_idx in kf.split(
                X=source_resp_train.mean(dim="trials", skipna=True)
            ):
                cv_splits.append({"train": cv_train_idx, "test": cv_val_idx})

            if target_animal not in results.keys():
                results[target_animal] = {}

            for alpha in RIDGECV_ALPHA_CV:
                map_kwargs = {
                    "map_type": "sklinear",
                    "map_kwargs": {
                        "regression_type": "Ridge",
                        "regression_kwargs": {"alpha": alpha},
                    },
                }
                # turn it into immutable tuple to store as a key
                map_kwargs_key = convert_dict_to_tuple(map_kwargs)
                results[target_animal][map_kwargs_key] = get_linregress_consistency(
                    source=source_resp_train,
                    target=target_resp_train,
                    map_kwargs=map_kwargs,
                    num_bootstrap_iters=num_bootstrap_iters,
                    num_parallel_jobs=num_parallel_jobs,
                    splits=cv_splits,
                    metric=metric,
                )

        results_list.append(results)

    fn = get_filename(
        map_name="ridgecv",
        brain_area=brain_area,
        interanimal=True,
        common_resp=True,
        num_splits=num_splits,
        num_cv_splits=num_cv_splits,
        train_frac=train_frac,
        trial_frac_lower_bound=trial_frac_lower_bound,
        additional_target_offset_sec=additional_target_offset_sec,
        num_bootstrap_iters=num_bootstrap_iters,
        metric=metric,
        enforce_finite_mean=enforce_finite_mean,
        time_mode=time_mode,
        temporal=temporal,
        collapse_temporal=True,
        trial_threshold=trial_threshold,
        start_offset_sec=start_offset_sec,
        downsample_rate=downsample_rate,
    )
    fn = os.path.join(NEURAL_FIT_CV_SEARCH_DIR, fn)
    np.savez(fn, results_list)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--brain-areas", type=str, default=None, required=True)
    parser.add_argument("--train-frac", type=float, default=0.8)
    parser.add_argument("--time-mode", type=str, default="target_gocue")
    parser.add_argument("--temporal", type=bool, default=False)
    parser.add_argument("--trial-threshold", type=int, default=360)
    parser.add_argument("--downsample-rate", type=int, default=None)
    parser.add_argument("--start-offset-sec", type=float, default=0.0)
    parser.add_argument("--trial-frac-lower-bound", type=float, default=0.5)
    parser.add_argument("--additional-target-offset-sec", type=float, default=0.2)
    parser.add_argument("--num-splits", type=int, default=5)
    parser.add_argument("--num-bootstrap-iters", type=int, default=1000)
    parser.add_argument("--num-cv-splits", type=int, default=5)
    parser.add_argument("--num-parallel-jobs", type=int, default=1)
    parser.add_argument("--no-finite-mean-filt", type=bool, default=False)
    args = parser.parse_args()

    params = build_param_lookup(args)
    print(f"Num jobs: {len(list(params.keys()))}")
    curr_params = params[os.environ.get("SLURM_ARRAY_TASK_ID")]
    perform_cv(**curr_params)