In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
%config Completer.use_jedi = False

import sys
sys.path.append('..')

import os
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

use_cpu = False
cuda_device = '0'

if use_cpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    os.environ['CPU_ONLY'] = "TRUE"
    physical_devices = tf.config.list_physical_devices('CPU')
    tf.config.set_logical_device_configuration(
        physical_devices[0],
        [tf.config.LogicalDeviceConfiguration() for i in range(8)])
    logical_devices = tf.config.list_logical_devices('CPU')

    print(logical_devices)
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = cuda_device
    os.environ['CPU_ONLY'] = "FALSE"
    physical_devices = tf.config.list_physical_devices('GPU')
    print(physical_devices)
    
from collections import defaultdict
from sklearn.metrics import r2_score
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime

from tndm.data import DataManager
from tndm import TNDM
from tndm.runtime import Runtime, ModelType
from tndm.utils import AdaptiveWeights
from tndm.models.model_loader import ModelLoader


In [None]:
tf.sysconfig.get_build_info()

In [17]:

import h5py

# use teh same h5 file as in LFADS
with h5py.File('data/baselines/lfads/data/datasets/monkey_5ms.h5', 'r') as f:
    dataset = {key: np.array(f[key]) for key in f.keys()}



# test set is combined valid and test
neural_data = dataset['train_data'].astype('float')
valid_neural_data = dataset['valid_data'].astype('float')
test_neural_data = dataset['test_data'].astype('float')

behavioural_data = dataset['train_beh'].astype('float')
valid_behavioural_data = dataset['valid_beh'].astype('float')
test_behavioural_data = dataset['test_beh'].astype('float')


print(neural_data.shape, behavioural_data.shape)
print(valid_neural_data.shape, valid_behavioural_data.shape)
print(test_neural_data.shape, test_behavioural_data.shape)

(2008, 140, 182) (2008, 140, 2)
(71, 140, 182) (71, 140, 2)
(216, 140, 182) (216, 140, 2)


In [5]:

# center behaviour at zero, using first time step (not strictly required)
b_mean = np.mean(behavioural_data[:,0,:],axis=0)
for i in range(2):
    behavioural_data[:,:,i] = behavioural_data[:,:,i]-b_mean[i]
    valid_behavioural_data[:,:,i] = valid_behavioural_data[:,:,i]-b_mean[i]
    test_behavioural_data[:,:,i] = test_behavioural_data[:,:,i]-b_mean[i]


In [None]:
# model parameters

# l2 regulariser for the recurrent decoder weights
l2_reg = 1
initial_neural_weight = 1.0 # weight of neural nll
initial_behaviour_weight = .2 # weight of behaviour loss
lambda_q = 100.0
update_rate = .0005 
dropout = .15
seed = 0
GRU_pre_activation = False
var_min = 0.0001
prior_variance = 1

optimizer = tf.keras.optimizers.Adam(
    learning_rate=1e-2,
    beta_1=0.9, 
    beta_2=0.999,
    epsilon=1e-08)

layers_settings=defaultdict(lambda: dict(
    kernel_initializer=tf.keras.initializers.VarianceScaling(
        scale=1.0, mode='fan_in', distribution='normal'),
    kernel_regularizer=tf.keras.regularizers.l2(l2=0.0)
))

layers_settings['encoder'].update(dict(var_min=var_min, var_trainable=True))
layers_settings['relevant_decoder'].update(dict(kernel_regularizer=tf.keras.regularizers.l2(l2=0),
                                      recurrent_regularizer=tf.keras.regularizers.l2(l2=l2_reg),
                                      original_cell=False))    
layers_settings['irrelevant_decoder'].update(dict(kernel_regularizer=tf.keras.regularizers.l2(l2=0),
                                      recurrent_regularizer=tf.keras.regularizers.l2(l2=l2_reg),
                                      original_cell=False))   
layers_settings['behavioural_dense'].update(dict(behaviour_type='causal'))    

In [7]:
T = datetime.today().strftime("%y_%m_%d_%X")

spike_data_dir = "tndm_exp"

logdir = os.path.join( spike_data_dir, 'log_l2_reg_'+str(l2_reg)+'_do_' + str(dropout)+'_no_norm_'+T)
modeldir = os.path.join( spike_data_dir, 'model_l2_reg_'+str(l2_reg)+'_do_' + str(dropout)+'_no_norm_'+T)

In [None]:
T0 = datetime.now()

model, history = Runtime.train(
    model_type=ModelType.TNDM,
    adaptive_lr=dict(factor=0.95, patience=10, min_lr=1e-5),
    model_settings=dict(
        rel_factors=5,
        irr_factors=5,
        encoded_dim=64,
        max_grad_norm=200,
        dropout=dropout,
        prior_variance=prior_variance,
        GRU_pre_activation=GRU_pre_activation, #NEW
        timestep=0.005, # hardcoded
        seed=seed
    ),
    layers_settings=layers_settings,
    optimizer=optimizer, 
    epochs=1000, 
    # logdir=logdir,
    logdir=None,
    train_dataset=(neural_data, behavioural_data), 
    val_dataset=(valid_neural_data, valid_behavioural_data),
    adaptive_weights=AdaptiveWeights(
        initial=[initial_neural_weight, initial_behaviour_weight, .0, .0, lambda_q, .0], #changed
        update_start=[0, 0, 0, 1000, 1000, 0],
        update_rate=[0., 0., update_rate, update_rate, 0.0, update_rate],
        min_weight=[initial_neural_weight, initial_behaviour_weight, 0.0, 0.0, lambda_q, 0.0],#changed
        max_weight=[initial_neural_weight, initial_behaviour_weight, 1.0, 1.0, lambda_q, 1.0],#changed
    ),
    batch_size=128,
    verbose=2 # set to 2 to see the losses during training
)

model.save(modeldir)

print('Training took '+str(datetime.now()-T0))

In [12]:
model.layers_settings.default_factory  = lambda : model.layers_settings.default_factory

In [None]:
model.save(modeldir)

print('Training took '+str(datetime.now()-T0))

In [10]:
# manual
modeldir = 'tndm_exp'
from tndm.utils import CustomEncoder, upsert_empty_folder
location = modeldir
import json


In [13]:

settings = model.get_settings()
upsert_empty_folder(location)
with open(os.path.join(location, "settings.json"), "w") as fp:
    json.dump(settings, fp, cls=CustomEncoder)
model.save_weights(os.path.join(location, ".weights.h5"))

In [None]:
modeldir

In [None]:
os.getcwd()

In [None]:
model = ModelLoader.load(modeldir, model_class=TNDM)

# Latent space

## Training data

In [None]:
from tqdm.auto import tqdm
data = neural_data
test_sample_mode = 'posterior_sample' #choose 'mean' for previous behaviour

if test_sample_mode == 'mean':
    log_f, b, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i) = \
        model(data.astype('float'), training=False, test_sample_mode=test_sample_mode)
else:
    batch_size_eval = 128
    log_fs = []
    bs = []
    g0_rs, mean_rs, logvar_rs, z_rs = [], [], [], []
    g0_is, mean_is, logvar_is, z_is = [], [], [], []
    for neural_datum in tqdm(data):
        neural_datum_batch = np.repeat(np.expand_dims(neural_datum, 0), batch_size_eval, axis=0)
        log_f, b, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i) = \
            model(neural_datum_batch.astype('float'), training=False, test_sample_mode=test_sample_mode)
        log_fs.append(np.mean(log_f, 0))
        bs.append(np.mean(b, 0))
        g0_rs.append(np.mean(g0_r, 0))
        mean_rs.append(np.mean(mean_r, 0))
        logvar_rs.append(np.mean(logvar_r, 0))
        z_rs.append(np.mean(z_r, 0))
        g0_is.append(np.mean(g0_i, 0))
        mean_is.append(np.mean(mean_i, 0))
        logvar_is.append(np.mean(logvar_i, 0))
        z_is.append(np.mean(z_i, 0))
    log_f = tf.stack(log_fs)
    b = tf.stack(bs)
    g0_r = tf.stack(g0_rs)
    mean_r = tf.stack(mean_rs)
    logvar_r = tf.stack(logvar_rs)
    z_r = tf.stack(z_rs)
    g0_i = tf.stack(g0_is)
    mean_i = tf.stack(mean_is)
    logvar_i = tf.stack(logvar_is)
    z_i = tf.stack(z_is)

##  test data

In [100]:
from tqdm.auto import tqdm
data = neural_data
# data = test_neural_data
test_sample_mode = 'mean' #choose 'mean' for previous behaviour

if test_sample_mode == 'mean':
    log_f, b, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i) = \
        model(data.astype('float'), training=False, test_sample_mode=test_sample_mode)
else:
    batch_size_eval = 128
    log_fs = []
    bs = []
    g0_rs, mean_rs, logvar_rs, z_rs = [], [], [], []
    g0_is, mean_is, logvar_is, z_is = [], [], [], []
    for neural_datum in tqdm(data):
        neural_datum_batch = np.repeat(np.expand_dims(neural_datum, 0), batch_size_eval, axis=0)
        log_f, b, (g0_r, mean_r, logvar_r), (g0_i, mean_i, logvar_i), (z_r, z_i) = \
            model(neural_datum_batch.astype('float'), training=False, test_sample_mode=test_sample_mode)
        log_fs.append(np.mean(log_f, 0))
        bs.append(np.mean(b, 0))
        g0_rs.append(np.mean(g0_r, 0))
        mean_rs.append(np.mean(mean_r, 0))
        logvar_rs.append(np.mean(logvar_r, 0))
        z_rs.append(np.mean(z_r, 0))
        g0_is.append(np.mean(g0_i, 0))
        mean_is.append(np.mean(mean_i, 0))
        logvar_is.append(np.mean(logvar_i, 0))
        z_is.append(np.mean(z_i, 0))
    log_f = tf.stack(log_fs)
    b = tf.stack(bs)
    g0_r = tf.stack(g0_rs)
    mean_r = tf.stack(mean_rs)
    logvar_r = tf.stack(logvar_rs)
    z_r = tf.stack(z_rs)
    g0_i = tf.stack(g0_is)
    mean_i = tf.stack(mean_is)
    logvar_i = tf.stack(logvar_is)
    z_i = tf.stack(z_is)

In [106]:
print("log_f", log_f.shape)
print("b", b.shape)
print("g0_r", g0_r.shape)
print("mean_r", mean_r.shape)
print("logvar_r", logvar_r.shape)
print("z_r", z_r.shape)
print("g0_i", g0_i.shape)
print("mean_i", mean_i.shape)
print("logvar_i", logvar_i.shape)
print("z_i", z_i.shape)



# ret_dict_test = {}
# ret_dict_test["ae_rates"] = (0.005*tf.math.exp(log_f)).numpy()
# ret_dict_test["ae_latents_relevant"] = z_r.numpy()
# ret_dict_test["ae_latents_irrelevant"] = z_i.numpy()
# ret_dict_test["ae_behaviour"] = b.numpy()
# ret_dict_test["gt_spikes"] = test_neural_data
# ret_dict_test["init_states_gt_relevant"] = g0_r.numpy()
# ret_dict_test["init_states_gt_irrelevant"] = g0_i.numpy()



ret_dict_train = {}
ret_dict_train["ae_rates"] = (0.005*tf.math.exp(log_f)).numpy()
ret_dict_train["ae_latents_relevant"] = z_r.numpy()
ret_dict_train["ae_latents_irrelevant"] = z_i.numpy()
ret_dict_train["ae_behaviour"] = b.numpy()
ret_dict_train["gt_spikes"] = neural_data
ret_dict_train["init_states_gt_relevant"] = g0_r.numpy()
ret_dict_train["init_states_gt_irrelevant"] = g0_i.numpy()

for key, val in ret_dict_train.items():
    print(key, val.shape)



import pickle
with open('tndm_exp/tndm_samples.pkl', 'wb') as f:
    pickle.dump(ret_dict_train, f)
    


log_f (2008, 140, 182)
b (2008, 140, 2)
g0_r (2008, 64)
mean_r (2008, 64)
logvar_r (2008, 64)
z_r (2008, 140, 5)
g0_i (2008, 64)
mean_i (2008, 64)
logvar_i (2008, 64)
z_i (2008, 140, 5)
ae_rates (2008, 140, 182)
ae_latents_relevant (2008, 140, 5)
ae_latents_irrelevant (2008, 140, 5)
ae_behaviour (2008, 140, 2)
gt_spikes (2008, 140, 182)
init_states_gt_relevant (2008, 64)
init_states_gt_irrelevant (2008, 64)
