In [2]:
import wandb
import os
import pytorch_lightning as pl
import torch
import numpy as np
from legendre.models.spline_cnode import SplineCNODEClass
from legendre.models.cnode_ext import CNODExtClassification, CNODExt
from legendre.models.node_ext import NODExtClassification, NODExt
from legendre.models.node_mod import NODEClassification, NODE
from legendre.models.rnn import RNNClassification, RNN
from legendre.models.hippo import HIPPO, HippoClassification
from legendre.data_utils.simple_path_utils import SimpleTrajDataModule
from legendre.data_utils.pMNIST_utils import pMNISTDataModule
from legendre.data_utils.character_utils import CharacterTrajDataModule
from legendre.data_utils.mimic_utils import MIMICDataModule
from legendre.data_utils.lorenz_utils import LorenzDataModule
from legendre.train_scripts.classif import get_init_model
from sklearn.metrics import accuracy_score, roc_auc_score
from pytorch_lightning.core.saving import _load_state

%load_ext autoreload
%autoreload 2


def mean_dict(d):
    return {k:sum(v)/len(v) for k,v in d.items() if len(v)>0}
def std_dict(d):
    return {k:np.std(np.array(v)) for k,v in d.items() if len(v)>0}

def process_dict(d):
    m_dict = mean_dict(d)
    s_dict = std_dict(d)
    return {k: f"${v:.3f} \\pm {s_dict[k]:.3f}$" for k,v in m_dict.items() if len(k)>0}

  from .autonotebook import tqdm as notebook_tqdm


# Classification

## SplineCNODE

In [10]:
#from urllib.parse import non_hierarchical
sweep_id = "v4vdss0h"
multivariate = None
regression_mode = True
spline_type = "Linear"
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = SplineCNODEClass
irregular_rates = [0.3,0.4,0.5]
#irregular_rates = [0.7,0.8,0.9,1.0]
acc_dict = {}

def data_cls_choice(data_type):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**model.hparams)
    return data_cls


for irregular_rate in irregular_rates:

    accs_ = []
    
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate) and (r.config.get("regression_mode",False)==regression_mode) & (r.config.get("spline_type","Hermite")==spline_type)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if checkpoint["hyper_parameters"].get("regression_mode",False):
            accuracy = torch.nn.MSELoss()(preds,labels)
        else:
            if len(preds.shape)>1:
                preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
                accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
            else:
                accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
        
    acc_dict[f"{irregular_rate}"]=accs_
    print(acc_dict)
    print(process_dict(acc_dict))

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.69it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 19.99it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:10<00:00,  1.85it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 14.51it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 15.10it/s] 
{'0.3': [tensor(0.8652), tensor(0.7028), tensor(0.7937)]}
{'0.3': '$0.787 \\pm 0.066$'}
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.78it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 17.22it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.75it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 14.92it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:10<00:00,  1.84it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 21.12it/s] 
{'0.3': [tensor(0.8652), tensor(0.7028), tensor(0.7937)], '0.4': [tensor(0.7999), tensor(0.8259), tensor(0.7349)]}
{'0.3': '$0.787 \\pm 0.066$', '0.4': '$0.787 \\pm 0.038$'}
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.72it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 20.26it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:10<00:00,  1.82it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 16.67it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:11<00:00,  1.73it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 19.47it/s] 
{'0.3': [tensor(0.8652), tensor(0.7028), tensor(0.7937)], '0.4': [tensor(0.7999), tensor(0.8259), tensor(0.7349)], '0.5': [tensor(0.7523), tensor(0.7519), tensor(0.8274)]}
{'0.3': '$0.787 \\pm 0.066$', '0.4': '$0.787 \\pm 0.038$', '0.5': '$0.777 \\pm 0.036$'}


## CNODExt

In [7]:
sweep_id = "apwiz3xd"
regression_mode = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = CNODExtClassification
init_model_cls = CNODExt
irregular_rates = [0.3, 0.4, 0.5]
#irregular_rates = [0.7, 0.8, 0.9, 1.0]
#irregular_rates = [1.0]
device = torch.device("cuda")
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate)   and  (r.config.get("regression_mode",False)==regression_mode) ]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)

        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if checkpoint["hyper_parameters"].get("regression_mode",False):
            accuracy = torch.nn.MSELoss()(preds,labels)
        else:
            if len(preds.shape)>1:
                preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
                accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
            else:
                accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)
print(process_dict(acc_dict))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:07<00:00,  2.82it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 26.01it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:07<00:00,  2.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 24.55it/s] 
Pre-computing ODE Projection embeddings....


  5%|▌         | 1/20 [00:00<00:07,  2.67it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 23.97it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 26.44it/s] 
Pre-computing ODE Projection embeddings....


  5%|▌         | 1/20 [00:00<00:07,  2.66it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 24.72it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 26.27it/s] 
Pre-computing ODE Projection embeddings....


  0%|          | 0/20 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/da

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 23.71it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:07<00:00,  2.86it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 23.34it/s] 
Pre-computing ODE Projection embeddings....


  0%|          | 0/20 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/da

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 25.33it/s] 
{'0.3': [tensor(0.0486), tensor(0.0368), tensor(0.0287)], '0.4': [tensor(0.0249), tensor(0.0213), tensor(0.0167)], '0.5': [tensor(0.0132), tensor(0.0093), tensor(0.0115)]}
{'0.3': '$0.038 \\pm 0.008$', '0.4': '$0.021 \\pm 0.003$', '0.5': '$0.011 \\pm 0.002$'}


## Hippo

In [3]:
sweep_id = "vzu6ezhf"
multivariate = None
regression_mode = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = HIPPO

irregular_rates = [0.3, 0.4, 0.5]
#irregular_rates= [0.7,0.8,0.9,1.0]
#irregular_rates = [1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type):
    if "forecast_mode" in model.hparams:
        model.hparams.pop("forecast_mode")
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams, forecast_mode=True)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams, forecast_mode=True)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams, forecast_mode = True)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**model.hparams,forecast_mode = True)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**model.hparams, forecast_mode = True)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate) and (r.config.get("regression_mode",False)==regression_mode)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        #Y = torch.cat([x["Y"] for x in outputs])
        #T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if checkpoint["hyper_parameters"].get("regression_mode",False):
            import ipdb; ipdb.set_trace()
            accuracy = torch.nn.MSELoss()(preds,labels)
        else:
            #if len(preds.shape)>1:
            #    preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            #    accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
            #else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
            print(accuracy)
            #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)
print(process_dict(acc_dict))

  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  7.07it/s]


  return F.mse_loss(input, target, reduction=self.reduction)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  5.47it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  5.78it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  6.51it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  5.86it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  6.81it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  6.55it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  6.72it/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  6.99it/s]
{'0.3': [tensor(1.9604), tensor(1.8589), tensor(2.0101)], '0.4': [tensor(1.9190), tensor(1.8837), tensor(1.5984)], '0.5': [tensor(1.7212), tensor(1.7034), tensor(1.6540)]}
{'0.3': '$1.943 \\pm 0.063$', '0.4': '$1.800 \\pm 0.144$', '0.5': '$1.693 \\pm 0.028$'}


## Hippo-RNN

In [14]:
sweep_id = "et71zf2x"
device = torch.device("cuda")
multivariate = None
regression_mode = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = HippoClassification
init_model_cls = HIPPO
#irregular_rates= [1.0]
irregular_rates = [0.3, 0.4, 0.5]#, 0.9, 1.0]
#irregular_rates = [0.7, 0.8, 0.9, 1.0]#, 0.9, 1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate) and (r.config.get("regression_mode",False)==regression_mode)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        #input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        #if input_dim is None:
        #    init_model = NODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        #else:
        #    init_model = NODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if checkpoint["hyper_parameters"].get("regression_mode",False):
            accuracy = torch.nn.MSELoss()(preds,labels)
        else:
            if len(preds.shape)>1:
                preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
                accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
            else:
                accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)
print(process_dict(acc_dict))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:03<00:00,  6.13it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 11.05it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  9.14it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 10.95it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  7.25it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 12.14it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:01<00:00, 10.45it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 10.64it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  9.84it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  9.87it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  6.71it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 12.15it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:03<00:00,  6.56it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 11.32it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  6.77it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 13.60it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:02<00:00,  9.93it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 12.35it/s] 
{'0.3': [tensor(0.2087), tensor(0.2067), tensor(0.1785)], '0.4': [tensor(0.1900), tensor(0.1526), tensor(0.1383)], '0.5': [tensor(0.1175), tensor(0.0970), tensor(0.1061)]}
{'0.3': '$0.198 \\pm 0.014$', '0.4': '$0.160 \\pm 0.022$', '0.5': '$0.107 \\pm 0.008$'}


## NODExt

In [10]:
sweep_id = "x9vkty3m"
device = torch.device("cuda")
multivariate = None
regression_mode = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = NODEClassification
init_model_cls = NODE
irregular_rates = [0.3, 0.4, 0.5]
#irregular_rates = [1.0]
#irregular_rates = [0.7, 0.8, 0.9, 1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate) and (r.config.get("regression_mode",False)==regression_mode)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        #input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        #if input_dim is None:
        #    init_model = NODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        #else:
        #    init_model = NODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if checkpoint["hyper_parameters"].get("regression_mode",False):
            accuracy = torch.nn.MSELoss()(preds,labels)
        else:
            if len(preds.shape)>1:
                preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
                accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
            else:
                accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)
print(process_dict(acc_dict))

  rank_zero_warn(


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:18<00:00,  1.08it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 11.28it/s] 
Pre-computing ODE Projection embeddings....


 20%|██        | 4/20 [00:02<00:10,  1.49it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 14.64it/s] 
Pre-computing ODE Projection embeddings....


  0%|          | 0/20 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/da

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 12.10it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:19<00:00,  1.01it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  8.99it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:16<00:00,  1.20it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 14.24it/s] 
Pre-computing ODE Projection embeddings....


 50%|█████     | 10/20 [00:10<00:10,  1.01s/it]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/util

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  8.23it/s] 
Pre-computing ODE Projection embeddings....


 20%|██        | 4/20 [00:03<00:12,  1.32it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 10.06it/s] 
Pre-computing ODE Projection embeddings....


  0%|          | 0/20 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/da

Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  9.92it/s] 


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fda0ae6d280>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:21<00:00,  1.07s/it]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 11.14it/s] 
{'0.3': [tensor(0.2569), tensor(0.1916), tensor(0.1947)], '0.4': [tensor(0.2382), tensor(0.2024), tensor(0.2348)], '0.5': [tensor(0.2156), tensor(0.1986), tensor(0.1620)]}
{'0.3': '$0.214 \\pm 0.030$', '0.4': '$0.225 \\pm 0.016$', '0.5': '$0.192 \\pm 0.022$'}


## RNN

In [15]:
sweep_id = "1vc3k85v"
device = torch.device("cuda")
multivariate = True
regression_mode = False
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = RNNClassification
init_model_cls = RNN
#irregular_rates = [0.6, 0.7, 0.8, 0.9, 1.0]
irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate) and (r.config.get("regression_mode",False)==regression_mode)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.79it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 38.12it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 111.57it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.70it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.70it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 161.18it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 56.22it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 56.54it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 143.26it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 52.67it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 52.93it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 155.72it/s]
{'0.3': [0.883484954513646], '0.4': [0.8540937718684395], '0.5': [0.8992302309307207], '0.6': [0.9062281315605318]}


# Forecasting

In [4]:
sweep_id = "2l62tcyx"
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = HIPPO
#irregular_rates = [0.6, 0.7, 0.8, 0.9, 1.0]
irregular_rates = [1.0]
acc_dict = {}
acc_rec_dict = {}

def data_cls_choice(data_type, hparams):
    if "forecast_mode" in hparams:
        hparams.pop("forecast_mode")
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**hparams, forecast_mode = True)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**hparams, forecast_mode = True)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**hparams, forecast_mode = True)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**hparams,forecast_mode = True)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**hparams, forecast_mode = True)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    accs_rec_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        model = model_cls.load_from_checkpoint(fname)
                
        dataset = data_cls_choice(model.hparams["data_type"],model.hparams)

        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        #dataset.pre_compute_ode = False
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        Y_future = torch.cat([x["Y_future"] for x in outputs])
        preds = torch.cat([x["preds"] for x in outputs])
        mask_future = torch.cat([x["mask_future"] for x in outputs])
        uncertainty_pred = None
        Y_rec = torch.cat([x["Y_rec"] for x in outputs])
        pred_rec = torch.cat([x["pred_rec"] for x in outputs])
        mask_rec = torch.cat([x["mask_rec"] for x in outputs])

        # uncertainty_pred = torch.cat([x["uncertainty_pred"] for x in outputs])

        mse = model.compute_loss(Y_future, preds, mask_future)
        mse_rec = model.compute_loss(Y_rec, pred_rec, mask_rec)
        
        accs_.append(mse)
        accs_rec_.append(mse_rec)
    
    acc_dict[f"{irregular_rate}"]=accs_
    acc_rec_dict[f"{irregular_rate}"]=accs_rec_
print(acc_dict)
print(process_dict(acc_dict))
print(acc_rec_dict)
print(process_dict(acc_rec_dict))

In [3]:
Y_future.shape

torch.Size([4790, 24, 10])

In [4]:
preds.shape

torch.Size([4790, 1])

In [8]:
mask_rec.shape

torch.Size([4790, 12, 10])

In [12]:
last_idx = mask_rec.shape[1] - torch.flip(mask_rec,dims=(1,)).argmax(1) - 1

In [16]:
torch.gather(Y_rec,dim=1, index = last_idx.unsqueeze(1)).shape

torch.Size([4790, 1, 10])

In [17]:
Y_rec.shape

torch.Size([4790, 12, 10])