In [1]:
import numpy as np
import os
import sys
import copy

import torch
torch.set_num_threads(os.cpu_count())
from torch import nn

import json

import matplotlib as mpl
import matplotlib.pyplot as plt

from pipnet import data
from pipnet import model
from pipnet import utils

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
mod = "final_model_mixed"

in_dir = f"../../data/1D/{mod}/"
fig_dir = f"../../figures/1D/{mod}/"

batch_size = 64

epochs = [1, 50, 100, 150, 200, 250]
eval_all_steps = True

iso_pars = dict(
    td = 512,
    Fs = 12_800,
    nmin = 1,
    nmax = 15,
    freq_range = [2_000., 10_000.],
    gmin = 1,
    gmax = 1,
    spread = 5.,
    lw_range = [[5e1, 2e2], [1e2, 5e2], [1e2, 1e3]],
    lw_probs = [0.7, 0.2, 0.1],
    int_range = [0.5, 1.], # Intensity
    phase = 0.,
    debug = False,
)

mas_pars = dict(
    nw = 24,
    mas_w_range = [30_000., 100_000.],
    random_mas = False,
    mas_phase_p = 0.5,
    mas_phase_scale = 0.05,
    
    # First-order MAS-dependent parameters
    mas1_lw_range = [[1e7, 5e7], [5e7, 1e8]],
    mas1_lw_probs = [0.8, 0.2],
    mas1_m_range = [[0., 0.], [0., 1e4], [1e4, 5e4]],
    mas1_m_probs = [0.1, 0.1, 0.8],
    mas1_s_range = [[-1e7, 1e7]],
    mas1_s_probs = [1.],

    # Second-order MAS-dependent parameters
    mas2_prob = 1.,
    mas2_lw_range = [[0., 0.], [1e11, 5e11]],
    mas2_lw_probs = [0.5, 0.5],
    mas2_m_range = [[0., 0.], [1e8, 5e8]],
    mas2_m_probs = [0.8, 0.2],
    mas2_s_range = [[0., 0.], [-2e10, 2e10]],
    mas2_s_probs = [0.8, 0.2],
    
    # Other MAS-dependent parameters
    non_mas_p = 0.5,
    non_mas_m_trends = ["constant", "increase", "decrease"],
    non_mas_m_probs = [0.34, 0.33, 0.33],
    non_mas_m_range = [0., 1.],
    
    int_decrease_p = 0.1,
    int_decrease_scale =[0.3, 0.7],
    debug = False,
)

with open(f"{in_dir}data_pars.json", "r") as F:
    data_pars = json.load(F)

data_pars["iso_pars"] = iso_pars
data_pars["mas_pars"] = mas_pars
data_pars["noise"] = 0.
data_pars["mas_l_noise"] = 0.05
data_pars["mas_s_noise"] = 25.
data_pars["gen_mas_shifts"] = False

loss_pars1 = dict(
    trg_fuzz = 0.,
    trg_fuzz_len = 0,
    ndim = 1,
    exp = 1.0,
    offset = 1.0,
    factor = 0.0,
    int_w = 1.0,
    int_exp = 1.0,
    return_components = True,
    device = device,
)

loss_pars2 = dict(
    trg_fuzz = 0.,
    trg_fuzz_len = 0,
    ndim = 1,
    exp = 2.0,
    offset = 1.0,
    factor = 0.0,
    int_w = 1.0,
    int_exp = 2.0,
    return_components = True,
    device = device,
)

loss1 = model.PIPLoss(**loss_pars1)
loss2 = model.PIPLoss(**loss_pars2)

In [3]:
if not os.path.exists(in_dir):
    raise ValueError(f"Unknown model: {mod}")
    
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)

fdir = fig_dir + "evaluate_final_epochs/"

if not os.path.exists(fdir):
    os.mkdir(fdir)
    
with open(f"{in_dir}model_pars.json", "r") as F:
    model_pars = json.load(F)
model_pars["noise"] = 0.

In [4]:
np.random.seed(1)
dataset = data.Dataset(**data_pars)

X, y = dataset.generate_batch(size=batch_size)

In [5]:
for i in range(batch_size):
    utils.plot_1d_dataset(
        X[i], y[i], y_scale=0.5,
        offset=0.01, xvals=dataset.f,
        show=False, save=f"{fdir}sample_{i+1}_input.pdf"
    )

In [6]:
for epoch in epochs:
    
    print(f"Epoch {epoch}...")
    
    net = model.ConvLSTMEnsemble(**model_pars).to(device)
    net.load_state_dict(torch.load(in_dir + f"epoch_{epoch}_network", map_location=torch.device(device)))
    net = net.eval()
    
    with torch.no_grad():
        y_pred, y_std, ys_pred = net(X)
    
    for ishow in range(batch_size):
        utils.plot_1d_iso_prediction(
            X[ishow],
            y_pred[ishow],
            y_std[ishow],
            y_trg = y[ishow, 0],
            pred_scale=0.5,
            trg_scale=0.5,
            X_offset = 0.2,
            pred_offset=0.1,
            xvals=dataset.f,
            wr_factor=dataset.norm_wr,
            all_steps=False,
            show=False,
            save=f"{fdir}sample_{ishow+1}_epoch_{epoch}.pdf",
        )

Epoch 1...
Epoch 50...
Epoch 100...
Epoch 150...
Epoch 200...
Epoch 250...
