In [1]:
import numpy as np
import os
import sys

import torch
from torch import nn
from torchsummary import summary

import importlib

import matplotlib.pyplot as plt

sys.path.insert(0, "../src/")
import data
import model
import train

In [2]:
data_pars = dict(
                 # General parameters
                 td = 512, # Number of points
                 Fs = 10, # Sampling frequency
                 debug = False, # Print data generation details
    
                 # Peak parameters
                 pmin = 1, # Minimum number of Gaussians in a peak
                 pmax = 10, # Maximum number of Gaussians in a peak
                 ds = 0.03, # Spread of chemical shift values for each peak
                 lw = [1e-2, 1e-1], # Linewidth range for Gaussians
                 phase = 0., # Spread of phase
    
                 # Isotropic parameters
                 nmin = 1, # Minimum number of peaks
                 nmax = 10, # Maximum number of peaks
                 shift_range = [1., 9.], # Chemical shift range
                 positive = True, # Force the spectrum to be positive
                 
                 # MAS-dependent parameters
                 mas_g_range = [1e4, 1e5], # MAS-dependent Gaussian broadening range
                 mas_l_range = [1e4, 1e5], # MAS-dependent Lorentzian broadening range
                 mas_s_range = [-1e4, 1e4], # MAS-dependent shift range
                 mas_phase = 0.1, # Random phase range for MAS spectra
                 peakwise_phase = True, # Whether the phase should be peak-wise or spectrum-wise
                 encode_imag = False, # Encode the imaginary part of the MAS spectra
                 nw = 4, # Number of MAS rates
                 mas_w_range = [30000, 100000], # MAS rate range
                 random_mas = False,
                 encode_w = False, # Encode the MAS rate of the spectra
    
                 # Post-processing parameters
                 noise = 0., # Noise level
                 smooth_end_len = 10, # Smooth ends of spectra
                 scale_iso = 0.8, # Scale isotropic spectra
                 offset = 0., # Baseline offset
                 norm_wr = True, # Normalize MAS rate values
                 wr_inv = False # Encode inverse of MAS rate instead of MAS rate
                )

train_pars = dict(batch_size = 4, # Dataset batch size
                  num_workers = 8, # Number of parallel processes to generate data
                  checkpoint = 100, # Perform evaluation after that many batches
                  n_eval = 100, # Number of batches in the evaluation
                  max_checkpoints = 100, # Maximum number of checkpoints before finishing training
                  out_dir = "../data/PIPNet_2021_10_13/", # Output directory
                  change_factor = {50: 100., 90: 10.}, # Checkpoints where 
                  device = "cpu"
                 )

model_pars = dict(input_dim = 1,
                  hidden_dim = 64,
                  kernel_size = [1, 3, 5],
                  num_layers = 3,
                  final_kernel_size = 1,
                  batch_input = 4,
                  bias = True,
                  final_bias = True,
                  return_all_layers = False,
                  final_act = "sigmoid",
                 )
    
if not os.path.exists(train_pars["out_dir"]):
    os.mkdir(train_pars["out_dir"])

In [3]:
dataset = data.PIPDataset(**data_pars)

net = model.ConvLSTM(**model_pars).to(train_pars["device"])

opt = torch.optim.Adam(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

# L1 loss
#loss = model.CustomLoss(exp=1., offset=1., factor=0., out_factor=0.)
# L2 loss
#loss = model.CustomLoss(exp=2., offset=1., factor=0., out_factor=0.)
# Custom loss
loss = model.CustomLoss(exp=1., offset=1., factor=1000., out_factor=0.)

sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.5, patience=5)

# Train the model

In [4]:
train.train(dataset, net, opt, loss, sch, train_pars)

Starting training...
    Training batch  100: loss =  5.2151e+00, mean loss =  3.4729e+00, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch  100: loss =  2.4629e+00, mean loss =  3.3331e+00...
  End of evaluation.
    Training batch  200: loss =  2.1847e+00, mean loss =  3.1343e+00, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch  100: loss =  2.4950e+00, mean loss =  3.0372e+00...
  End of evaluation.
    Training batch  300: loss =  4.5865e-01, mean loss =  4.7897e-01, lr =  1.0000e-03...
  Checkpoint reached, evaluating the model...
    Validation batch  100: loss =  2.4724e-01, mean loss =  3.9583e-01...

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x12e679790>
Traceback (most recent call last):
  File "//anaconda3/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "//anaconda3/envs/torch/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1297, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "//anaconda3/envs/torch/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "//anaconda3/envs/torch/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "//anaconda3/envs/torch/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "//anaconda3/envs/torch/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 



  End of evaluation.
    Training batch  316: loss =  3.1203e-01, mean loss =  3.4638e-01, lr =  1.0000e-03...

KeyboardInterrupt: 