In [1]:
import numpy as np
import os
import sys
import copy

import torch
torch.set_num_threads=4
from torch import nn
import scipy as sp

import importlib
import pickle as pk

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as anm

sys.path.insert(0, "../src/")
import data
import model
import train

import nmrglue as ng
import scipy
import scipy.io

np.random.seed(123)

In [2]:
mod = "Ensemble_PIPNet_2022_02_01_5_layers"

in_dir = f"../data/{mod}/"
fig_dir = f"../figures/{mod}/"

sel_wrs = [30000., 32000., 34000., 36000., 38000.,
           40000., 42000., 44000., 46000., 48000.,
           50000., 52000., 54000., 56000., 58000.,
           60000., 62000., 64000., 66000., 68000.,
           70000., 72000., 74000., 76000., 78000., 80000.]

peaks = {"ampicillin": [[0.25, 0.7], [1.3, 1.7], [3.5, 4.1], [4.3, 4.7], [4.8, 5.2], [6.1, 6.6], [9.5, 10.5]],
         "aspala": [[0.5, 1.], [1.9, 2.2], [2.3, 3.], [3.7, 4.1], [4.5, 5.], [7., 7.5], [7.7, 8.2], [12.3, 12.8]],
         "flutamide": [[0.7, 1.5], [1.8, 2.2]],
         "histidine": [[2.3, 2.9], [2.8, 3.5], [4.8, 5.2], [7.2, 7.8], [8.6, 9.2], [12, 12.8], [16.7, 17.4]],
         "thymol": [[0.2, 0.7], [0.8, 1.3], [1.3, 1.8], [3, 3.5], [5, 5.7], [5.7, 6.4], [6.7, 7.2], [9, 9.5]],
         "tyrosine": [[2.3, 2.7], [6.4, 6.8], [9.7, 10.2]],
         "mdma": []}

evals = {"sel": True, "opt": True, "all": True, "high": True, "low": True, "rand": 0}
exp_dir = f"../data/experimental_spectra/topspin/4096/"
x_scales = [0.2]
exp_compounds = ["ampicillin", "aspala", "flutamide", "histidine", "thymol", "tyrosine", "mdma"]
exp_range = {"ampicillin": [1500, 2500],
             "aspala": [1500, 2500],
             "flutamide": [1500, 2500],
             "histidine": [1500, 2500],
             "thymol": [1500, 2500],
             "tyrosine": [1500, 2500],
             "mdma": [1600, 2400],
            }

int_regions = {"ampicillin": [[12., 8.5], [8.5, 3.], [3., -2.]],
               "aspala": [[15., 10.], [10., 6.], [6., 3.4], [3.4, 1.5], [1.5, 0.]],
               "flutamide": [[11., 9.], [9., 5.], [5., -2.]],
               "histidine": [[20., 15.], [15., 10.5], [10.5, 6.5], [6.5, 4.2], [4.2, -2.]],
               "thymol": [[12., 8.], [8., 4.5], [4.5, 2.7], [2.7, -2]],
               "tyrosine": [[14., 11.], [11., 9.], [9., 6.1], [6.1, 3.5], [3.5, 0.]],
               "mdma": [[15., 10.], [10., 7.], [7., 1.]],
              }

align_regions = {"ampicillin": [12., 9.],
                 "aspala": [15., 11.],
                 "flutamide": [11., 9.],
                 "histidine": [18., 16.],
                 "thymol": [6.7, 5.7],
                 "tyrosine": [14., 11.],
                 "mdma": [4., 2.],
                }

iso_dir = "../data/experimental_spectra/iso/"
exp_res = {"ampicillin": ["4k", "4k", "4k", "4k"],
           "aspala": ["4k", "4k", "4k", "4k", "4k"],
           "flutamide": ["2k", "2k", "2k", "2k"],
           "histidine": ["4k", "4k", "4k", "4k", "4k"],
           "thymol": ["4k", "4k", "4k", "4k"],
           "tyrosine": ["4k", "4k", "4k", "4k", "4k"],
           "mdma": []}
exp_parts = {'ampicillin': ['NH3', 'NHAr5', 'Ar6104b', 'Me2'],
             'aspala': ['OH', 'NHNH3', 'CHCH', 'CH2', 'CH3'],
             'flutamide': ['H5b', 'H38', 'H6', 'H101112'],
             'histidine': ['H5', 'H7', 'H618', 'H9', 'H342'],
             'thymol': ['H7', 'H321c', 'H4', 'H556'],
             'tyrosine': ['COOH', 'OH', 'NH3H76', 'H5823', 'H3dia'],
             "mdma": []}

In [3]:
if not os.path.exists(in_dir):
    raise ValueError(f"Unknown model: {mod}")
    
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)

In [4]:
def clean_split(l, delimiter):
    """
    Split a line with the desired delimiter, ignoring delimiters present in arrays or strings
    
    Inputs: - l     Input line
    
    Output: - ls    List of sub-strings making up the line
    """
    
    # Initialize sub-strings
    ls = []
    clean_l = ""
    
    # Loop over all line characters
    in_dq = False
    in_sq = False
    arr_depth = 0
    for li in l:
        # Identify strings with double quotes
        if li == "\"":
            if not in_dq:
                in_dq = True
            else:
                in_dq = False
        
        # Identify strings with single quotes
        if li == "\'":
            if not in_sq:
                in_sq = True
            else:
                in_sq = False
        
        # Identify arrays
        if li == "[":
            if not in_sq and not in_dq:
                arr_depth += 1
        if li == "]":
            if not in_sq and not in_dq:
                arr_depth -= 1
        
        # If the delimiter is not within quotes or in an array, split the line at that character
        if li == delimiter and not in_dq and not in_sq and arr_depth == 0:
            ls.append(clean_l)
            clean_l = ""
        else:
            clean_l += li
    
    ls.append(clean_l)
        
    return ls

In [5]:
def get_array(l):
    """
    Get the values in an array contained in a line
    
    Input:  - l         Input line
    
    Output: - vals      Array of values
    """
    
    # Identify empty array
    if l.strip() == "[]":
        return []
    
    # Initialize array
    vals = []
    clean_l = ""
    
    # Loop over all line characters
    arr_depth = 0
    for li in l:
    
        # Identify end of array
        if li == "]":
            arr_depth -= 1
            
            # Check that there are not too many closing brackets for the opening ones
            if arr_depth < 0:
                raise ValueError("Missing \"[\" for matching the number of \"]\"")
        
        # If we are within the array, extract the character
        if arr_depth > 0:
            clean_l += li
    
        # Identify start of array
        if li == "[":
            arr_depth += 1
    
    # Check that the array is properly closed at the end
    if arr_depth > 0:
        raise ValueError("Missing \"]\" for matching the number of \"[\"")
    
    # Extract elements in the array
    ls = clean_split(clean_l, ",")
    
    # Get the value of each element in the array
    for li in ls:
        vals.append(get_val(li.strip()))

    return vals

In [6]:
def get_val(val):
    
    # Remove tailing comma
    if val.endswith(","):
        val = val[:-1]
    
    # Float / Int
    if val.isnumeric():
        
        if "." in val:
            return float(val)
        else:
            return int(val)
    
    # Bool
    if val.lower() == "true":
        return True
    if val.lower() == "false":
        return False
    
    # String
    if val.startswith("\""):
        return val.split("\"")[1]
    
    # List
    if val.startswith("["):
        
        return get_array(val)
    
    # Try to return a float anyway
    return float(val)

In [7]:
# Get model architecture
with open(mod + ".py", "r") as F:
    lines = F.read().split("\n")

model_pars = {}
in_pars = False

# Parse script
for l in lines:
    
    # Identify model parameter block start
    if "model_pars = " in l:
        in_pars = True
    
    # Identify model parameter block end
    if l.strip() == ")":
        in_pars = False
    
    if in_pars:
        # Get line
        if "(" in l:
            L = l.split("(")[1].split("#")[0]
        else:
            L = l.strip().split("#")[0]
        
        key, val = L.split("=")
        
        v = get_val(val.strip())
        
        model_pars[key.strip()] = v

model_pars["noise"] = 0.

In [8]:
# Get data parameters
with open(mod + ".py", "r") as F:
    lines = F.read().split("\n")

data_pars = {}
in_pars = False

# Parse script
for l in lines:
    
    # Identify model parameter block start
    if "data_pars = " in l:
        in_pars = True
    
    # Identify model parameter block end
    if l.strip() == ")":
        in_pars = False
    
    if in_pars:
        # Get line
        if "(" in l:
            L = l.split("(")[1].split("#")[0]
        else:
            L = l.strip().split("#")[0]
        
        if "=" in L:
        
            key, val = L.split("=")
        
            v = get_val(val.strip())
        
            data_pars[key.strip()] = v

dataset = data.PIPDataset(**data_pars)

In [9]:
# Get loss parameters
with open(mod + ".py", "r") as F:
    lines = F.read().split("\n")

loss_pars = {}
in_pars = False

# Parse script
for l in lines:
    
    # Identify model parameter block start
    if "loss_pars = " in l:
        in_pars = True
    
    # Identify model parameter block end
    if l.strip() == ")":
        in_pars = False
    
    if in_pars:
        # Get line
        if "(" in l:
            L = l.split("(")[1].split("#")[0]
        else:
            L = l.strip().split("#")[0]
        
        if "=" in L:
        
            key, val = L.split("=")
        
            v = get_val(val.strip())
        
            loss_pars[key.strip()] = v
            
loss_components = []
loss_components_with_tot = ["Total"]

if "srp_w" in loss_pars and loss_pars["srp_w"] > 0.:
    loss_components.append("Sharp component")
    loss_components_with_tot.append("Sharp component")
if "brd_w" in loss_pars and loss_pars["brd_w"] > 0.:
    loss_components.append("Broad component")
    loss_components_with_tot.append("Broad component")
if "int_w" in loss_pars and loss_pars["int_w"] > 0.:
    loss_components.append("Integral component")
    loss_components_with_tot.append("Integral component")

In [10]:
# Load loss and learning rate
all_lrs = np.load(in_dir + "all_lrs.npy")
all_losses = np.load(in_dir + "all_losses.npy")
all_val_losses = np.load(in_dir + "all_val_losses.npy")

try:
    all_loss_components = np.load(in_dir + "all_loss_components.npy")
    all_val_loss_components = np.load(in_dir + "all_val_loss_components.npy")
    mean_loss_components = np.mean(all_loss_components, axis=1)
    mean_val_loss_components = np.mean(all_val_loss_components, axis=1)
except:
    all_loss_components = None
    all_val_loss_components = None
    mean_loss_components = None
    mean_val_loss_components = None

mean_losses = np.mean(all_losses, axis=1)
mean_val_losses = np.mean(all_val_losses, axis=1)

n_chk = all_losses.shape[0]
best_chk = np.argmin(mean_val_losses)
print(best_chk)

113


# Load model

In [11]:
net = model.ConvLSTMEnsemble(**model_pars)
net = net.eval()

# Evaluate experimental data

In [12]:
def load_topspin_spectrum(d):
    
    pd = f"{d}pdata/1/"
    
    fr = pd + "1r"
    fi = pd + "1i"

    with open(fr, "rb") as F:
        dr = np.fromfile(F, np.int32).astype(float)

    with open(fi, "rb") as F:
        di = np.fromfile(F, np.int32).astype(float)

    with open(f"{d}acqus", "r") as F:
        lines = F.read().split("\n")

    for l in lines:
        if l.startswith("##$MASR"):
            wr = int(l.split("=")[1].strip())
        if l.startswith("##$TD="):
            TD = int(l.split("=")[1].strip())
        if l.startswith("##$SW_h="):
            SW = float(l.split("=")[1].strip())

    with open(f"{pd}procs", "r") as F:
        lines = F.read().split("\n")

    for l in lines:
        if l.startswith("##$SI="):
            n_pts = int(l.split("=")[1].strip())

        if l.startswith("##$OFFSET="):
            offset = float(l.split("=")[1].strip())

        if l.startswith("##$SF="):
            SF = float(l.split("=")[1].strip())
            
    AQ = TD / (2 * SW)

    hz = offset * SF - np.arange(n_pts) / (2 * AQ * n_pts / TD)
    
    ppm = hz / SF

    return dr, di, wr, ppm, hz

In [13]:
def extract_exp_topspin(in_dir, compound):
    
    d0 = f"{in_dir}{compound}/"
    
    ws = []
    X = []
    
    for d in os.listdir(d0):
        if d.isnumeric():
            Xi, _, wr, ppm, hz = load_topspin_spectrum(f"{d0}{d}/")
            X.append(Xi)
            ws.append(wr)
    
    sorted_inds = np.argsort(ws)
    
    sorted_ws = np.array([ws[i] for i in sorted_inds])
    
    sorted_X = np.array([X[i] for i in sorted_inds])
    
    return ppm, sorted_ws, sorted_X

In [14]:
def extract_fit_model(in_dir, compound, parts, res):
    
    ys_part_means = []
    ys_part_stds = []
    ys_ppms = []
    
    if len(parts) == 0 or len(res) == 0:
        return [], [], []
    
    for p, n in zip(parts, res):
        
        
        d = f"{in_dir}{compound}_{n}/"
        
        if not os.path.exists(d):
            return [], [], []
    
        ys_part = []

        i_guess = 1
        while os.path.exists(f"{d}{compound[:3]}_{p}_guess_r{i_guess}.mat"):

            m = scipy.io.loadmat(f"{d}{compound[:3]}_{p}_guess_r{i_guess}.mat")

            ys_part.append(m["x"][:-3])
            ppm = m["ppm"][0, m["range"][0]]
            
            i_guess += 1
        
        if len(ys_ppms) > 0:
            already_ppms = np.concatenate(ys_ppms, axis=0)
            inds = np.where(np.logical_or(ppm < np.min(already_ppms), ppm > np.max(already_ppms)))[0]
        else:
            inds = range(len(ppm))
        
        ys_ppms.append(ppm[inds])
        ys_part = np.concatenate(ys_part, axis=1)
        ys_part_means.append(np.mean(ys_part, axis=1)[inds])
        ys_part_stds.append(np.std(ys_part, axis=1)[inds])
    
    ys_ppms = np.concatenate(ys_ppms, axis=0)
    ys_part_means = np.concatenate(ys_part_means, axis=0)
    ys_part_stds = np.concatenate(ys_part_stds, axis=0)
    
    return ys_ppms, ys_part_means, ys_part_stds

In [15]:
def make_input(X, ws, data_pars, x_max=0.25):
    
    inds = np.argsort(ws)
    X_torch = torch.Tensor(X[inds])
    X_torch = torch.unsqueeze(X_torch, dim=0)
    X_torch = torch.unsqueeze(X_torch, dim=2)
    
    X_torch /= torch.max(X_torch)
    X_torch *= x_max
    
    if data_pars["encode_w"]:
        W = torch.Tensor(ws[inds])
        W = torch.unsqueeze(W, dim=0)
        W = torch.unsqueeze(W, dim=2)
        W = torch.unsqueeze(W, dim=3)
        W = W.repeat(1, 1, 1, X_torch.shape[-1])
        
        if data_pars["norm_wr"]:
            W -= data_pars["mas_w_range"][0]
            W /= data_pars["mas_w_range"][1] - data_pars["mas_w_range"][0]
    
    X_torch = torch.cat([X_torch, W], dim=2)
    
    return X_torch, ws[inds]

In [16]:
def plot_exp(ppm, X, show=True, save=None, x_offset=0., xl=[20., -5.], c0=[0., 1., 1.], dc = [0., -1., 0.]):
    
    # Initialize figure
    fig = plt.figure(figsize=(4,3))
    ax = fig.add_subplot(1,1,1)
    
    n = X.shape[0]
    if n == 1:
        colors = [[ci + dci for ci, dci in zip(c0, dc)]]
        
    else:
        colors = [[ci + (dci * i / (n-1)) for ci, dci in zip(c0, dc)] for i in range(n)]
    
    try:
        X2 = np.copy(X.numpy())
    except:
        X2 = np.copy(X)
    
    X2 /= np.max(X2)
    
    # Plot inputs
    for i, (c, x) in enumerate(zip(colors, X2)):
        ax.plot(ppm, x + i * x_offset, color=c, linewidth=1)
    
    # Update axis
    ax.set_xlim(xl)
    ax.set_yticks([])
    ax.set_xlabel("Chemical shift [ppm]")
    
    # Cleanup layout
    fig.tight_layout()
    
    # Save figure
    if save:
        plt.savefig(save)
    
    # Show figure
    if show:
        plt.show()
        
    # Close figure
    plt.close()
    
    return

In [17]:
def plot_exp_vs_pred(ppm, X, y_pred, y_std, ppm_trg, y_trg_avg, y_trg_std, show=True, save=None, x_offset=0.,
                     y0_pred=0., y0_trg=0., y_pred_scale=0.5, y_trg_scale=0.5, reverse_trg=False, xl=[20., -5.], c0=[0., 1., 1.], dc = [0., -1., 0.]):
    
    # Initialize figure
    fig = plt.figure(figsize=(4,3))
    ax = fig.add_subplot(1,1,1)
    
    n = X.shape[0]
    if n == 1:
        colors = [[ci + dci for ci, dci in zip(c0, dc)]]
        
    else:
        colors = [[ci + (dci * i / (n-1)) for ci, dci in zip(c0, dc)] for i in range(n)]
    
    try:
        X2 = np.copy(X.numpy())
    except:
        X2 = np.copy(X)
    
    X2 /= np.max(X2)
    
    try:
        y_pred2 = y_pred.numpy()
        y_std2 = y_std.numpy()
    except:
        y_pred2 = y_pred
        y_std2 = y_std
    
    factor = np.max(y_pred2) / y_pred_scale
    y_pred2 /= factor
    y_std2 /= factor
    
    if len(y_trg_avg) > 0:
        factor = np.max(y_trg_avg) / y_trg_scale
        y_trg_avg2 = y_trg_avg / factor
        y_trg_std2 = y_trg_std / factor
    
    
    # Plot inputs
    for i, (c, x) in enumerate(zip(colors, X2)):
        ax.plot(ppm, x + i * x_offset, color=c, linewidth=1)
    
    # Plot predictions
    ax.plot(ppm, y_pred2 + y0_pred, "r", linewidth=1)
    ax.fill_between(ppm, y_pred2 - y_std2 + y0_pred, y_pred2 + y_std2 + y0_pred, color="r", alpha=0.3)
    
    if len(y_trg_avg) > 0:
        # Plot target
        if reverse_trg:
            ax.plot(ppm_trg, -y_trg_avg2 + y0_trg, "k", linewidth=1)
            ax.fill_between(ppm_trg, -y_trg_avg2 - y_trg_std2 + y0_trg, -y_trg_avg2 + y_trg_std2 + y0_trg, color="k", alpha=0.3)

        else:
            ax.plot(ppm_trg, y_trg_avg2 + y0_trg, "k", linewidth=1)
            ax.fill_between(ppm_trg, y_trg_av2g - y_trg_std2 + y0_trg, y_trg_avg2 + y_trg_std2 + y0_trg, color="k", alpha=0.3)

    # Update axis
    ax.set_xlim(xl)
    ax.set_yticks([])
    ax.set_xlabel("Chemical shift [ppm]")
    
    # Cleanup layout
    fig.tight_layout()
    
    # Save figure
    if save:
        plt.savefig(save)
    
    # Show figure
    if show:
        plt.show()
        
    # Close figure
    plt.close()
    
    return

In [18]:
def extract_linewidth(x, y, r, nfit=3):
    
    inds = np.where(np.logical_and(x > r[0], x < r[1]))[0]

    dx = np.mean(x[1:] - x[:-1])
    
    top = np.max(y[inds])
    
    i0 = np.argmax(y[inds])
    
    xr = None
    xl = None
    
    for i, j in zip(inds[:-1], inds[1:]):
        if y[i] > top / 2 and y[j] < top / 2:
            
            dy = y[j] - y[i]
            
            dy2 = (top / 2) - y[i]
            
            xr = x[i] + dx * dy2 / dy
            
        
        if y[i] < top / 2 and y[j] > top / 2:
            
            dy = y[j] - y[i]
            
            dy2 = (top / 2) - y[i]
            
            xl = x[i] + dx * dy2 / dy
    
    if xl is None:
        xl = x[inds[0]]
    if xr is None:
        xr = x[inds[-1]]
    
    return abs(xl - xr), x[inds[i0]]

In [19]:
def plot_lw(all_lws_fit, all_lws_net, all_pks_fit, all_pks_net, compounds, save):
    
    fig = plt.figure(figsize=(7,3))
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)

    for lws_fit, lws_net in zip(all_lws_fit, all_lws_net):
        ax1.scatter(lws_fit, lws_net, s=5)

    for pks_fit, pks_net in zip(all_pks_fit, all_pks_net):
        ax2.scatter(pks_fit, pks_net, s=5)

    ax1.set_xlabel("Fitted linewidth [ppm]")
    ax1.set_ylabel("PIPNet linewidth [ppm]")
    ax2.set_xlabel("Fitted peak [ppm]")
    ax2.set_ylabel("PIPNet peak [ppm]")

    ax2.legend(compounds)

    fig.tight_layout()

    plt.savefig(f"{save}_preds.pdf")
    plt.close()

    fig = plt.figure(figsize=(7,3))
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)

    for lws_fit, lws_net in zip(all_lws_fit, all_lws_net):
        ax1.scatter(lws_fit, lws_net - lws_fit, s=5)

    for pks_fit, pks_net in zip(all_pks_fit, all_pks_net):
        ax2.scatter(pks_fit, pks_net - pks_fit, s=5)

    ax1.set_xlabel("Fitted linewidth [ppm]")
    ax1.set_ylabel("linewidth difference [ppm]")
    ax2.set_xlabel("Fitted peak [ppm]")
    ax2.set_ylabel("peak difference [ppm]")

    lw_mae = np.mean(np.abs(np.concatenate(all_lws_fit) - np.concatenate(all_lws_net)))
    pk_mae = np.mean(np.abs(np.concatenate(all_pks_fit) - np.concatenate(all_pks_net)))

    ax1.title.set_text(f"MAE = {lw_mae:.2f} ppm")
    ax2.title.set_text(f"MAE = {pk_mae:.2f} ppm")

    ax2.legend(compounds)

    fig.tight_layout()

    plt.savefig(f"{save}_preds_diff.pdf")
    plt.close()
    
    return

In [20]:
def compare_integrals(ppm, X, y, regions):
    
    X_int = []
    y_int = []
    
    for r1, r2 in regions:
        p1 = min(max(1, np.argmin(np.abs(ppm - r1))), len(ppm) - 2)
        p2 = min(max(1, np.argmin(np.abs(ppm - r2))), len(ppm) - 2)
        
        if p1 > p2:
            tmp = p1
            p1 = p2
            p2 = tmp
        
        X_int.append(np.sum(X[p1:p2]))
        y_int.append(np.sum(y[p1:p2]))
        
    X_int = np.array(X_int) / np.sum(X_int)
    y_int = np.array(y_int) / np.sum(y_int)
    
    return X_int, y_int

In [21]:
def plot_integrals(all_X_int, all_y_int, compounds, int_regions, w=0.2, label_h=0.1, show=True, save=None):
    
    # Compound separations
    sep = []
    mid = []
    labels = []
    bounds = []
    i = 0
    for k in compounds:
        if k in int_regions:
            l = len(int_regions[k])
            labels.append(k)
            sep.append(i+l - 0.5)
            mid.append(i + (l / 2) - 0.5)
            i += l
            bounds.append([r[1] for r in int_regions[k][:-1]])
    sep = sep[:-1]
    
    err_avg = []
    err_std = []
    for xint, yint in zip(all_X_int, all_y_int):
        err_avg.append(np.mean(np.abs(xint-yint) / xint*100))
        err_std.append(np.std(np.abs(xint-yint) / xint*100))
    
    x = np.array(range(i))
    
    xint = np.concatenate(all_X_int)
    yint = np.concatenate(all_y_int)
    
    M = max(np.max(xint), np.max(yint)) * 1.1
    
    fig = plt.figure(figsize=(i*0.4,3))
    ax = fig.add_subplot(1,1,1)
    
    ax.bar(x-(w/2), xint, width=w)
    ax.bar(x+(w/2), yint, width=w)
    
    ax.legend(["100 kHz MAS", "PIPNet"], bbox_to_anchor=(0.,0.95), loc="upper left")
    
    for s in sep:
        ax.plot([s, s], [0., M], "k")
    
    lx = 0.5
    for b in bounds:
        for bi in b:
            ax.plot([lx, lx], [0, label_h], "k:")
            ax.text(lx, label_h, f" {bi} ppm", rotation=90, ha="center", va="bottom", size=8)
            lx += 1
        lx += 1
    
    for em, es, m in zip(err_avg, err_std, mid):
        ax.text(m, M*0.99, f"{em:.0f}±{es:.0f}% error", ha="center", va="top", size=8)
    
    ax.set_xticks(mid)
    ax.set_xticklabels(labels)
    
    ax.set_ylabel("Relative integral")
    
    ax.set_ylim(0., M)
    ax.set_xlim(-0.5, i-0.5)
    
    if save is not None:
        plt.savefig(save)
    
    if show:
        plt.show()
    
    plt.close()
    
    return

In [22]:
def get_maximum(ppm, X, r, method="direct"):
    
    r0 = min(r)
    r1 = max(r)
    
    if method == "direct":
        
        inds = np.where(np.logical_and(ppm > r0, ppm < r1))[0]
        
        i0 = np.argmax(X[inds])
        w0 = ppm[inds[i0]]
        
    elif method == "interp":
        
        inds = np.where(np.logical_and(ppm > r0, ppm < r1))[0]
        f = sp.interpolate.interp1d(range(len(inds)), X[inds], kind="cubic")
        
        x = np.linspace(0, len(inds)-1, 1001)
        x_ppm = ppm[inds[0]] + x * (ppm[inds[1]] - ppm[inds[0]])
        y = f(x)
        w0 = x_ppm[np.argmax(y)]
        
    else:
        raise ValueError(f"Unknown method: {method}")
    
    return w0

In [23]:
def shift_spectrum(hz, Xr, Xi, dw):
    
    n = Xr.shape[0]
    t = np.arange(n) / np.abs(hz[1] - hz[0]) / n
    
    X = Xr + 1j * Xi
    
    T = np.fft.ifft(X)
    
    T *= np.exp(1j*dw*t * 2 * np.pi)
    X = np.fft.fft(T)
    Xr = np.real(X)
    Xi = np.imag(X)
    
    return Xr, Xi

In [24]:
def extract_and_shift_exp_topspin(in_dir, compound, align_region, align_ind=-1, method="interp"):
    
    d0 = f"{in_dir}{compound}/"
    
    ws = []
    X_real = []
    X_imag = []
    for d in os.listdir(d0):
        if d.isnumeric():
            Xr, Xi, wr, ppm, hz = load_topspin_spectrum(f"{d0}{d}/")
            X_real.append(Xr)
            X_imag.append(Xi)
            ws.append(wr)
    
    sorted_inds = np.argsort(ws)
    
    sorted_ws = np.array([ws[i] for i in sorted_inds])
    
    sorted_Xr = np.array([X_real[i] for i in sorted_inds])
    sorted_Xi = np.array([X_imag[i] for i in sorted_inds])
    
    # Extract target shift
    align_region_hz = [hz[np.argmin(np.abs(ppm - align_region[0]))], hz[np.argmin(np.abs(ppm - align_region[1]))]]
    w0 = get_maximum(hz, sorted_Xr[align_ind], align_region_hz, method=method)
    w0_ppm = get_maximum(ppm, sorted_Xr[align_ind], align_region, method=method)
    
    shifted_Xr = []
    shifted_Xi = []
    all_dw = []
    all_dw_ppm = []
    
    all_w = []
    for Xr, Xi in zip(sorted_Xr, sorted_Xi):
        # Get actual shift
        this_w = get_maximum(hz, Xr, align_region_hz, method=method)
        this_w_ppm = get_maximum(ppm, Xr, align_region, method=method)
        dw = this_w - w0
        all_dw.append(this_w- w0)
        all_dw_ppm.append(this_w_ppm- w0_ppm)
        
        # Shift spectrum
        Xr2, Xi2 = shift_spectrum(hz, Xr, Xi, dw)
        
        shifted_Xr.append(Xr2)
        shifted_Xi.append(Xi2)
    
    return ppm, sorted_ws, np.array(shifted_Xr), np.array(shifted_Xi), np.array(all_dw), np.array(all_dw_ppm), sorted_Xr, sorted_Xi

In [25]:
align_ind = -1

# Load best model
net.load_state_dict(torch.load(in_dir + f"checkpoint_{best_chk+1}_network", map_location=torch.device("cpu")))

for xscale in x_scales:

    all_lws_fit = []
    all_pks_fit = []

    sel_all_lws_net = []
    sel_all_pks_net = []
    opt_all_lws_net = []
    opt_all_pks_net = []
    all_all_lws_net = []
    all_all_pks_net = []
    low_all_lws_net = []
    low_all_pks_net = []
    high_all_lws_net = []
    high_all_pks_net = []
    rand_all_lws_net = []
    rand_all_pks_net = []

    fdir = fig_dir + f"eval_exp_4k_scale_{xscale}_shifted_{align_ind}/"
    if not os.path.exists(fdir):
        os.mkdir(fdir)

    all_X_int = []
    all_y_int = []
    sel_X_int = []
    sel_y_int = []
    opt_X_int = []
    opt_y_int = []
    all_X_int = []
    all_y_int = []
    low_X_int = []
    low_y_int = []
    high_X_int = []
    high_y_int = []
    rand_X_int = []
    rand_y_int = []

    for compound in exp_compounds:
        r = exp_range[compound]
        print(compound)
        ys_ppms, ys_part_means, ys_part_stds = extract_fit_model(iso_dir, compound,
                                                                 exp_parts[compound], exp_res[compound])
        ppm, ws, X, _, all_dw, all_dw_ppm, X0, _ = extract_and_shift_exp_topspin(exp_dir, compound, align_regions[compound], align_ind=align_ind)

        if len(ys_part_means) > 0:
            ymax = np.max(ys_part_means)
            ys_part_means /= ymax / 0.5
            ys_part_stds /= ymax / 0.5

        ppm = ppm[r[0]:r[1]]
        X = X[:, r[0]:r[1]]

        X /= np.sum(X, axis=1)[:, np.newaxis]
        
        plot_exp(ppm, X, xl=align_regions[compound], show=False, save=f"{fdir}{compound}_exp.pdf")
        
        fig = plt.figure(figsize=(4,3))
        ax = fig.add_subplot(1,1,1)
        ax.plot(ws / 1000, all_dw)
        ax.set_xlabel("MAS rate [kHz]")
        ax.set_ylabel("Peak difference [Hz]")
        
        fig.tight_layout()
        plt.savefig(f"{fdir}{compound}_dw_hz.pdf")
        plt.close()
        
        fig = plt.figure(figsize=(4,3))
        ax = fig.add_subplot(1,1,1)
        ax.plot(ws / 1000, all_dw_ppm)
        ax.set_xlabel("MAS rate [kHz]")
        ax.set_ylabel("Peak difference [ppm]")
        
        fig.tight_layout()
        plt.savefig(f"{fdir}{compound}_dw_ppm.pdf")
        plt.close()

        X_torch, ws = make_input(X, ws, data_pars)

        lws_fit = []
        pks_fit = []

        for p in peaks[compound]:

            lw, pk = extract_linewidth(ys_ppms, ys_part_means, p)

            lws_fit.append(lw)
            pks_fit.append(pk)

        all_lws_fit.append(np.array(lws_fit))
        all_pks_fit.append(np.array(pks_fit))

        # Selected rates
        if evals["sel"]:

            print("  Predictions on selected MAS rates...")

            w_inds = []
            for w in sel_wrs:
                w_inds.append(np.argmin(np.abs(ws - w)))

            X_net = X_torch[:, w_inds]
            X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

            y_pred, y_std, ys = net(X_net)

            y_pred = y_pred.detach().numpy()[0]
            y_std = y_std.detach().numpy()[0]

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_sel_w_pred_{i+1}.pdf")

            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_sel_w_pred_final.pdf")

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            sel_all_lws_net.append(np.array(lws_net))
            sel_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            sel_X_int.append(X_int)
            sel_y_int.append(y_int)

        # Optimized rates
        if evals["opt"]:

            print("  Predictions on optimal MAS rates...")

            opt_wrs = np.linspace(data_pars["mas_w_range"][0], data_pars["mas_w_range"][1], num=data_pars["nw"])

            w_inds = []
            for w in opt_wrs:
                w_inds.append(np.argmin(np.abs(ws - w)))

            X_net = X_torch[:, w_inds]
            X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

            y_pred, y_std, ys = net(X_net)

            y_pred = y_pred.detach().numpy()[0]
            y_std = y_std.detach().numpy()[0]

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_opt_w_pred_{i+1}.pdf")

            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_opt_w_pred_final.pdf")

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            opt_all_lws_net.append(np.array(lws_net))
            opt_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            opt_X_int.append(X_int)
            opt_y_int.append(y_int)

        # All rates
        if evals["all"]:

            print("  Predictions on all MAS rates...")

            w_inds = np.argsort(ws)

            X_net = X_torch[:, w_inds]
            X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

            y_pred, y_std, ys = net(X_net)

            y_pred = y_pred.detach().numpy()[0]
            y_std = y_std.detach().numpy()[0]

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_all_w_pred_{i+1}.pdf") 

            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_all_w_pred_final.pdf") 

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            all_all_lws_net.append(np.array(lws_net))
            all_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            all_X_int.append(X_int)
            all_y_int.append(y_int)

        # Lowest rates
        if evals["low"]:

            print("  Predictions on lowest MAS rates...")

            X_net = X_torch[:, :data_pars["nw"]]
            X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

            y_pred, y_std, ys = net(X_net)

            y_pred = y_pred.detach().numpy()[0]
            y_std = y_std.detach().numpy()[0]

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_low_w_pred_{i+1}.pdf")

            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_low_w_pred_final.pdf")

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            low_all_lws_net.append(np.array(lws_net))
            low_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            low_X_int.append(X_int)
            low_y_int.append(y_int)

        # Highest rates
        if evals["high"]:

            print("  Predictions on highest MAS rates...")

            X_net = X_torch[:, -data_pars["nw"]:]
            X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

            y_pred, y_std, ys = net(X_net)

            y_pred = y_pred.detach().numpy()[0]
            y_std = y_std.detach().numpy()[0]

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_high_w_pred_{i+1}.pdf")


            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_high_w_pred_final.pdf")

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            high_all_lws_net.append(np.array(lws_net))
            high_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            high_X_int.append(X_int)
            high_y_int.append(y_int)

        # Randomly selected rates
        if evals["rand"] > 0:

            print("  Predictions on randomly selected MAS rates...")

            all_ys = []
            for k in range(evals["rand"]):

                print(f"    Selection {k+1}/{evals['rand']}...")

                w_inds = np.sort(np.random.choice(range(X.shape[0]), size=data_pars["nw"], replace=False))

                X_net = X_torch[:, w_inds]
                X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

                y_pred, y_std, ys = net(X_net)

                y_pred = y_pred.detach().numpy()[0]
                y_std = y_std.detach().numpy()[0]

                ymax = np.max(y_pred)
                y_pred /= ymax / 0.5

                y_std /= ymax / 0.5

                all_ys.append(ys.detach().numpy()[:, 0])

                for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                    plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                     y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                     save=f"{fdir}{compound}_rand_w_{k+1}_pred_{i+1}.pdf")

                plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_rand_w_{k+1}_pred_final.pdf")

            ys = np.concatenate(all_ys, axis=0)
            y_pred = np.mean(ys, axis=0)
            y_std = np.std(ys, axis=0)

            ymax = np.max(y_pred)
            y_pred /= ymax / 0.5

            y_std /= ymax / 0.5

            for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                plot_exp_vs_pred(ppm, X[-1:], yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                                 y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                                 save=f"{fdir}{compound}_rands_pred_{i+1}.pdf")

            plot_exp_vs_pred(ppm, X, yi_pred, yi_std, ys_ppms, ys_part_means, ys_part_stds, x_offset=0.1,
                             y0_pred=-0.1, y0_trg=-0.1, reverse_trg=True, show=False,
                             save=f"{fdir}{compound}_rands_pred_final.pdf")

            lws_net = []
            pks_net = []

            for p in peaks[compound]:

                lw, pk = extract_linewidth(ppm, yi_pred, p)

                lws_net.append(lw)
                pks_net.append(pk)

            rand_all_lws_net.append(np.array(lws_net))
            rand_all_pks_net.append(np.array(pks_net))

            X_int, y_int = compare_integrals(ppm, X[-1], y_pred[-1], int_regions[compound])
            rand_X_int.append(X_int)
            rand_y_int.append(y_int)


    if len(sel_all_lws_net) > 0:
        plot_lw(all_lws_fit, sel_all_lws_net, all_pks_fit, sel_all_pks_net, exp_compounds, f"{fdir}sel")

    if len(opt_all_lws_net) > 0:
        plot_lw(all_lws_fit, opt_all_lws_net, all_pks_fit, opt_all_pks_net, exp_compounds, f"{fdir}opt")

    if len(all_all_lws_net) > 0:
        plot_lw(all_lws_fit, all_all_lws_net, all_pks_fit, all_all_pks_net, exp_compounds, f"{fdir}all")

    if len(low_all_lws_net) > 0:
        plot_lw(all_lws_fit, low_all_lws_net, all_pks_fit, low_all_pks_net, exp_compounds, f"{fdir}low")

    if len(high_all_lws_net) > 0:
        plot_lw(all_lws_fit, high_all_lws_net, all_pks_fit, high_all_pks_net, exp_compounds, f"{fdir}high")

    if len(rand_all_lws_net) > 0:
        plot_lw(all_lws_fit, rand_all_lws_net, all_pks_fit, rand_all_pks_net, exp_compounds, f"{fdir}rand")

    if len(sel_y_int) > 0:
        plot_integrals(sel_X_int, sel_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}sel_integrals.pdf")

    if len(opt_y_int) > 0:
        plot_integrals(opt_X_int, opt_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}opt_integrals.pdf")

    if len(all_y_int) > 0:
        plot_integrals(all_X_int, all_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}all_integrals.pdf")

    if len(low_y_int) > 0:
        plot_integrals(low_X_int, low_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}low_integrals.pdf")

    if len(high_y_int) > 0:
        plot_integrals(high_X_int, high_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}high_integrals.pdf")

    if len(rand_y_int) > 0:
        plot_integrals(rand_X_int, rand_y_int, exp_compounds, int_regions, show=False, save=f"{fdir}rand_integrals.pdf")

ampicillin
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
  Predictions on all MAS rates...
  Predictions on lowest MAS rates...
  Predictions on highest MAS rates...
aspala
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
  Predictions on all MAS rates...
  Predictions on lowest MAS rates...
  Predictions on highest MAS rates...
flutamide
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
  Predictions on all MAS rates...
  Predictions on lowest MAS rates...
  Predictions on highest MAS rates...
histidine
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
  Predictions on all MAS rates...
  Predictions on lowest MAS rates...
  Predictions on highest MAS rates...
thymol
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
  Predictions on all MAS rates...
  Predictions on lowest MAS rates...
  Predictions on highest MAS rates...
tyrosine
  Predictions on