In [1]:
import numpy as np
import os
import sys

import torch
torch.set_num_threads(os.cpu_count())
from torch import nn

import matplotlib.pyplot as plt
from IPython import display

from pipnet import data
from pipnet import model
from pipnet import train

np.random.seed(1)

# Define data generation parameters

In [7]:
import pickle
import json
import os

In [9]:
in_dir = "../../trained_models/PIPNet_model/"
for f in os.listdir(in_dir):
    if f.endswith(".pk"):
        
        with open(in_dir + f, "rb") as F:
            tmp = pk.load(F)
        
        with open(in_dir + f.replace(".pk", ".json"), "w") as F:
            json.dump(tmp, F)

In [2]:
iso_pars = dict(
    td = 128,
    Fs = 3_200,
    nmin = 1,
    nmax = 5,
    freq_range = [500., 2700.],
    gmin = 1,
    gmax = 1,
    spread = 5.,
    lw_range = [[5e1, 2e2], [1e2, 5e2], [1e2, 1e3]],
    lw_probs = [0.7, 0.2, 0.1],
    int_range = [0.5, 1.], # Intensity
    phase = 0.,
    debug = False,
)

mas_pars = dict(
    nw = 8,
    mas_w_range = [50_000., 100_000.],
    random_mas = False,
    mas_phase_p = 0.5,
    mas_phase_scale = 0.05,
    
    # First-order MAS-dependent parameters
    mas1_lw_range = [[1e7, 5e7], [5e7, 1e8]],
    mas1_lw_probs = [0.8, 0.2],
    mas1_m_range = [[0., 0.], [0., 1e4], [1e4, 5e4]],
    mas1_m_probs = [0.1, 0.1, 0.8],
    mas1_s_range = [[-1e7, 1e7]],
    mas1_s_probs = [1.],

    # Second-order MAS-dependent parameters
    mas2_prob = 1.,
    mas2_lw_range = [[0., 0.], [1e11, 5e11]],
    mas2_lw_probs = [0.5, 0.5],
    mas2_m_range = [[0., 0.], [1e8, 5e8]],
    mas2_m_probs = [0.8, 0.2],
    mas2_s_range = [[0., 0.], [-2e10, 2e10]],
    mas2_s_probs = [0.8, 0.2],
    
    # Other MAS-dependent parameters
    non_mas_p = 0.5,
    non_mas_m_trends = ["constant", "increase", "decrease"],
    non_mas_m_probs = [0.34, 0.33, 0.33],
    non_mas_m_range = [0., 1.],
    
    int_decrease_p = 0.1,
    int_decrease_scale =[0.3, 0.7],
    debug = False,
)

data_pars = dict(
    iso_pars = iso_pars,
    mas_pars = mas_pars,
    
    positive_iso = True,
    encode_imag = False, # Encode the imaginary part of the MAS spectra
    encode_wr = True, # Encode the MAS rate of the spectra

    # noise parameters
    noise = 0., # Noise level
    mas_l_noise = 0.05,
    mas_s_noise = 25.,
    
    smooth_end_len = 10, # Smooth ends of spectra
    iso_spec_norm = 40., # Normalization factor for peaks
    mas_spec_norm = 8., # Normalization factor for MAS spectra
    wr_norm_factor = 100_000.,
    wr_inv = False, # Encode inverse of MAS rate instead of MAS rate
    gen_mas_shifts = True,
)

model_pars = dict(
    input_dim = 2,
    n_models = 1,
    hidden_dim = [64, 64],
    kernel_size = [5, 5],
    num_layers = 2,
    batch_input = 1,
    bias = True,
    output_bias = True,
    return_all_layers = True,
    batch_norm = False,
    ndim = 2,
    independent = True,
    output_kernel_size = 5,
    output_act = "sigmoid",
    noise = 0.,
    invert = False,
)

loss_pars = dict(
    trg_fuzz=3.,
    trg_fuzz_len=25,
    ndim=2,
    exp=1.0,
    offset=1.0,
    factor=100.0,
    int_w=0.0,
    int_exp=2.0,
    return_components=False,
    device="cpu",
)

train_pars = dict(
    batch_size = 4,
    num_workers = 4,
    batches_per_epoch = 10,
    batches_per_eval = 5,
    n_epochs = 10,
    change_loss={3: {"trg_fuzz": 3.0, "factor": 10.},
                 6: {"trg_fuzz": 0.0, "factor": 0.},
                },
    out_dir = "../../data/2D/sanity_check/",
    device = "cuda" if torch.cuda.is_available() else "cpu",
    monitor_end = "\r"
)

fig_dir = "../../figures/2D/sanity_check/"
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)

In [3]:
import pickle as pk
import json

In [5]:
with open("model_pars.json", "w") as F:
    json.dump(model_pars, F)

In [3]:
dataset = data.Dataset2D(params_x=data_pars, params_y=data_pars)

# Define network, loss and optimizer

In [4]:
net = model.ConvLSTMEnsemble(**model_pars)
loss = model.PIPLoss(**loss_pars)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=50)

In [None]:
train.train(
    dataset,
    net,
    opt,
    loss,
    sch,
    **train_pars
)

Starting training...
    Training batch   10: loss =  1.3097e+00, mean loss =  3.8998e+00, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch    5: loss =  1.8110e+00, mean loss =  3.2327e+00...
  End of evaluation.
    Training batch   20: loss =  2.3213e+00, mean loss =  3.0231e+00, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch    5: loss =  3.1381e+00, mean loss =  1.9597e+00...
  End of evaluation.
    Training batch   30: loss =  2.1449e+00, mean loss =  3.2398e+00, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch    5: loss =  1.8491e+00, mean loss =  3.7431e+00...
  End of evaluation.

    Changing loss parameter trg_fuzz to 3.0...

    Changing loss parameter factor to 10.0...
    Training batch   40: loss =  2.5312e-01, mean loss =  4.1699e-01, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch    5: loss =  5.0436e-01, mean los