In [10]:
%load_ext autoreload
%autoreload 2
import torch
import eq
import wandb
from tqdm.notebook import trange
import numpy as np
import warnings
warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
device = "cuda"
catalog = eq.catalogs.ANSS_MultiCatalog(mag_completeness=4.5)

Loading existing catalog from /home/zekai/repos/recast/data/ANSS_MultiCatalog.


In [12]:
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config
        try:
            dl_train = catalog.train.get_dataloader(batch_size=config.batch_size)
            dl_val = catalog.val.get_dataloader(batch_size=1)
            dl_test = catalog.test.get_dataloader(batch_size=1)

            model = eq.models.RecurrentTPP(context_size=config.context_size,
                                           num_components=config.num_components,
                                           rnn_type=config.rnn_type,
                                           dropout_proba=config.dropout_proba,
                                           learning_rate=config.lr)
            model = model.to(device)

            epochs = 200
            avg_train_loss_list = []
            avg_val_loss_list = []

            optimizer = torch.optim.AdamW(model.parameters(), 
                                          lr=config.lr, 
                                          betas=config.betas, 
                                          weight_decay=config.weight_decay)
            
            best_model_path = "temp_best_model"
            best_val_loss = float('inf')

            for epoch in trange(epochs):
                running_training_loss = []
                model.train()
                for i, data in enumerate(dl_train):
                    data = data.to(device)
                    optimizer.zero_grad()
                    nll = model.nll_loss(data).mean()
                    nll.backward()
                    optimizer.step()
                    running_training_loss.append(nll.item())
                
                model.eval()
                with torch.no_grad():
                    running_val_loss = []
                    for i, data in enumerate(dl_val):
                        data = data.to(device)
                        nll = model.nll_loss(data).mean()
                        running_val_loss.append(nll.item())

                avg_val_loss = np.mean(running_val_loss)

                avg_train_loss_list.append(np.mean(running_training_loss))
                avg_val_loss_list.append(avg_val_loss)

                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    torch.save(model.state_dict(), best_model_path)

            best_model = eq.models.RecurrentTPP(context_size=config.context_size,
                                                num_components=config.num_components,
                                                rnn_type=config.rnn_type,
                                                dropout_proba=config.dropout_proba,
                                                learning_rate=config.lr)
            best_model.load_state_dict(torch.load(best_model_path))
            best_model = best_model.to(device)
            best_model.eval()
            with torch.no_grad():
                running_test_loss = []
                for i, data in enumerate(dl_test):
                    data = data.to(device)
                    nll = best_model.nll_loss(data).mean()
                    running_test_loss.append(nll.item())
            avg_test_loss = np.mean(running_test_loss)

        except Exception as e:
            print(e)
            avg_test_loss = float("nan")

        wandb.log({"avg_test_loss": avg_test_loss})


In [13]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "sweep.ipynb"
wandb.login()



True

In [14]:
sweep_config = {'method': "random"}
metric = {"name": "avg_test_loss",
          "goal": "minimize"}
sweep_config["metric"] = metric

parameter_dict = {"context_size": {'values': [8, 16, 32, 64, 128]},
                  "num_components": {"values": [8, 16, 32, 64, 128]},
                  "rnn_type": {"values": ["RNN", "GRU", "LSTM"]},
                  "dropout_proba": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]},
                  "lr": {"values": [1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2]},
                  "betas": {"value": (0.9, 0.999)},
                  "weight_decay": {"values": [0, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2]},
                  "batch_size": {"values": [8, 16, 32, 64, 128]}}

sweep_config["parameters"] = parameter_dict

In [15]:
sweep_id = wandb.sweep(sweep_config, project="RecurrentTPP on ANSS Nov 5")

Create sweep with ID: 27ca0onv
Sweep URL: https://wandb.ai/zekai-wang/RecurrentTPP%20on%20ANSS%20Nov%205/sweeps/27ca0onv


In [16]:
wandb.agent(sweep_id, function=train, count=100)



<IPython.core.display.HTML object>


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


<IPython.core.display.HTML object>
<IPython.core.display.HTML object>


[34m[1mwandb[0m: Agent Starting Run: tdq51qjk with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 8
[34m[1mwandb[0m: 	dropout_proba: 0
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	num_components: 128
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0.05


  0%|          | 0/200 [00:00<?, ?it/s]

Exception in thread Thread-20 (_run_job):
Traceback (most recent call last):
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/site-packages/wandb/agents/pyagent.py", line 298, in _run_job
    self._function()
  File "/tmp/ipykernel_15643/1366785928.py", line 2, in train
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 3120, in __exit__
    self._finish(exit_code=exit_code)
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 1938, in _finish
    hook.call()
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/site-packages/wandb/sdk/wandb_init.py", line 464, in _jupyter_teardown
    ipython.display_pub.publish = ipython.display_pub._orig_publish
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'ZMQDisplayPublisher' object has no attribute '_orig_publish'

During handling of the above exception, another exception occurred:

Traceback (most recent call la

<IPython.core.display.HTML object>
VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max=1.0)))
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>


[34m[1mwandb[0m: Agent Starting Run: exm3zv6a with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 8
[34m[1mwandb[0m: 	dropout_proba: 0.1
[34m[1mwandb[0m: 	lr: 0.02
[34m[1mwandb[0m: 	num_components: 128
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.02


  0%|          | 0/200 [00:00<?, ?it/s]

Exception in thread Exception in thread ChkStopThr:
NetStatThr:
Traceback (most recent call last):
Exception in thread Traceback (most recent call last):
IntMsgThr  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
:
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
Traceback (most recent call last):
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 1038, in _bootstrap_inner
        self.run()self.run()
    
self.run()  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 975, in run

  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 975, in run
  File "/home/zekai/miniconda3/envs/eq/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
          File "/home/zekai/miniconda3/envs/eq/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 267, in check_network_status
self._target(*se

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.16265


[34m[1mwandb[0m: Agent Starting Run: hqe7mqof with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 128
[34m[1mwandb[0m: 	dropout_proba: 0.1
[34m[1mwandb[0m: 	lr: 0.05
[34m[1mwandb[0m: 	num_components: 32
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0.001


  0%|          | 0/200 [00:00<?, ?it/s]

Expected parameter scale (Tensor of shape (8, 462, 32)) of distribution Weibull(scale: torch.Size([8, 462, 32]), shape: torch.Size([8, 462, 32])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[[   nan, 0.7629, 0.7606,  ..., 0.7818, 0.8243, 0.8388],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],

        [[   nan, 0.7629, 0.7606,  ..., 0.7818, 0.8243, 0.8388],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,  

0,1
avg_test_loss,


[34m[1mwandb[0m: Agent Starting Run: 3f23ig8i with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 32
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	num_components: 64
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.18022


[34m[1mwandb[0m: Agent Starting Run: mozsxwej with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 32
[34m[1mwandb[0m: 	dropout_proba: 0.3
[34m[1mwandb[0m: 	lr: 0.005
[34m[1mwandb[0m: 	num_components: 32
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0


  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.184864…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.17367


[34m[1mwandb[0m: Agent Starting Run: u688d0ul with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 32
[34m[1mwandb[0m: 	dropout_proba: 0
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	num_components: 64
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.02


  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.246308…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.16217


[34m[1mwandb[0m: Agent Starting Run: j0h7zq2i with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 64
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_components: 16
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0.001


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.17515


[34m[1mwandb[0m: Agent Starting Run: 21wsc8e7 with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 16
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.05
[34m[1mwandb[0m: 	num_components: 8
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0.01


  0%|          | 0/200 [00:00<?, ?it/s]

Expected parameter scale (Tensor of shape (8, 582, 8)) of distribution Weibull(scale: torch.Size([8, 582, 8]), shape: torch.Size([8, 582, 8])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[[0.8988, 3.6720, 0.7561,  ..., 1.1632, 8.1114, 2.0297],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],

        [[0.8988, 3.6720, 0.7561,  ..., 1.1632, 8.1114, 2.0297],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    n

0,1
avg_test_loss,


[34m[1mwandb[0m: Agent Starting Run: sw1p9aoi with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 128
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.05
[34m[1mwandb[0m: 	num_components: 16
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.01


  0%|          | 0/200 [00:00<?, ?it/s]

Expected parameter scale (Tensor of shape (8, 767, 16)) of distribution Weibull(scale: torch.Size([8, 767, 16]), shape: torch.Size([8, 767, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([[[0.6533, 0.7147, 0.8922,  ..., 0.6525, 0.9216, 0.9456],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan]],

        [[0.6533, 0.7147, 0.8922,  ..., 0.6525, 0.9216, 0.9456],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         ...,
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,    nan,  ...,    nan,    nan,    nan],
         [   nan,    nan,  

0,1
avg_test_loss,


[34m[1mwandb[0m: Agent Starting Run: e5tlwvam with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 128
[34m[1mwandb[0m: 	dropout_proba: 0.5
[34m[1mwandb[0m: 	lr: 0.005
[34m[1mwandb[0m: 	num_components: 64
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0.001


  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.014 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.757529…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.18658


[34m[1mwandb[0m: Agent Starting Run: 3d69ednq with config:
[34m[1mwandb[0m: 	batch_size: 16
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 64
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.01
[34m[1mwandb[0m: 	num_components: 16
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.002


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.18031


[34m[1mwandb[0m: Agent Starting Run: j9ifcthf with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 64
[34m[1mwandb[0m: 	dropout_proba: 0.2
[34m[1mwandb[0m: 	lr: 0.002
[34m[1mwandb[0m: 	num_components: 16
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0.05


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.16392


[34m[1mwandb[0m: Agent Starting Run: y9wv6zrt with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 8
[34m[1mwandb[0m: 	dropout_proba: 0.4
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_components: 64
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0.002


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.22648


[34m[1mwandb[0m: Agent Starting Run: 257wzujr with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 128
[34m[1mwandb[0m: 	dropout_proba: 0
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_components: 32
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.002


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.17032


[34m[1mwandb[0m: Agent Starting Run: 7i8chsrd with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 16
[34m[1mwandb[0m: 	dropout_proba: 0
[34m[1mwandb[0m: 	lr: 0.002
[34m[1mwandb[0m: 	num_components: 16
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0.005


  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.184621…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.17476


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: q18o6yyc with config:
[34m[1mwandb[0m: 	batch_size: 8
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 128
[34m[1mwandb[0m: 	dropout_proba: 0.3
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_components: 8
[34m[1mwandb[0m: 	rnn_type: LSTM
[34m[1mwandb[0m: 	weight_decay: 0.02


  0%|          | 0/200 [00:00<?, ?it/s]

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.16572


[34m[1mwandb[0m: Agent Starting Run: e8dapfta with config:
[34m[1mwandb[0m: 	batch_size: 128
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 32
[34m[1mwandb[0m: 	dropout_proba: 0.2
[34m[1mwandb[0m: 	lr: 0.002
[34m[1mwandb[0m: 	num_components: 32
[34m[1mwandb[0m: 	rnn_type: RNN
[34m[1mwandb[0m: 	weight_decay: 0.001


  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_test_loss,▁

0,1
avg_test_loss,0.17972


[34m[1mwandb[0m: Agent Starting Run: e2f7bwjk with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	betas: [0.9, 0.999]
[34m[1mwandb[0m: 	context_size: 8
[34m[1mwandb[0m: 	dropout_proba: 0.3
[34m[1mwandb[0m: 	lr: 0.001
[34m[1mwandb[0m: 	num_components: 32
[34m[1mwandb[0m: 	rnn_type: GRU
[34m[1mwandb[0m: 	weight_decay: 0.01


  0%|          | 0/200 [00:00<?, ?it/s]