In [1]:
import numpy as np
import os
import sys
import copy

import torch
from torch import nn

import importlib
import pickle as pk

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as anm

sys.path.insert(0, "../src/")
import data
import model
import train

import scipy
import scipy.io

np.random.seed(123)

In [2]:
mod = "final_model"

in_dir = f"../data/{mod}/"
fig_dir = f"../figures/{mod}/"

eval_exp_4k = True
evals_4k = {"sel": True, "opt": True, "all": True, "high": True, "low": True, "rand": 0}
exp_dir_4k = f"../data/experimental_spectra/topspin/4096/"
x_scales_4k = [0.2, 0.5]
x_scales_4k = [0.5]
exp_compounds_4k = ["aspala", "aspala_lb_100", "aspala_lb_250", "aspala_lb_500"]
exp_range_4k = {"aspala": [1500, 2500],
                "aspala_lb_100": [1500, 2500],
                "aspala_lb_250": [1500, 2500],
                "aspala_lb_500": [1500, 2500],
               }

opt_range = [40000., 100000.]
dw = 2000.
sel_wrs = [40000., 42000., 44000., 46000., 48000.,
           50000., 52000., 54000., 56000., 58000.,
           60000., 62000., 64000., 66000., 68000.,
           70000., 72000., 74000., 76000., 78000., 80000.]

In [3]:
if not os.path.exists(in_dir):
    raise ValueError(f"Unknown model: {mod}")
    
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)

In [4]:
with open(in_dir + "PIPNet_model_pars.pk", "rb") as F:
    model_pars, data_pars = pk.load(F)

In [5]:
net = model.ConvLSTMEnsemble(**model_pars)
net.load_state_dict(torch.load(in_dir + f"PIPNet_model.weights", map_location=torch.device("cpu")))
net.eval()
print("Model loaded!")

Model loaded!


# Evaluate experimental data

In [6]:
def load_topspin_spectrum(d):
    
    pd = f"{d}pdata/1/"
    
    fr = pd + "1r"
    fi = pd + "1i"
    
    ti = pd + "title"

    with open(fr, "rb") as F:
        dr = np.fromfile(F, np.int32).astype(float)

    with open(fi, "rb") as F:
        di = np.fromfile(F, np.int32).astype(float)
        
    # Get MAS rate from title
    wr_found = False
    with open(ti, "r") as F:
        lines = F.read().split("\n")
    for l in lines:
        if "KHZ" in l.upper():
            wr = float(l.upper().split("KHZ")[0].split()[-1]) * 1000
            wr_found = True
        elif "HZ" in l.upper():
            wr = float(l.upper().split("HZ")[0].split()[-1])
            wr_found = True

    with open(f"{d}acqus", "r") as F:
        lines = F.read().split("\n")

    for l in lines:
        if l.startswith("##$MASR") and not wr_found:
            try:
                wr = int(l.split("=")[1].strip())
            except:
                wr = -1. 
        if l.startswith("##$TD="):
            TD = int(l.split("=")[1].strip())
        if l.startswith("##$SW_h="):
            SW = float(l.split("=")[1].strip())

    with open(f"{pd}procs", "r") as F:
        lines = F.read().split("\n")

    for l in lines:
        if l.startswith("##$SI="):
            n_pts = int(l.split("=")[1].strip())

        if l.startswith("##$OFFSET="):
            offset = float(l.split("=")[1].strip())

        if l.startswith("##$SF="):
            SF = float(l.split("=")[1].strip())
            
    AQ = TD / (2 * SW)

    hz = offset * SF - np.arange(n_pts) / (2 * AQ * n_pts / TD)
    
    ppm = hz / SF

    return dr, di, wr, ppm, hz

In [7]:
def extract_exp_topspin(in_dir, compound):
    
    d0 = f"{in_dir}{compound}/"
    
    ws = []
    X_real = []
    X_imag = []
    
    for d in os.listdir(d0):
        if d.isnumeric():
            Xr, Xi, wr, ppm, hz = load_topspin_spectrum(f"{d0}{d}/")
            X_real.append(Xr)
            X_imag.append(Xi)
            ws.append(wr)
    
    sorted_inds = np.argsort(ws)
    
    sorted_ws = np.array([ws[i] for i in sorted_inds])
    
    sorted_X_real = np.array([X_real[i] for i in sorted_inds])
    sorted_X_imag = np.array([X_imag[i] for i in sorted_inds])
    
    return ppm, sorted_ws, sorted_X_real, sorted_X_imag

In [8]:
def extract_fit_model(in_dir, compound, parts, res):
    
    if compound == "mdma":
        c = compound
    elif compound == "molnupiravir":
        c = "molnu"
    else:
        c = compound[:3]
    
    ys_part_means = []
    ys_part_stds = []
    ys_ppms = []
    
    if len(parts) == 0 or len(res) == 0:
        return [], [], []
    
    for p, n in zip(parts, res):
        
        
        d = f"{in_dir}{compound}_{n}/"
        
        if not os.path.exists(d):
            return [], [], []
    
        ys_part = []

        i_guess = 1
        while os.path.exists(f"{d}{c}_{p}_guess_r{i_guess}.mat"):

            m = scipy.io.loadmat(f"{d}{c}_{p}_guess_r{i_guess}.mat")

            ys_part.append(m["x"][:-3])
            ppm = m["ppm"][0, m["range"][0]]
            
            i_guess += 1
        
        if len(ys_ppms) > 0:
            already_ppms = np.concatenate(ys_ppms, axis=0)
            inds = np.where(np.logical_or(ppm < np.min(already_ppms), ppm > np.max(already_ppms)))[0]
        else:
            inds = range(len(ppm))
        
        ys_ppms.append(ppm[inds])
        ys_part = np.concatenate(ys_part, axis=1)
        ys_part_means.append(np.mean(ys_part, axis=1)[inds])
        ys_part_stds.append(np.std(ys_part, axis=1)[inds])
    
    ys_ppms = np.concatenate(ys_ppms, axis=0)
    ys_part_means = np.concatenate(ys_part_means, axis=0)
    ys_part_stds = np.concatenate(ys_part_stds, axis=0)
    
    return ys_ppms, ys_part_means, ys_part_stds

In [9]:
def make_input(X, ws, data_pars, Xi=None, x_max=0.25):
    
    inds = np.argsort(ws)
    X_torch = torch.Tensor(X[inds])
    X_torch = torch.unsqueeze(X_torch, dim=0)
    X_torch = torch.unsqueeze(X_torch, dim=2)
    
    M = torch.max(X_torch)
    
    X_torch /= M
    X_torch *= x_max
    
    if data_pars["encode_w"]:
        W = torch.Tensor(ws[inds])
        W = torch.unsqueeze(W, dim=0)
        W = torch.unsqueeze(W, dim=2)
        W = torch.unsqueeze(W, dim=3)
        W = W.repeat(1, 1, 1, X_torch.shape[-1])
        
        if data_pars["norm_wr"]:
            W -= data_pars["mas_w_range"][0]
            W /= data_pars["mas_w_range"][1] - data_pars["mas_w_range"][0]
    
    if Xi is not None:
        X_imag = torch.Tensor(Xi[inds])
        X_imag = torch.unsqueeze(X_imag, dim=0)
        X_imag = torch.unsqueeze(X_imag, dim=2)
        X_imag /= M
        X_imag *= x_max
        
        X_torch = torch.cat([X_torch, X_imag, W], dim=2)
    
    else:
        X_torch = torch.cat([X_torch, W], dim=2)
    
    return X_torch, ws[inds]

In [10]:
def plot_exp_vs_pred(ppm, X, y_pred, y_std, show=True, save=None, x_offset=0.,
                     y0_pred=0., y_pred_scale=0.5, xl=[20., -5.], c0=[0., 1., 1.], dc = [0., -1., 0.]):
    
    # Initialize figure
    fig = plt.figure(figsize=(4,3))
    ax = fig.add_subplot(1,1,1)
    
    n = X.shape[0]
    if n == 1:
        colors = [[ci + dci for ci, dci in zip(c0, dc)]]
        
    else:
        colors = [[ci + (dci * i / (n-1)) for ci, dci in zip(c0, dc)] for i in range(n)]
    
    try:
        X2 = np.copy(X.numpy())
    except:
        X2 = np.copy(X)
    
    X2 /= np.max(X2)
    
    try:
        y_pred2 = y_pred.numpy()
        y_std2 = y_std.numpy()
    except:
        y_pred2 = y_pred
        y_std2 = y_std
    
    factor = np.max(y_pred2) / y_pred_scale
    y_pred2 /= factor
    y_std2 /= factor
    
    # Plot inputs
    for i, (c, x) in enumerate(zip(colors, X2)):
        ax.plot(ppm, x + i * x_offset, color=c, linewidth=1)
    
    # Plot predictions
    ax.plot(ppm, y_pred2 + y0_pred, "r", linewidth=1)
    ax.fill_between(ppm, y_pred2 - y_std2 + y0_pred, y_pred2 + y_std2 + y0_pred, color="r", alpha=0.3)
    
    # Update axis
    ax.set_xlim(xl)
    ax.set_yticks([])
    ax.set_xlabel("Chemical shift [ppm]")
    
    # Cleanup layout
    fig.tight_layout()
    
    # Save figure
    if save:
        plt.savefig(save)
    
    # Show figure
    if show:
        plt.show()
        
    # Close figure
    plt.close()
    
    return

In [11]:
def extract_linewidth(x, y, r, nfit=3):
    
    inds = np.where(np.logical_and(x > r[0], x < r[1]))[0]

    dx = np.mean(x[1:] - x[:-1])
    
    top = np.max(y[inds])
    
    i0 = np.argmax(y[inds])
    
    xr = None
    xl = None
    
    for i, j in zip(inds[:-1], inds[1:]):
        if y[i] > top / 2 and y[j] < top / 2:
            
            dy = y[j] - y[i]
            
            dy2 = (top / 2) - y[i]
            
            xr = x[i] + dx * dy2 / dy
            
        
        if y[i] < top / 2 and y[j] > top / 2:
            
            dy = y[j] - y[i]
            
            dy2 = (top / 2) - y[i]
            
            xl = x[i] + dx * dy2 / dy
    
    if xl is None:
        xl = x[inds[0]]
    if xr is None:
        xr = x[inds[-1]]
    
    return abs(xl - xr), x[inds[i0]]

In [12]:
def plot_lw(all_lws_fit, all_lws_net, all_pks_fit, all_pks_net, compounds, save):
    
    fig = plt.figure(figsize=(8,3))
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)

    for lws_fit, lws_net in zip(all_lws_fit, all_lws_net):
        ax1.scatter(lws_fit, lws_net, s=5)

    for pks_fit, pks_net in zip(all_pks_fit, all_pks_net):
        ax2.scatter(pks_fit, pks_net, s=5)

    ax1.set_xlabel("Fitted linewidth [ppm]")
    ax1.set_ylabel("PIPNet linewidth [ppm]")
    ax2.set_xlabel("Fitted peak [ppm]")
    ax2.set_ylabel("PIPNet peak [ppm]")

    ax2.legend(compounds, bbox_to_anchor=(1., 1.))

    fig.tight_layout()

    plt.savefig(f"{save}_preds.pdf")
    plt.close()

    fig = plt.figure(figsize=(8,3))
    ax1 = fig.add_subplot(1,2,1)
    ax2 = fig.add_subplot(1,2,2)

    for lws_fit, lws_net in zip(all_lws_fit, all_lws_net):
        ax1.scatter(lws_fit, lws_net - lws_fit, s=5)

    for pks_fit, pks_net in zip(all_pks_fit, all_pks_net):
        ax2.scatter(pks_fit, pks_net - pks_fit, s=5)

    ax1.set_xlabel("Fitted linewidth [ppm]")
    ax1.set_ylabel("linewidth difference [ppm]")
    ax2.set_xlabel("Fitted peak [ppm]")
    ax2.set_ylabel("peak difference [ppm]")

    lw_mae = np.mean(np.abs(np.concatenate(all_lws_fit) - np.concatenate(all_lws_net)))
    pk_mae = np.mean(np.abs(np.concatenate(all_pks_fit) - np.concatenate(all_pks_net)))

    ax1.title.set_text(f"MAE = {lw_mae:.2f} ppm")
    ax2.title.set_text(f"MAE = {pk_mae:.2f} ppm")

    ax2.legend(compounds, bbox_to_anchor=(1., 1.))

    fig.tight_layout()

    plt.savefig(f"{save}_preds_diff.pdf")
    plt.close()
    
    return

In [13]:
def compare_integrals(ppm, X, y, regions):
    
    X_int = []
    y_int = []
    
    for r1, r2 in regions:
        p1 = min(max(1, np.argmin(np.abs(ppm - r1))), len(ppm) - 2)
        p2 = min(max(1, np.argmin(np.abs(ppm - r2))), len(ppm) - 2)
        
        if p1 > p2:
            tmp = p1
            p1 = p2
            p2 = tmp
        
        X_int.append(np.sum(X[p1:p2]))
        y_int.append(np.sum(y[p1:p2]))
        
    X_int = np.array(X_int) / np.sum(X_int)
    y_int = np.array(y_int) / np.sum(y_int)
    
    return X_int, y_int

In [14]:
def plot_integrals(all_X_int, all_y_int, compounds, int_regions, w=0.2, label_h=0.1, show=True, save=None):
    
    # Compound separations
    sep = []
    mid = []
    labels = []
    bounds = []
    i = 0
    for k in compounds:
        if k in int_regions:
            l = len(int_regions[k])
            labels.append(k)
            sep.append(i+l - 0.5)
            mid.append(i + (l / 2) - 0.5)
            i += l
            bounds.append([r[1] for r in int_regions[k][:-1]])
    sep = sep[:-1]
    
    err_avg = []
    err_std = []
    for xint, yint in zip(all_X_int, all_y_int):
        err_avg.append(np.mean(np.abs(xint-yint) / xint*100))
        err_std.append(np.std(np.abs(xint-yint) / xint*100))
    
    x = np.array(range(i))
    
    xint = np.concatenate(all_X_int)
    yint = np.concatenate(all_y_int)
    
    M = max(np.max(xint), np.max(yint)) * 1.1
    
    fig = plt.figure(figsize=(i*0.4,3))
    ax = fig.add_subplot(1,1,1)
    
    ax.bar(x-(w/2), xint, width=w)
    ax.bar(x+(w/2), yint, width=w)
    
    ax.legend(["100 kHz MAS", "PIPNet"], bbox_to_anchor=(0.,0.95), loc="upper left")
    
    for s in sep:
        ax.plot([s, s], [0., M], "k")
    
    lx = 0.5
    for b in bounds:
        for bi in b:
            ax.plot([lx, lx], [0, label_h], "k:")
            ax.text(lx, label_h, f" {bi} ppm", rotation=90, ha="center", va="bottom", size=8)
            lx += 1
        lx += 1
    
    for em, es, m in zip(err_avg, err_std, mid):
        ax.text(m, M*0.99, f"{em:.0f}±{es:.0f}% error", ha="center", va="top", size=8)
    
    ax.set_xticks(mid)
    ax.set_xticklabels(labels)
    
    ax.set_ylabel("Relative integral")
    
    ax.set_ylim(0., M)
    ax.set_xlim(-0.5, i-0.5)
    
    if save is not None:
        plt.savefig(save)
    
    if show:
        plt.show()
    
    plt.close()
    
    return

In [16]:
if eval_exp_4k:
        
    for xscale in x_scales_4k:
        
        fdir = fig_dir + f"eval_lb/"
        if not os.path.exists(fdir):
            os.mkdir(fdir)
        
        for compound in exp_compounds_4k:
            r = exp_range_4k[compound]
            print(compound)
            ppm, ws, X_real, X_imag = extract_exp_topspin(exp_dir_4k, compound)

            ppm = ppm[r[0]:r[1]]
            X_real = X_real[:, r[0]:r[1]]
            X_imag = X_imag[:, r[0]:r[1]]
            
            normalization = np.sum(X_real, axis=1)[:, np.newaxis]
    
            X_real /= normalization
            X_imag /= normalization
        
            if data_pars["encode_imag"]:
                X_torch, ws = make_input(X_real, ws, data_pars, Xi=X_imag)
            else:
                X_torch, ws = make_input(X_real, ws, data_pars)

            # Selected rates
            if evals_4k["sel"]:
                
                print("  Predictions on selected MAS rates...")

                w_inds = []
                for w in sel_wrs:
                    w_inds.append(np.argmin(np.abs(ws - w)))

                X_net = X_torch[:, w_inds]
                X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

                y_pred, y_std, ys = net(X_net)

                y_pred = y_pred.detach().numpy()[0]
                y_std = y_std.detach().numpy()[0]

                ymax = np.max(y_pred)
                y_pred /= ymax / 0.5

                y_std /= ymax / 0.5
                
                for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                    plot_exp_vs_pred(ppm, X_real[-1:], yi_pred, yi_std, x_offset=0.1,
                                     y0_pred=-0.1, show=False,
                                     save=f"{fdir}{compound}_sel_w_pred_{i+1}.pdf")
                
                plot_exp_vs_pred(ppm, X_real, yi_pred, yi_std, x_offset=0.1,
                                 y0_pred=-0.1, show=False,
                                 save=f"{fdir}{compound}_sel_w_pred_final.pdf")
                
            # Optimized rates
            if evals_4k["opt"]:
                
                print("  Predictions on optimal MAS rates...")
                
                opt_wrs = np.linspace(opt_range[0], opt_range[1], num=int((opt_range[1]-opt_range[0]) / dw)+1)
                
                w_inds = []
                for w in opt_wrs:
                    w_inds.append(np.argmin(np.abs(ws - w)))
                w_inds = np.unique(w_inds)
                print(ws[w_inds])

                X_net = X_torch[:, w_inds]
                X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

                y_pred, y_std, ys = net(X_net)

                y_pred = y_pred.detach().numpy()[0]
                y_std = y_std.detach().numpy()[0]

                ymax = np.max(y_pred)
                y_pred /= ymax / 0.5

                y_std /= ymax / 0.5
                
                for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                    plot_exp_vs_pred(ppm, X_real[-1:], yi_pred, yi_std, x_offset=0.1,
                                     y0_pred=-0.1, show=False,
                                     save=f"{fdir}{compound}_opt_w_pred_{i+1}.pdf")
                    
                plot_exp_vs_pred(ppm, X_real, yi_pred, yi_std, x_offset=0.1,
                                 y0_pred=-0.1, show=False,
                                 save=f"{fdir}{compound}_opt_w_pred_final.pdf")
                    
            # All rates
            if evals_4k["all"]:
                
                print("  Predictions on all MAS rates...")
                
                w_inds = np.argsort(ws)

                X_net = X_torch[:, w_inds]
                X_net[:, :, 0] /= torch.max(X_net[:, :, 0]) / xscale

                y_pred, y_std, ys = net(X_net)

                y_pred = y_pred.detach().numpy()[0]
                y_std = y_std.detach().numpy()[0]

                ymax = np.max(y_pred)
                y_pred /= ymax / 0.5

                y_std /= ymax / 0.5

                for i, (yi_pred, yi_std) in enumerate(zip(y_pred, y_std)):
                    plot_exp_vs_pred(ppm, X_real[-1:], yi_pred, yi_std, x_offset=0.1,
                                     y0_pred=-0.1, show=False,
                                     save=f"{fdir}{compound}_all_w_pred_{i+1}.pdf") 
                    
                plot_exp_vs_pred(ppm, X_real, yi_pred, yi_std, x_offset=0.1,
                                 y0_pred=-0.1, show=False,
                                 save=f"{fdir}{compound}_all_w_pred_final.pdf")

aspala
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
[ 40000.  42000.  44000.  46000.  48000.  50000.  52000.  54000.  56000.
  58000.  60000.  62000.  64000.  66000.  68000.  70000.  72000.  74000.
  76000.  78000.  80000.  82000.  84000.  86000.  88000.  90000.  92000.
  94000.  96000.  98000. 100000.]
  Predictions on all MAS rates...
aspala_lb_100
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
[ 40000.  42000.  44000.  46000.  48000.  50000.  52000.  54000.  56000.
  58000.  60000.  62000.  64000.  66000.  68000.  70000.  72000.  74000.
  76000.  78000.  80000.  82000.  84000.  86000.  88000.  90000.  92000.
  94000.  96000.  98000. 100000.]
  Predictions on all MAS rates...
aspala_lb_250
  Predictions on selected MAS rates...
  Predictions on optimal MAS rates...
[ 40000.  42000.  44000.  46000.  48000.  50000.  52000.  54000.  56000.
  58000.  60000.  62000.  64000.  66000.  68000.  70000.  72000.  74000.
  76000.  7800