In [1]:
%load_ext autoreload
%autoreload 2

import os
import random
import sys

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from pathlib import Path
from hydra import initialize_config_dir, compose
from hydra.utils import instantiate
from omegaconf import OmegaConf
import shutil
import tqdm
import torch
import matplotlib.pyplot as plt

from tbfm import film
from tbfm import multisession
from tbfm import utils

DATA_DIR = "/home/mmattb/Projects/opto-coproc/data"
sys.path.append(DATA_DIR)
# imported from JNE project
import dataset

EMBEDDING_REST_SUBDIR = "embedding_rest"

conf_dir = Path("./conf").resolve()

# Initialize Hydra with the configuration directory
with initialize_config_dir(config_dir=str(conf_dir), version_base=None):
    # Compose the configuration
    cfg = compose(config_name="config")   # i.e. conf/config.yaml

DEVICE = "cuda" #cfg.device
WINDOW_SIZE = cfg.data.trial_len
NUM_HELD_OUT_SESSIONS = 0
session_ids, _ = multisession.gather_session_ids(DATA_DIR, num_held_out_sessions=NUM_HELD_OUT_SESSIONS)
NUM_REPEATS = 3

RESULTS_DIR = "singles"

cfg.training.epochs = 20001

In [2]:
# Dividing the work up so we can run two GPUs...
n = len(session_ids)
top_ids = session_ids[:n//2]
bottom_ids = session_ids[n//2:]
print(top_ids)
print("----")
print(bottom_ids)

# my_ids = top_ids

my_ids = top_ids

['MonkeyJ_20160625_Session5_S1', 'MonkeyG_20150914_Session1_S1', 'MonkeyJ_20160428_Session3_S1', 'MonkeyG_20150918_Session1_M1', 'MonkeyJ_20160702_Session2_S1', 'MonkeyG_20150915_Session4_S1', 'MonkeyG_20150922_Session2_S1', 'MonkeyG_20150925_Session1_S1', 'MonkeyJ_20160624_Session4_S1', 'MonkeyJ_20160426_Session2_S1', 'MonkeyG_20150917_Session3_M1', 'MonkeyJ_20160702_Session4_S1', 'MonkeyG_20150914_Session3_S1', 'MonkeyG_20150921_Session5_S1', 'MonkeyJ_20160502_Session1_S1', 'MonkeyG_20150921_Session3_S1', 'MonkeyG_20150915_Session5_S1', 'MonkeyJ_20160630_Session1_S1', 'MonkeyJ_20160429_Session1_S1', 'MonkeyG_20150917_Session3_S1']
----
['MonkeyG_20150917_Session2_S1', 'MonkeyJ_20160428_Session2_S1', 'MonkeyJ_20160426_Session3_S1', 'MonkeyG_20150922_Session3_S1', 'MonkeyJ_20160624_Session3_S1', 'MonkeyG_20150917_Session1_S1', 'MonkeyJ_20160429_Session3_S1', 'MonkeyJ_20160426_Session1_S1', 'MonkeyG_20150915_Session2_S1', 'MonkeyJ_20160630_Session3_S1', 'MonkeyJ_20160627_Session2_S1', '

In [3]:
# Rest embeddings
embeddings_rest = multisession.load_rest_embeddings(my_ids, device=DEVICE)

In [None]:
for session_id in my_ids:
    print("-----------------------------------------", session_id)
    out_dir = os.path.join(RESULTS_DIR, session_id)
    
    try:
        shutil.rmtree(out_dir)
    except OSError:
        pass
    os.makedirs(out_dir, mode=0o777, exist_ok=False)

    _embeddings_rest = {session_id: embeddings_rest[session_id]}
    
    for nidx in range(NUM_REPEATS):
        print("@@@@@", nidx)
        
        d, held_out_session_ids = multisession.load_stim_batched(                                                             
            window_size=WINDOW_SIZE,                                                               
            session_subdir="torchraw",                                                     
            data_dir=DATA_DIR,
            held_in_session_ids=[session_id,],
            batch_size=6000,
            num_held_out_sessions=0,                                                      
        )
        data_train, data_test = d.train_test_split(6000, test_cut=2500)

        ms = multisession.build_from_cfg(cfg, data_train, device=DEVICE, quiet=True)
        model_optims = multisession.get_optims(cfg, ms)


        embeddings_stim, results = multisession.train_from_cfg(
            cfg,
            ms,
            data_train,
            model_optims,
            _embeddings_rest,
            data_test=data_test,
            test_interval=1000,
            epochs=cfg.training.epochs)

        torch.save(embeddings_stim, os.path.join(out_dir, f"es_{nidx}.torch"))
        torch.save(results, os.path.join(out_dir, f"results_{nidx}.torch"))

----------------------------------------- MonkeyJ_20160625_Session5_S1
@@@@@ 0
---- 0 0.8711338043212891 2.2867069244384766 0.07048509269952774 0.43087905645370483
---- 1000 0.5032915472984314 2.0945048332214355 0.339429646730423 0.4787370562553406
---- 2000 0.40396246314048767 2.030858039855957 0.4779115319252014 0.49457255005836487
---- 3000 0.356317400932312 2.001826047897339 0.5401397347450256 0.5017723441123962
---- 4000 0.3439364731311798 1.9818607568740845 0.5583277940750122 0.5067331194877625
---- 5000 0.3321656882762909 1.970332384109497 0.5751792192459106 0.5095957517623901
---- 6000 0.33410730957984924 1.9648367166519165 0.5726734399795532 0.5109647512435913
---- 7000 0.32057344913482666 1.9523473978042603 0.5909767150878906 0.5140740871429443
---- 8000 0.32532086968421936 1.949306607246399 0.5847910642623901 0.5148305296897888
---- 9000 0.31028833985328674 1.9403859376907349 0.6051744818687439 0.5170496106147766
---- 10000 0.3054920434951782 1.9396368265151978 0.61130011081