In [1]:
import numpy as np
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"

import sys
import copy

import mkl
mkl.set_num_threads(4)

import torch
mkl.set_num_threads(4)
torch.set_num_threads=4
from torch import nn
from torchsummary import summary

import importlib

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as anm

sys.path.insert(0, "../src/")
import data
import model
import train

import nmrglue as ng
import scipy
import scipy.io

np.random.seed(123)

In [2]:
def clean_split(l, delimiter):
    """
    Split a line with the desired delimiter, ignoring delimiters present in arrays or strings
    
    Inputs: - l     Input line
    
    Output: - ls    List of sub-strings making up the line
    """
    
    # Initialize sub-strings
    ls = []
    clean_l = ""
    
    # Loop over all line characters
    in_dq = False
    in_sq = False
    arr_depth = 0
    for li in l:
        # Identify strings with double quotes
        if li == "\"":
            if not in_dq:
                in_dq = True
            else:
                in_dq = False
        
        # Identify strings with single quotes
        if li == "\'":
            if not in_sq:
                in_sq = True
            else:
                in_sq = False
        
        # Identify arrays
        if li == "[":
            if not in_sq and not in_dq:
                arr_depth += 1
        if li == "]":
            if not in_sq and not in_dq:
                arr_depth -= 1
        
        # If the delimiter is not within quotes or in an array, split the line at that character
        if li == delimiter and not in_dq and not in_sq and arr_depth == 0:
            ls.append(clean_l)
            clean_l = ""
        else:
            clean_l += li
    
    ls.append(clean_l)
        
    return ls

In [3]:
def get_array(l):
    """
    Get the values in an array contained in a line
    
    Input:  - l         Input line
    
    Output: - vals      Array of values
    """
    
    # Identify empty array
    if l.strip() == "[]":
        return []
    
    # Initialize array
    vals = []
    clean_l = ""
    
    # Loop over all line characters
    arr_depth = 0
    for li in l:
    
        # Identify end of array
        if li == "]":
            arr_depth -= 1
            
            # Check that there are not too many closing brackets for the opening ones
            if arr_depth < 0:
                raise ValueError("Missing \"[\" for matching the number of \"]\"")
        
        # If we are within the array, extract the character
        if arr_depth > 0:
            clean_l += li
    
        # Identify start of array
        if li == "[":
            arr_depth += 1
    
    # Check that the array is properly closed at the end
    if arr_depth > 0:
        raise ValueError("Missing \"]\" for matching the number of \"[\"")
    
    # Extract elements in the array
    ls = clean_split(clean_l, ",")
    
    # Get the value of each element in the array
    for li in ls:
        vals.append(get_val(li.strip()))

    return vals

In [4]:
def get_val(val):
    
    # Remove tailing comma
    if val.endswith(","):
        val = val[:-1]
    
    # Float / Int
    if val.isnumeric():
        
        if "." in val:
            return float(val)
        else:
            return int(val)
    
    # Bool
    if val.lower() == "true":
        return True
    if val.lower() == "false":
        return False
    
    # String
    if val.startswith("\""):
        return val.split("\"")[1]
    
    # List
    if val.startswith("["):
        
        return get_array(val)
    
    # Try to return a float anyway
    return float(val)

In [5]:
def plot_checkpoint(X, y_trg, y_pred, y_std, ys=None, input_factor=0.5, show=True, save=None, anim=False, all_steps=False,
                    c_start=[0., 1., 1.], c_stop=[0., -1., 0.], c2_start=[1., 1., 0.], c2_stop=[0., -1., 0.]):
    
    if y_std is None:
        y_std = [[] for _ in range(X.shape[0])]
    
    n_pts = X.shape[-1]
    
    n = X.shape[1] - 1
    
    colors = [[c0 + (i / n) * c1 for c0, c1 in zip(c_start, c_stop)] for i in range(n+1)]
    
    if ys is not None:
        n2 = ys.shape[0] - 1
        colors2 = [[c0 + (i / n2) * c1 for c0, c1 in zip(c2_start, c2_stop)] for i in range(n2+1)]
    
    for i, (Xi, yi_trg, yi_pred, yi_std) in enumerate(zip(X, y_trg, y_pred, y_std)):
        
        for k in range(yi_pred.shape[0]):
            
            if all_steps or k == yi_pred.shape[0] - 1:

                # Print target vs. pred
                fig = plt.figure(figsize=(4,3))
                ax = fig.add_subplot(1,1,1)

                ax.plot(yi_trg[0], linewidth=1.)
                ax.plot(yi_pred[k], linewidth=1.)

                if len(yi_std) > 0:
                    ax.fill_between(range(n_pts), yi_pred[k] - yi_std[k], yi_pred[k] + yi_std[k], facecolor="C1", alpha=0.3)

                ax.legend(["Ground-truth", "Prediction"])

                fig.tight_layout()
                
                if show:
                    plt.show()

                if save is not None:
                    plt.savefig(save + f"sample_{i+1}_pred_step_{k+1}.pdf")

                plt.close()
        
        if anim:
            
            fig = plt.figure(figsize=(4,3))
            ax = fig.add_subplot(1,1,1)
            trg, = ax.plot(range(n_pts), yi_trg[0], lw=1.)
            pred, = ax.plot(range(n_pts), yi_pred[0], lw=1.)
            ax.fill_between(range(n_pts), yi_pred[0] - yi_std[0], yi_pred[0] + yi_std[0], facecolor="C1", alpha=0.3)

            ax.legend(["Ground-truth", f"Prediction {1:02}"])
            
            fig.tight_layout()

            def init():
                ax.set_ylim(-0.1, 1.)
                ax.set_xlim(-1, n_pts)
                return fig,
            
            def update(k):
                ax.collections.clear()
                pred.set_data(range(n_pts), yi_pred[k])
                ax.fill_between(range(n_pts), yi_pred[k] - yi_std[k], yi_pred[k] + yi_std[k], facecolor="C1", alpha=0.3)
                ax.legend(["Ground-truth", f"Prediction {k+1:02}"])
                return fig,
            
            A = anm.FuncAnimation(fig, update, init_func=init, frames=yi_pred.shape[0], interval=250, blit=True)
            
            if save is not None:
                A.save(save + f"sample_{i+1}_pred_anim.gif", dpi=300)
            
            plt.close()
        
        if ys is not None:
            
            for k in range(ys.shape[2]):
                # Print target vs. all preds
                fig = plt.figure(figsize=(4,3))
                ax = fig.add_subplot(1,1,1)

                h1 = ax.plot(yi_trg[0], linewidth=1.)

                for c, yi in zip(colors2, ys[:, i, k, :]):
                    ax.plot(yi, linewidth=1., color=c)

                h2 = ax.plot(yi_pred[k], linewidth=1., color="k")

                ax.legend([h1[0], h2[0]], ["Ground-truth", "Prediction"], loc=0)

                fig.tight_layout()
                
                if show:
                    plt.show()

                if save is not None:
                    plt.savefig(save + f"sample_{i+1}_all_preds_step_{k+1}.pdf")

                plt.close()
            
            for j in range(ys.shape[0]):
                
                for k in range(ys.shape[2]):

                    # Print target vs. all preds
                    fig = plt.figure(figsize=(4,3))
                    ax = fig.add_subplot(1,1,1)

                    h1 = ax.plot(yi_trg[0], linewidth=1.)

                    ax.plot(ys[j, i, k, :], linewidth=1., color=colors2[j])

                    h2 = ax.plot(yi_pred[k], linewidth=1., color="k")

                    ax.legend([h1[0], h2[0]], ["Ground-truth", "Prediction"])

                    fig.tight_layout()
                    
                    if show:
                        plt.show()

                    if save is not None:
                        plt.savefig(save + f"sample_{i+1}_pred_step_{k+1}_model_{j+1}.pdf")

                    plt.close()
        
        # Print input and target
        fig = plt.figure(figsize=(4,3))
        ax = fig.add_subplot(1,1,1)
        
        for x, c in zip(Xi[:, 0, :], colors):
            ax.plot(x / np.max(Xi[:, 0, :]), linewidth=1., color=c)
        ax.plot(yi_trg[0] / np.max(yi_trg[0]) * input_factor, "r", linewidth=1.)
        
        fig.tight_layout()
        
        if show:
            plt.show()
            
        if save is not None:
            plt.savefig(save + f"sample_{i+1}_input.pdf")
            
        plt.close()
    
    return

In [6]:
loss = model.CustomLoss(brd_w=1., int_w=1., int_exp=1., return_components=True)

data_pars = dict(
                 # General parameters
                 td = 256, # Number of points
                 Fs = 12800, # Sampling frequency
                 debug = False, # Print data generation details

                 # Peak parameters
                 pmin = 1, # Minimum number of Gaussians in a peak
                 pmax = 1, # Maximum number of Gaussians in a peak
                 ds = 0.03, # Spread of chemical shift values for each peak
                 lw = [[5e1, 2e2], [1e2, 1e3]], # Linewidth range for Gaussians
                 iso_p = [0.9, 0.1],
                 iso_p_peakwise = True,
                 iso_int = [0.5, 1.], # Intensity
                 phase = 0., # Spread of phase

                 # Isotropic parameters
                 nmin = 1, # Minimum number of peaks
                 nmax = 15, # Maximum number of peaks
                 shift_range = [2000., 10000.], # Chemical shift range
                 positive = True, # Force the spectrum to be positive

                 # MAS-dependent parameters
                 mas_g_range = [[1e10, 1e11], [1e10, 5e11]], # MAS-dependent Gaussian broadening range
                 mas_l_range = [[1e7, 1e8], [1e7, 5e8]], # MAS-dependent Lorentzian broadening range
                 mas_s_range = [[-1e7, 1e7], [-1e7, 1e7]], # MAS-dependent shift range
                 mas_p = [0.5, 0.5],
                 mas_phase = 0.1, # Random phase range for MAS spectra
                 peakwise_phase = True, # Whether the phase should be peak-wise or spectrum-wise
                 encode_imag = False, # Encode the imaginary part of the MAS spectra
                 nw = 24, # Number of MAS rates
                 mas_w_range = [30000, 100000], # MAS rate range
                 random_mas = True,
                 encode_w = True, # Encode the MAS rate of the spectra

                 # Post-processing parameters
                 noise = 0., # Noise level
                 smooth_end_len = 10, # Smooth ends of spectra
                 iso_norm = 256., # Normalization factor for peaks
                 brd_norm = 64., # Normalization factor for MAS spectra
                 offset = 0., # Baseline offset
                 norm_wr = True, # Normalize MAS rate values
                 wr_inv = False # Encode inverse of MAS rate instead of MAS rate
                )

n_samples = 16

dataset = data.PIPDataset(**data_pars)

X = []
y = []

for i in range(n_samples):
    Xi, _, yi = dataset.__getitem__(0)
    X.append(Xi.unsqueeze(0))
    y.append(yi.unsqueeze(0))

X = torch.cat(X, dim=0)
y = torch.cat(y, dim=0)

In [7]:
models = ["Ensemble_PIPNet_2021_12_09_w_001", "Ensemble_PIPNet_2021_12_09_w_01",
          "Ensemble_PIPNet_2021_12_07_w_1", "Ensemble_PIPNet_2021_12_07_w_10",
          "Ensemble_PIPNet_2021_12_10_w_1", "Ensemble_PIPNet_2021_12_10_w_10", "Ensemble_PIPNet_2021_11_30_more_mas"]

name = "24_mas"

noises = [0., 1e-5, 1e-4, 1e-3]

fig_dir = f"../figures/model_comparison/"

for noise in noises:

    # Load networks
    losses = []
    all_losses = []
    for mod in models:

        print(f"Model {mod}")

        in_dir = f"../data/{mod}/"

        if not os.path.exists(in_dir):
            raise ValueError(f"Unknown model: {mod}")

        # Get model architecture
        with open(mod + ".py", "r") as F:
            lines = F.read().split("\n")

        model_pars = {}
        in_pars = False

        # Parse script
        for l in lines:

            # Identify model parameter block start
            if "model_pars = " in l:
                in_pars = True

            # Identify model parameter block end
            if l.strip() == ")":
                in_pars = False

            if in_pars:
                # Get line
                if "(" in l:
                    L = l.split("(")[1].split("#")[0]
                else:
                    L = l.strip().split("#")[0]

                key, val = L.split("=")

                v = get_val(val.strip())

                model_pars[key.strip()] = v

        model_pars["noise"] = noise

        # Load loss to get best model
        all_val_losses = np.load(in_dir + "all_val_losses.npy")
        mean_val_losses = np.mean(all_val_losses, axis=1)
        best_chk = np.argmin(mean_val_losses)

        net = model.ConvLSTMEnsemble(**model_pars)
        net.eval()
        net.load_state_dict(torch.load(in_dir + f"checkpoint_{best_chk+1}_network", map_location=torch.device("cpu")))

        print("  Prediction...")
        y_pred, y_std, ys = net(X)
        y_trg = y.repeat((1, y_pred.shape[-2], 1))

        print("  Evaluation...")
        _, components = loss(y_pred, y_trg)

        losses.append(components)

        these_losses = []
        for y_tmp in ys:
            _, components = loss(y_tmp, y_trg)
            these_losses.append(components)

        all_losses.append(these_losses)

        print("  Done!")

    losses = np.array(losses)
    all_losses = np.array(all_losses)

    if not os.path.exists(fig_dir):
        os.mkdir(fig_dir)
        
    for i, l in enumerate(losses.T):
        fig = plt.figure(figsize=(4,5))
        ax = fig.add_subplot(1,1,1)

        ax.plot(l)
        ax.set_xticks(range(len(models)))
        ax.set_xticklabels(models, rotation=60, ha="right")

        ax.set_ylabel("Loss")
    
        fig.tight_layout()

        plt.savefig(f"{fig_dir}{name}_noise_{noise}_loss_component_{i}.pdf")

        plt.show()
        plt.close()
    
    mean_losses = np.mean(all_losses, axis=1)
    std_losses = np.std(all_losses, axis=1)
    
    for i, (m, s) in enumerate(zip(mean_losses.T, std_losses.T)):
        fig = plt.figure(figsize=(4,5))
        ax = fig.add_subplot(1,1,1)

        ax.errorbar(range(len(m)), m, s, capsize=5)
        ax.set_xticks(range(len(models)))
        ax.set_xticklabels(models, rotation=60, ha="right")

        ax.set_ylabel("Loss")
    
        fig.tight_layout()

        plt.savefig(f"{fig_dir}{name}_noise_{noise}_losses_component_{i}.pdf")

        plt.show()
        plt.close()

Model Ensemble_PIPNet_2021_12_09_w_001
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_12_09_w_01
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_12_07_w_1
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_12_07_w_10
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_12_10_w_1
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_12_10_w_10
  Prediction...
  Evaluation...
  Done!
Model Ensemble_PIPNet_2021_11_30_more_mas
  Prediction...
  Evaluation...
  Done!


ValueError: Format '0_loss_component_0' is not supported (supported formats: eps, jpeg, jpg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff)