In [1]:
import wandb
import os
import pytorch_lightning as pl
import torch
from legendre.models.spline_cnode import SplineCNODEClass
from legendre.models.cnode_ext import CNODExtClassification, CNODExt
from legendre.models.node_ext import NODExtClassification, NODExt
from legendre.models.rnn import RNNClassification, RNN
from legendre.models.hippo import HIPPO, HippoClassification
from legendre.data_utils.simple_path_utils import SimpleTrajDataModule
from legendre.data_utils.pMNIST_utils import pMNISTDataModule
from legendre.data_utils.character_utils import CharacterTrajDataModule
from legendre.data_utils.mimic_utils import MIMICDataModule
from legendre.data_utils.lorenz_utils import LorenzDataModule
from legendre.train_scripts.classif import get_init_model
from sklearn.metrics import accuracy_score, roc_auc_score
from pytorch_lightning.core.saving import _load_state

  from .autonotebook import tqdm as notebook_tqdm


# Classification

## SplineCNODE

In [7]:
#from urllib.parse import non_hierarchical
sweep_id = "luqt6vso"
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = SplineCNODEClass
irregular_rates = [0.3,0.4,0.5]
acc_dict = {}

def data_cls_choice(data_type):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**model.hparams)
    return data_cls


for irregular_rate in irregular_rates:

    accs_ = []
    
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
        
    acc_dict[f"{irregular_rate}"]=accs_
    print(acc_dict)

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:16<00:00,  3.85s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 41.36it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.79s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 34.03it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.79s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 41.16it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.78s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 39.86it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.77s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 42.26it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.78s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 35.86it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.77s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 37.85it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.76s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 40.78it/s] 
Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [01:15<00:00,  3.77s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 40.37it/s] 


In [13]:
def mean_dict(d):
    return {k:sum(v)/len(v) for k,v in d.items()}
def std_dict(d):
    return {k:sum([(x-mean_dict(d)[k])**2 for x in v])/len(v) for k,v in d.items()}

#print(mean_dict(acc_dict))
print(mean_dict(acc_dict))
print(std_dict(acc_dict))

{'0.3': 0.9763237226125073, '0.4': 0.9842834347404965, '0.5': 0.9874872873552635}
{'0.3': 1.857002613376012e-05, '0.4': 2.9012533550325858e-05, '0.5': 3.767992416202517e-05}


## CNODExt

In [14]:
sweep_id = "hme1gza6"
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = CNODExtClassification
init_model_cls = CNODExt
irregular_rates = [0.3, 0.4, 0.5]
#irregular_rates = [0.7, 0.8, 0.9, 1.0]
device = torch.device("cuda")
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)

        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

KeyError: 'init_sweep_id'

## Hippo

In [4]:
sweep_id = "vn074zre"
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = HIPPO

irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f19cd868550>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f19cd868550>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_

UnboundLocalError: local variable 'data_cls' referenced before assignment

## Hippo-RNN

In [5]:
sweep_id = "pifyehqf"
device = torch.device("cuda")
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = HippoClassification
init_model_cls = HIPPO
#irregular_rates = [0.3, 0.4, 0.5]#, 0.9, 1.0]
irregular_rates = [0.7, 0.8, 0.9, 1.0]#, 0.9, 1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        #input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        #if input_dim is None:
        #    init_model = NODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        #else:
        #    init_model = NODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

AttributeError: 'SimpleTrajDataModule' object has no attribute 'set_test_only'

## NODExt

In [1]:
sweep_id = "pifyehqf"
device = torch.device("cuda")
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = NODExtClassification
init_model_cls = NODExt
irregular_rates = [0.3, 0.4, 0.5]#, 0.9, 1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    elif data_type == "MIMIC":
        data_cls = MIMICDataModule(**init_model.hparams)
    elif data_type == "Lorenz":
        data_cls = LorenzDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        #input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        #if input_dim is None:
        #    init_model = NODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        #else:
        #    init_model = NODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.set_test_only()
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

NameError: name 'torch' is not defined

## RNN

In [15]:
sweep_id = "1vc3k85v"
device = torch.device("cuda")
multivariate = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = RNNClassification
init_model_cls = RNN
#irregular_rates = [0.6, 0.7, 0.8, 0.9, 1.0]
irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        #input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        #if input_dim is None:
        #    init_model = RNN( **checkpoint["hyper_parameters"], output_dim = checkpoint["hyper_parameters"]["num_dims"])
        #else:
        #    init_model = RNN(output_dim = input_dim, **checkpoint["hyper_parameters"])

        init_model = get_init_model(checkpoint["hyper_parameters"]["init_sweep_id"],init_model_cls,checkpoint["hyper_parameters"]["irregular_rate"],checkpoint["hyper_parameters"]["seed"],checkpoint["hyper_parameters"].get("multivariate", False))
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model).to(device)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.79it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 38.12it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 111.57it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.70it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 37.70it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 161.18it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 56.22it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 56.54it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 143.26it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1412, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f11c0f9f9d0>
Traceback (most recent call last):
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1481, in __del__
    self._shutdown_workers()
  File "/voyager/projects/edebrouwer/.conda/envs/orthopoly/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 52.67it/s]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [00:00<00:00, 52.93it/s]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 155.72it/s]
{'0.3': [0.883484954513646], '0.4': [0.8540937718684395], '0.5': [0.8992302309307207], '0.6': [0.9062281315605318]}


# Forecasting

In [1]:
sweep_id = "c6yoaram"
multivariate = None
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = NODExt
irregular_rates = [0.6, 0.7, 0.8, 0.9, 1.0]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, hparams):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**hparams, forecast_mode = True)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**hparams, forecast_mode = True)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**hparams, forecast_mode = True)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
                
        dataset = data_cls_choice(checkpoint["hyper_parameters"]["data_type"],checkpoint["hyper_parameters"])

        if input_dim is None:
            model = _load_state(model_cls, checkpoint_, output_dim = dataset.num_dims)

            #init_model = CNODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        else:
            model = _load_state(model_cls, checkpoint_, output_dim = input_dim)
            #init_model = CNODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        #dataset.pre_compute_ode = False
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

NameError: name 'wandb' is not defined