In [2]:
import wandb
import os
import pytorch_lightning as pl
import torch
from legendre.models.spline_cnode import SplineCNODEClass
from legendre.models.cnode_ext import CNODExtClassification, CNODExt
from legendre.models.hippo import HIPPO
from legendre.data_utils.simple_path_utils import SimpleTrajDataModule
from legendre.data_utils.pMNIST_utils import pMNISTDataModule
from legendre.data_utils.character_utils import CharacterTrajDataModule
from sklearn.metrics import accuracy_score, roc_auc_score
from pytorch_lightning.core.saving import _load_state

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


# Classification

## SplineCNODE

In [2]:
sweep_id = "94ldi5sg"
multivariate = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = SplineCNODEClass
irregular_rates = [0.5]
acc_dict = {}

def data_cls_choice(data_type):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:23<00:00, 13.17s/it]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:25<00:00, 13.28s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 104.50it/s]


In [3]:
acc_dict

{'0.5': [0.9692092372288313]}

## CNODExt

In [13]:
sweep_id = "ej2hbh5g"
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = CNODExtClassification
init_model_cls = CNODExt
irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, init_model):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**init_model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**init_model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**init_model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
        if input_dim is None:
            init_model = CNODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        else:
            init_model = CNODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
        model = _load_state(model_cls, checkpoint_,init_model = init_model)
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        dataset = data_cls_choice(model.hparams.data_type, init_model = model)
        #dataset.pre_compute_ode = False
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not s

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:49<00:00, 14.48s/it]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:48<00:00, 14.41s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 109.57it/s]


  rank_zero_warn(


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:47<00:00, 14.39s/it]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [04:47<00:00, 14.37s/it]
  rank_zero_deprecation(
  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/

Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 111.64it/s]


  rank_zero_warn(
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fbbe32e1b80>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_wo

Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [05:03<00:00, 15.15s/it]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [05:02<00:00, 15.11s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 109.07it/s]


  rank_zero_warn(


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [05:23<00:00, 16.16s/it]


Pre-computing ODE Projection embeddings....


100%|██████████| 20/20 [05:17<00:00, 15.88s/it]
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 114.76it/s]
{'0.3': [0.9072778166550035], '0.4': [0.9429671098670399], '0.5': [0.9503149055283415], '0.6': [0.9657102869139258]}


In [14]:
acc_dict

{'0.3': [0.9072778166550035],
 '0.4': [0.9429671098670399],
 '0.5': [0.9503149055283415],
 '0.6': [0.9657102869139258]}

## Hippo

In [12]:
sweep_id = "iycs44zs"
multivariate = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
runs = sweep.runs

model_cls = HIPPO

irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**model.hparams)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**model.hparams)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**model.hparams)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
        checkpoint = torch.load(fname, map_location=lambda storage, loc: storage)
        hparams = checkpoint["hyper_parameters"]
        model = model_cls.load_from_checkpoint(fname,num_dims = hparams.get("num_dims",1))
        os.remove(fname)
        dataset = data_cls_choice(model.hparams.data_type)
        dataset.prepare_data()
        
        trainer = pl.Trainer(gpus=1)
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 28.80it/s]


  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 29.72it/s]


  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 28.50it/s]


  rank_zero_deprecation(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 23/23 [00:00<00:00, 29.22it/s]
{'0.3': [0.9153254023792862], '0.4': [0.9450664800559833], '0.5': [0.9552134359692093], '0.6': [0.9636109167249826]}


# Forecasting

In [18]:
sweep_id = "gztx8pq1"
multivariate = True
api= wandb.Api()
sweep = api.sweep("edebrouwer/orthopoly/"+sweep_id)
#init_sweep_id = "edebrouwer/orthopoly/t77q79lw"
runs = sweep.runs

model_cls = CNODExt
irregular_rates = [0.3, 0.4, 0.5, 0.6]
seed = 421
acc_dict = {}

def data_cls_choice(data_type, hparams):
    if data_type == "SimpleTraj":
        data_cls = SimpleTrajDataModule(**hparams, forecast_mode = True)
    elif data_type == "pMNIST":
        data_cls = pMNISTDataModule(**hparams, forecast_mode = True)
    elif data_type == "Character":
        data_cls = CharacterTrajDataModule(**hparams, forecast_mode = True)
    return data_cls

for irregular_rate in irregular_rates:
    
    accs_ = []
    run_sub = [r for r in runs if (r.config["irregular_rate"]==irregular_rate) and (r.config.get("multivariate",None)==multivariate)]

    for run in run_sub:
        fname = [f.name for f in run.files() if "ckpt" in f.name][0]
        run.file(fname).download(replace = True, root = ".")
         
        checkpoint = torch.load(
        fname, map_location=lambda storage, loc: storage)
        checkpoint["hyper_parameters"].pop("callbacks", None)
        checkpoint["hyper_parameters"].pop("logger", None)
        checkpoint["hyper_parameters"].pop("wandb_id_file_path", None)
        checkpoint.pop("callbacks",None)
        checkpoint_ = {"state_dict":checkpoint["state_dict"],"hyper_parameters":checkpoint["hyper_parameters"]}
        checkpoint_["hyper_parameters"].pop("init_model",None)
        
        input_dim = checkpoint["hyper_parameters"].pop("input_dim",None)
                
        dataset = data_cls_choice(checkpoint["hyper_parameters"]["data_type"],checkpoint["hyper_parameters"])

        if input_dim is None:
            model = _load_state(model_cls, checkpoint_, output_dim = dataset.num_dims)

            #init_model = CNODExt( **checkpoint["hyper_parameters"],output_dim = checkpoint["hyper_parameters"]["num_dims"])
        else:
            model = _load_state(model_cls, checkpoint_, output_dim = input_dim)
            #init_model = CNODExt(output_dim = input_dim, **checkpoint["hyper_parameters"])
        
               
        #model = model_cls.load_from_checkpoint(fname)
        os.remove(fname)
        
        #dataset.pre_compute_ode = False
        dataset.prepare_data()
    
        trainer = pl.Trainer(gpus=1,logger = None)
        
        outputs = trainer.predict(model,dataset.test_dataloader())

        preds = torch.cat([x["preds"] for x in outputs])
        Y = torch.cat([x["Y"] for x in outputs])
        T = torch.cat([x["T"] for x in outputs])
        labels = torch.cat([x["labels"] for x in outputs])
        
        if len(preds.shape)>1:
            preds = torch.nn.functional.softmax(preds, dim=-1).argmax(-1)
            accuracy = accuracy_score(labels.long().cpu().numpy(), preds.cpu().numpy())
        else:
            accuracy = roc_auc_score(labels.long().cpu().numpy(),preds.cpu().numpy())
        #print(accuracy)
        
        accs_.append(accuracy)
    
    acc_dict[f"{irregular_rate}"]=accs_
print(acc_dict)

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb64881c550>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not self._shutdown:
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_shutdown'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb64881c550>
Traceback (most recent call last):
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1510, in __del__
    self._shutdown_workers()
  File "/ssd003/home/edebrouw/Projects/Legendre/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    if not s

Predicting DataLoader 0:   0%|          | 0/23 [00:00<?, ?it/s]

ValueError: not enough values to unpack (expected 9, got 5)