In [2]:
import os
os.environ["OMP_NUM_THREADS"] = "8"  # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "8"  # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "8"  # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "8"  # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "8"  # export NUMEXPR_NUM_THREADS=6

import math
import torch
import torch.nn as nn
import xfads.utils as utils
import xfads.prob_utils as prob_utils
import pytorch_lightning as lightning
import matplotlib.pyplot as plt

from hydra import compose, initialize
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from xfads.ssm_modules.likelihoods import PoissonLikelihood
from xfads.ssm_modules.dynamics import DenseGaussianDynamics
from xfads.ssm_modules.dynamics import DenseGaussianInitialCondition
from xfads.ssm_modules.encoders import LocalEncoderLRMvn, BackwardEncoderLRMvn
from xfads.smoothers.lightning_trainers import LightningNonlinearSSM, LightningMonkeyReaching
from xfads.smoothers.nonlinear_smoother import NonlinearFilter, LowRankNonlinearStateSpaceModel

In [4]:
torch.cuda.empty_cache()

In [6]:
initialize(version_base=None, config_path="", job_name="monkey_reaching")
cfg = compose(config_name="config")

n_bins_bhv = 10
seeds = [1234, 1235, 1236]

In [8]:
"""config"""

for seed in seeds:
    cfg.seed = seed

    lightning.seed_everything(cfg.seed, workers=True)
    torch.set_default_dtype(torch.float32)

Seed set to 1234
Seed set to 1235
Seed set to 1236


In [10]:
"""data"""

data_path = 'data/data_{split}_{bin_sz_ms}ms.pt'

train_data = torch.load(data_path.format(split='train', bin_sz_ms=cfg.bin_sz_ms))
val_data = torch.load(data_path.format(split='valid', bin_sz_ms=cfg.bin_sz_ms))
test_data = torch.load(data_path.format(split='test', bin_sz_ms=cfg.bin_sz_ms))

y_valid_obs = val_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_train_obs = train_data['y_obs'].type(torch.float32).to(cfg.data_device)
y_test_obs = test_data['y_obs'].type(torch.float32).to(cfg.data_device)

vel_valid = val_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_train = train_data['velocity'].type(torch.float32).to(cfg.data_device)
vel_test = test_data['velocity'].type(torch.float32).to(cfg.data_device)

n_trials, n_time_bins, n_neurons_obs = y_train_obs.shape
n_time_bins_enc = train_data['n_time_bins_enc']

y_train_dataset = torch.utils.data.TensorDataset(y_train_obs, vel_train)
y_val_dataset = torch.utils.data.TensorDataset(y_valid_obs, vel_valid)
y_test_dataset = torch.utils.data.TensorDataset(y_test_obs, vel_test)

train_dataloader = torch.utils.data.DataLoader(y_train_dataset, batch_size=cfg.batch_sz, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(y_val_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)
test_dataloader = torch.utils.data.DataLoader(y_test_dataset, batch_size=y_valid_obs.shape[0], shuffle=False)

In [12]:
"""likelihood pdf"""
H = utils.ReadoutLatentMask(cfg.n_latents, cfg.n_latents_read)
readout_fn = nn.Sequential(H, nn.Linear(cfg.n_latents_read, n_neurons_obs))
readout_fn[-1].bias.data = prob_utils.estimate_poisson_rate_bias(train_dataloader, cfg.bin_sz)
likelihood_pdf = PoissonLikelihood(readout_fn, n_neurons_obs, cfg.bin_sz, device=cfg.device)

In [14]:
"""dynamics module"""
Q_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
dynamics_fn = utils.build_gru_dynamics_function(cfg.n_latents, cfg.n_hidden_dynamics, device=cfg.device)
dynamics_mod = DenseGaussianDynamics(dynamics_fn, cfg.n_latents, Q_diag, device=cfg.device)

In [16]:
"""initial condition"""
m_0 = torch.zeros(cfg.n_latents, device=cfg.device)
Q_0_diag = 1. * torch.ones(cfg.n_latents, device=cfg.device)
initial_condition_pdf = DenseGaussianInitialCondition(cfg.n_latents, m_0, Q_0_diag, device=cfg.device)

In [18]:
"""local/backward encoder"""
backward_encoder = BackwardEncoderLRMvn(cfg.n_latents, cfg.n_hidden_backward, cfg.n_latents,
                                        rank_local=cfg.rank_local, rank_backward=cfg.rank_backward,
                                        device=cfg.device)
local_encoder = LocalEncoderLRMvn(cfg.n_latents, n_neurons_obs, cfg.n_hidden_local, cfg.n_latents, rank=cfg.rank_local,
                                  device=cfg.device, dropout=cfg.p_local_dropout)
nl_filter = NonlinearFilter(dynamics_mod, initial_condition_pdf, device=cfg.device)

In [20]:
"""sequence vae"""
ssm = LowRankNonlinearStateSpaceModel(dynamics_mod, likelihood_pdf, initial_condition_pdf, backward_encoder,
                                      local_encoder, nl_filter, device=cfg.device)

In [22]:
"""lightning"""
seq_vae = LightningMonkeyReaching(ssm, cfg, n_time_bins_enc, n_bins_bhv)
csv_logger = CSVLogger('logs/smoother/acausal/', name=f'sd_{cfg.seed}_r_y_{cfg.rank_local}_r_b_{cfg.rank_backward}', version='smoother_acausal')
ckpt_callback = ModelCheckpoint(save_top_k=3, monitor='r2_valid_enc', mode='max', dirpath='ckpts/smoother/acausal/', save_last=True,
                                filename='{epoch:0}_{valid_loss:0.2f}_{r2_valid_enc:0.2f}_{r2_valid_bhv:0.2f}_{valid_bps_enc:0.2f}')

In [24]:
trainer = lightning.Trainer(max_epochs=cfg.n_epochs,
                            gradient_clip_val=1.0,
                            default_root_dir='lightning/',
                            callbacks=[ckpt_callback],
                            logger=csv_logger,
                            strategy='ddp_notebook',
                            accelerator='cpu',
                            )

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [26]:
trainer.fit(model=seq_vae, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
torch.save(ckpt_callback.best_model_path, 'ckpts/smoother/acausal/best_model_path.pt')

[rank: 0] Seed set to 1236
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/lightning_fabric/loggers/csv_logs.py:268: Experiment logs directory logs/smoother/acausal/sd_1236_r_y_2_r_b_2/smoother_acausal exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/mahmoud/catnip/xfads/workshop/monkey_reaching/ckpts/smoother/acausal exists and is not empty.

  | Name | Type                            | Params
-------------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (14) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0:  14%|█▍        | 2/14 [00:04<00:27,  0.44it/s, v_num=usal]

/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [27]:
trainer.test(dataloaders=test_dataloader, ckpt_path='last')

[rank: 0] Seed set to 1236
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
[W socket.cpp:464] [c10d] The server socket has failed to bind to [::]:63190 (errno: 48 - Address already in use).
[W socket.cpp:464] [c10d] The server socket has failed to bind to 0.0.0.0:63190 (errno: 48 - Address already in use).
[E socket.cpp:500] [c10d] The server socket has failed to listen on any local network address.


ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 794, in _test_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 943, in _run
    self.strategy.setup_environment()
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 154, in setup_environment
    self.setup_distributed()
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 203, in setup_distributed
    _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/lightning_fabric/utilities/distributed.py", line 291, in _init_dist_connection
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 86, in wrapper
    func_return = func(*args, **kwargs)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1177, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 246, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout, use_libuv)
  File "/opt/anaconda3/envs/xfads/lib/python3.10/site-packages/torch/distributed/rendezvous.py", line 174, in _create_c10d_store
    return TCPStore(
torch.distributed.DistNetworkError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:63190 (errno: 48 - Address already in use). The server socket has failed to bind to 0.0.0.0:63190 (errno: 48 - Address already in use).
