In [None]:
import torch
import pickle
import random
import matplotlib.pyplot as plt
import numpy as np

import torchinfo
import importlib as imp

import data_loader.data_loader as data_loader
from trainer.trainer import Trainer
from model.model import TorchModel
from utils import utils
import model.loss as module_loss
import model.metric as module_metric

torch.set_warn_always(False)

# https://github.com/victoresque/pytorch-template/tree/master

# TODO: change from settings to config files

In [None]:
settings = {
    "batch_size": 32,
    "device": "gpu",
    "max_epochs": 10_000,
    "learning_rate": 0.00001,
    "patience": 2,
    "min_delta": 0.02,
    "criterion": "ShashNLL",
    "metrics": ("custom_mae", "iqr_capture", "sign_test"),
}

In [None]:
trainset = data_loader.CustomData("data/train_data.pickle")
valset = data_loader.CustomData("data/val_data.pickle")
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=settings["batch_size"], shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    valset, batch_size=settings["batch_size"], shuffle=False
)

In [None]:
criterion = getattr(module_loss, settings['criterion'])()
criterion

In [None]:
SEED = 44
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True


trainset = data_loader.CustomData("data/train_data.pickle")
valset = data_loader.CustomData("data/val_data.pickle")
train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=settings["batch_size"],
    shuffle=True,
    drop_last=False,
)
val_loader = torch.utils.data.DataLoader(
    valset,
    batch_size=settings["batch_size"],
    shuffle=False,
    drop_last=False,
)

model = TorchModel(target=trainset.target)
model.freeze_layers(freeze_id="tau")
device = utils.prepare_device(settings["device"])

criterion = getattr(module_loss, settings['criterion'])()
optimizer = torch.optim.Adam(model.parameters(), lr=settings["learning_rate"])
metric_funcs = [getattr(module_metric, met) for met in settings["metrics"]]

trainer = Trainer(
    model,
    criterion,
    metric_funcs,
    optimizer,
    max_epochs=settings["max_epochs"],
    device=device,
    data_loader=train_loader,
    validation_data_loader=val_loader,
    settings=settings,
)

# torchinfo.summary(
#     model,
#     [
#         trainset.input[: settings["batch_size"]].shape,
#         trainset.input_unit[: settings["batch_size"]].shape,
#     ],
#     verbose=0,
# )

In [None]:
model = model.to(device)
trainer.fit()

In [None]:
print(trainer.log.history.keys())

plt.figure(figsize=(16,4))
for i, m in enumerate(("loss", *settings["metrics"])):
    plt.subplot(1,4,i+1)
    plt.plot(trainer.log.history["epoch"], trainer.log.history[m], label=m)
    plt.plot(
        trainer.log.history["epoch"], trainer.log.history["val_" + m], label="val_" + m
    )
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
input = torch.Tensor(valset.input[:3]).to(device)
input_unit = torch.Tensor(valset.input_unit[:3]).to(device)
with torch.no_grad():
    pred = model(input, input_unit)
pred = pred.cpu().numpy()
pred

In [None]:
from shash.shash_torch import Shash

sample = 1
x = np.arange(-13, 13, 0.01)
dist = Shash(
    mu=pred[sample, 0],
    sigma=pred[sample, 1],
    gamma=pred[sample, 2],
    tau=pred[sample, 3],
)
p = dist.prob(torch.tensor(x))

plt.plot(x, p)