In [1]:
%load_ext autoreload
%autoreload 2
import pytorch_lightning as pl
import pytorch_lightning.callbacks as pl_callbacks
import torch
import eq
import wandb
from tqdm.notebook import trange
import numpy as np

In [2]:
device = "cuda"

In [3]:
catalog = eq.catalogs.ANSS_MultiCatalog(mag_completeness=4.5)

Loading existing catalog from /home/zekai/repos/recast/data/ANSS_MultiCatalog.


In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
def train(config):
    dl_train = catalog.train.get_dataloader()
    dl_val = catalog.val.get_dataloader()
    dl_test = catalog.test.get_dataloader()

    # model = eq.models.RecurrentTPP(**config)
    model = eq.models.RecurrentTPP()
    model = model.to(device)

    epochs = 200
    avg_train_loss_list = []
    avg_val_loss_list = []

    optimizer = torch.optim.AdamW(model.parameters(), 
                                  lr=config["learning_rate"], 
                                  betas=config["betas"],
                                  weight_decay=config["weight_decay"])
    
    best_model_path = "temp_best_model"
    best_val_loss = float('inf')

    for epoch in trange(epochs):
        running_training_loss = []
        model.train()
        for i, data in enumerate(dl_train):
            data = data.to(device)
            optimizer.zero_grad()
            nll = model.nll_loss(data).mean()
            nll.backward()
            optimizer.step()
            running_training_loss.append(nll.item())
        
        model.eval()
        with torch.no_grad():
            running_val_loss = []
            for i, data in enumerate(dl_val):
                data = data.to(device)
                nll = model.nll_loss(data).mean()
                running_val_loss.append(nll.item())

        avg_val_loss = np.mean(running_val_loss)

        avg_train_loss_list.append(np.mean(running_training_loss))
        avg_val_loss_list.append(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)

    best_model = torch.load(best_model_path)
    with torch.no_grad():
        running_test_loss = []
        for i, data in enumerate(dl_test):
            data = data.to(device)
            nll = best_model.nll_loss(data).mean()
            running_test_loss.append(nll.item())
    avg_test_loss = np.mean(running_test_loss)

    return avg_test_loss


In [6]:
res = train({"learning_rate":1e-4,
             "betas": (0.9, 0.999),
             "weight_decay": 0.01})

  0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def train(config):
    try:
        dl_train = catalog.train.get_dataloader()
        dl_val = catalog.val.get_dataloader()
        dl_test = catalog.test.get_dataloader()

        model = eq.models.RecurrentTPP(**config)
        model = model.to(device)

        epochs = 200
        avg_train_loss_list = []
        avg_val_loss_list = []

        optimizer = torch.optim.AdamW(model.parameters(), 
                                    lr=config["learning_rate"], 
                                    betas=config["betas"],
                                    weight_decay=config["weight_decay"])
        
        best_model_path = "temp_best_model"
        best_val_loss = float('inf')

        for epoch in trange(epochs):
            running_training_loss = []
            model.train()
            for i, data in enumerate(dl_train):
                data = data.to(device)
                optimizer.zero_grad()
                nll = model.nll_loss(data).mean()
                nll.backward()
                optimizer.step()
                running_training_loss.append(nll.item())
            
            model.eval()
            with torch.no_grad():
                running_val_loss = []
                for i, data in enumerate(dl_val):
                    data = data.to(device)
                    nll = model.nll_loss(data).mean()
                    running_val_loss.append(nll.item())

            avg_val_loss = np.mean(running_val_loss)

            avg_train_loss_list.append(np.mean(running_training_loss))
            avg_val_loss_list.append(avg_val_loss)

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), best_model_path)

        best_model = torch.load(best_model_path)
        with torch.no_grad():
            running_test_loss = []
            for i, data in enumerate(dl_test):
                data = data.to(device)
                nll = best_model.nll_loss(data).mean()
                running_test_loss.append(nll.item())
        avg_test_loss = np.mean(running_test_loss)
    
    except Exception:
        avg_test_loss = float("nan")

    return avg_test_loss
