In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

from pathlib import Path
from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
import tqdm
import torch
import matplotlib.pyplot as plt

from tbfm import film
from tbfm import multisession

DATA_DIR = "/home/mmattb/Projects/opto-coproc/data"
sys.path.append(DATA_DIR)
# imported from JNE project
import dataset
meta = dataset.load_meta(DATA_DIR)

OUT_DIR = "data"  # Local data cache; i.e. not reading from the opto-coproc folder.
EMBEDDING_REST_SUBDIR = "embedding_rest"

conf_dir = Path("./conf").resolve()

# Initialize Hydra with the configuration directory
with initialize_config_dir(config_dir=str(conf_dir), version_base=None):
    # Compose the configuration
    cfg = compose(config_name="config")   # i.e. conf/config.yaml

DEVICE = "cuda:0" #cfg.device
WINDOW_SIZE = cfg.data.trial_len
NUM_HELD_OUT_SESSIONS = cfg.training.num_held_out_sessions
BASE_MODEL_PATH = "session3.torch"

In [6]:
# Okay, now a stim data loader...
d, held_out_session_ids = multisession.load_stim_batched(                                                             
    batch_size=7500,                                                               
    window_size=WINDOW_SIZE,                                                               
    session_subdir="torchraw",                                                     
    data_dir=DATA_DIR,
    # held_in_session_ids=["MonkeyG_20150914_Session1_S1", "MonkeyG_20150915_Session2_S1"],
    held_in_session_ids=["MonkeyG_20150925_Session2_S1", "MonkeyJ_20160630_Session3_S1", "MonkeyG_20150917_Session1_M1"],
    # held_in_session_ids=["MonkeyG_20150925_Session2_S1"],
    num_held_out_sessions=NUM_HELD_OUT_SESSIONS,                                                      
)
data_train, data_test = d.train_test_split(5000, test_cut=2500)

held_in_session_ids = data_train.session_ids

# Gather cached rest embeddings...
embeddings_rest = multisession.load_rest_embeddings(held_in_session_ids, device=DEVICE)

In [7]:
# Verify batch sizes...
b = next(iter(data_train))
k = list(b.keys())
k0 = k[0]


b = next(iter(data_train))[k0]

print(f"per session batch size: {b[0].shape[0]}")

per session batch size: 5000


In [38]:
# Make the model. Note that AEs will still be PCA warm-started, and normalizers too.

ms = multisession.build_from_cfg(cfg, data_train, base_model_path=BASE_MODEL_PATH, device=DEVICE)

# TODO: Now figure out how we want to do batch size here...  Probably our efficiency experiment is: batch_size=session_count*bsize.
# TODO: get optimizer for FiLM parts only  model_optims = multisession.get_optims(cfg, ms). film.inner_update_stopgrad()
# TODO: sample efficiency experiment

Building and fitting normalizers...
Building and warm starting AEs...
Loading base TBFM from file...
BOOM! Dino DNA!


In [39]:
# Let's do a silly validation: load one of the *held in* sessions and use the warm started AE and normalizers.
# This should get okay-ish performance?

embeddings_stim, results = multisession.test_time_adaptation(
    cfg,
    ms,
    embeddings_rest,
    data_train,
    epochs=1000,
    data_test=None,
)

Parameter containing:
tensor([[ 0.2081, -0.1057,  0.1777, -0.0155],
        [ 0.3461, -0.3218,  0.4037,  0.1337],
        [-0.3590,  0.3446, -0.1036,  0.0948],
        [ 0.4034, -0.2442, -0.1809,  0.3859]], device='cuda:0',
       requires_grad=True) ---


AssertionError: 

In [None]:
# Then: let's try on a held-out session, but full training set.

In [5]:
# Cleared for takeoff...
embeddings_stim, results = multisession.train_from_cfg(
    cfg,
    ms,
    data_train,
    model_optims,
    embeddings_rest,
    data_test=data_test,
    test_interval=1000,
    epochs=20001)

# 15k batch size 3000 0.6836317658424378 0.24544973075389862
# 10k batch size 3000 0.6723785549402237 0.24294960126280785
# 7.5k batch size 3000 0.6306539237499237 0.34320169389247895
# 5k batch size 

# 7.5k no ae coadapt: 

---- 0 1.8234667778015137 1.486933708190918 -0.8891817132631937 -0.5328980684280396
---- 1000 0.5899742841720581 0.5507943630218506 0.3826422741015752 0.43499526381492615
---- 2000 0.581556499004364 0.549232006072998 0.39112886786460876 0.43645283579826355
---- 3000 0.578947901725769 0.5422236323356628 0.3938383956750234 0.44358503818511963
---- 4000 0.5930135250091553 0.5486187934875488 0.3809671103954315 0.43679529428482056
---- 5000 0.572654664516449 0.535224437713623 0.40044142802556354 0.4508146643638611
---- 6000 0.5688191652297974 0.5323787927627563 0.40421965221563977 0.45374569296836853
---- 7000 0.5617053508758545 0.532183051109314 0.41133420666058856 0.45391416549682617
---- 8000 0.5615584850311279 0.5309649705886841 0.4116242031256358 0.455045223236084
---- 9000 0.5558925867080688 0.5284721255302429 0.41714655856291455 0.4577196538448334
---- 10000 0.5591580271720886 0.5314217805862427 0.4140358219544093 0.45462727546691895
---- 11000 0.554527223110199 0.5307040214538574 0.