In [1]:
import numpy as np
import os
import sys
import copy

import torch
torch.set_num_threads(os.cpu_count())
from torch import nn

import importlib
import json

import matplotlib as mpl
import matplotlib.pyplot as plt

from pipnet import data
from pipnet import model
from pipnet import utils

#import nmrglue as ng
import scipy
import scipy.io

device = "cuda" if torch.cuda.is_available() else "cpu"

np.random.seed(123)

In [2]:
model_name = "PIPNet_2022_11_21_2_layers"
model_dir = f"../../data/2D/{model_name}/"
fig_dir = f"../../figures/2D/{model_name}/"

epoch = 250

sel_wrs = [
    50000., 51000., 52000., 53000., 54000.,
    55000., 56000., 57000., 58000., 59000.,
    60000., 61000., 62000., 63000., 64000.,
    65000., 66000., 67000., 68000., 69000.,
    70000., 71000., 72000., 73000., 74000.,
    75000., 76000., 77000., 78000., 79000.,
    80000., 81000., 82000., 83000., 84000.,
    85000., 86000., 87000., 88000., 89000.,
    90000., 91000., 92000., 93000., 94000.,
    95000., 96000., 97000., 98000., 99000., 100000.
]

dw_max = 500.

ppm_range_x = [-5, 15.]
ppm_range_y = [-10., 20.]

compounds = ["20220216_ampicillin_baba_vmas_9", "20220215_tyrosine_baba_vmas_9"]

exp_dir = "../../data/experimental_spectra/2D/"

fields = {
    "20220216_ampicillin_baba_vmas_9": 800.,
    "20220215_tyrosine_baba_vmas_9": 800.,
}

peaks = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

peaks2 = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

int_regions = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

In [3]:
if not os.path.exists(model_dir):
    raise ValueError(f"Unknown model: {model_name}")
    
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)
    
fdir = fig_dir + "eval_experimental/"
if not os.path.exists(fdir):
    os.mkdir(fdir)

In [4]:
with open(model_dir + "model_pars.json", "r") as F:
    model_pars = json.load(F)
model_pars["noise"] = 0.

with open(model_dir + "data_pars.json", "r") as F:
    data_pars = json.load(F)

net = model.ConvLSTMEnsemble(**model_pars)
net.load_state_dict(torch.load(model_dir + f"epoch_{epoch}_network", map_location=device))
net = net.eval()

In [5]:
%%time

for compound in compounds:
    
    print(compound)
    
    ppm_x, ppm_y, hz_x, hz_y, ws, xrr, xri, xir, xii = utils.extract_2d_dataset(f"{exp_dir}{compound}/", 1, 30, load_imag=data_pars["encode_imag"])

    inds_x = np.where(np.logical_and(ppm_x >= ppm_range_x[0], ppm_x <= ppm_range_x[1]))[0]
    inds_y = np.where(np.logical_and(ppm_y >= ppm_range_y[0], ppm_y <= ppm_range_y[1]))[0]

    ppm_x = ppm_x[inds_x]
    hz_x = hz_x[inds_x]
    xrr = xrr[:, :, inds_x]
    if data_pars["encode_imag"]:
        xri = xri[:, :, inds_x]
        xir = xir[:, :, inds_x]
        xii = xii[:, :, inds_x]

    ppm_y = ppm_y[inds_y]
    hz_y = hz_y[inds_y]
    xrr = xrr[:, inds_y, :]
    if data_pars["encode_imag"]:
        xri = xri[:, inds_y, :]
        xir = xir[:, inds_y, :]
        xii = xii[:, inds_y, :]

    X = utils.prepare_2d_input(xrr, ws, data_pars, xri=xri, xir=xir, xii=xii, xmax=1.)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max ]
    X = X[:, wr_inds]

    with torch.no_grad():
        y_pred, _, _ = net(X)
    y_pred = y_pred.numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}{compound}.pdf"
    )
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_min=0.01,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}{compound}_low_contour.pdf"
    )
    
    output = {}
    output["ppm_x"] = ppm_x
    output["ppm_y"] = ppm_y
    output["X"] = X[0, :, 0].numpy()
    output["wr"] = ws[wr_inds]
    output["pred"] = y_pred[0, -1]
    
    scipy.io.savemat(f"{fdir}{compound}_preds.mat", output)

20220216_ampicillin_baba_vmas_9
20220215_tyrosine_baba_vmas_9
CPU times: user 16min 10s, sys: 2min 41s, total: 18min 51s
Wall time: 6min 26s


# Predict with 1D model

In [6]:
model_name_1d = "PIPNet_model"
model_dir_1d = f"../../trained_models/{model_name_1d}/"

with open(f"{model_dir_1d}model_pars.json", "r") as F:
    model_pars_1d = json.load(F)
model_pars_1d["noise"] = 0.

with open(f"{model_dir_1d}data_pars.json", "r") as F:
    data_pars_1d = json.load(F)

batch_size_1d = 64

In [7]:
net_1d = model.ConvLSTMEnsemble(**model_pars_1d).to(device)
net_1d.load_state_dict(torch.load(model_dir_1d + f"network", map_location=torch.device(device)))
net_1d = net_1d.eval()

In [8]:
%%time

for compound in compounds:
    
    print(compound)
    
    ppm_x, ppm_y, hz_x, hz_y, ws, xrr, xri, xir, xii = utils.extract_2d_dataset(f"{exp_dir}{compound}/", 1, 30, load_imag=data_pars["encode_imag"])

    inds_x = np.where(np.logical_and(ppm_x >= ppm_range_x[0], ppm_x <= ppm_range_x[1]))[0]
    inds_y = np.where(np.logical_and(ppm_y >= ppm_range_y[0], ppm_y <= ppm_range_y[1]))[0]

    ppm_x = ppm_x[inds_x]
    hz_x = hz_x[inds_x]
    xrr = xrr[:, :, inds_x]
    if data_pars["encode_imag"]:
        xri = xri[:, :, inds_x]
        xir = xir[:, :, inds_x]
        xii = xii[:, :, inds_x]

    ppm_y = ppm_y[inds_y]
    hz_y = hz_y[inds_y]
    xrr = xrr[:, inds_y, :]
    if data_pars["encode_imag"]:
        xri = xri[:, inds_y, :]
        xir = xir[:, inds_y, :]
        xii = xii[:, inds_y, :]

    X = utils.prepare_2d_input(xrr, ws, data_pars, xri=xri, xir=xir, xii=xii, xmax=0.5)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max ]
    X = X[:, wr_inds]
    
    X_rows = X.transpose(0, 3).squeeze()
    y_rows = []
    n_rows = X_rows.shape[0] // batch_size_1d + 1
    for i in range(n_rows):
        print(f"{i+1} / {n_rows}")
        with torch.no_grad():
            y_pred, _, _ = net_1d(X_rows[i*batch_size_1d:(i+1)*batch_size_1d])
        y_rows.append(y_pred)
    
    y_pred = torch.cat(y_rows)
    y_pred = y_pred.unsqueeze(2).transpose(0, 2).numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}rows_{compound}.pdf"
    )
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_min=0.01,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}rows_{compound}_low_contour.pdf"
    )
    
    output = {}
    output["ppm_x"] = ppm_x
    output["ppm_y"] = ppm_y
    output["X"] = X[0, :, 0].numpy()
    output["wr"] = ws[wr_inds]
    output["pred"] = y_pred[0, -1]
    
    scipy.io.savemat(f"{fdir}{compound}_preds_rows.mat", output)

20220216_ampicillin_baba_vmas_9
1 / 13
2 / 13
3 / 13
4 / 13
5 / 13
6 / 13
7 / 13
8 / 13
9 / 13
10 / 13
11 / 13
12 / 13
13 / 13
20220215_tyrosine_baba_vmas_9
1 / 17
2 / 17
3 / 17
4 / 17
5 / 17
6 / 17
7 / 17
8 / 17
9 / 17
10 / 17
11 / 17
12 / 17
13 / 17
14 / 17
15 / 17
16 / 17
17 / 17
CPU times: user 3h 44min 16s, sys: 30min 17s, total: 4h 14min 34s
Wall time: 59min 17s


In [9]:
%%time

for compound in compounds:
    
    print(compound)
    
    ppm_x, ppm_y, hz_x, hz_y, ws, xrr, xri, xir, xii = utils.extract_2d_dataset(f"{exp_dir}{compound}/", 1, 30, load_imag=data_pars["encode_imag"])

    inds_x = np.where(np.logical_and(ppm_x >= ppm_range_x[0], ppm_x <= ppm_range_x[1]))[0]
    inds_y = np.where(np.logical_and(ppm_y >= ppm_range_y[0], ppm_y <= ppm_range_y[1]))[0]

    ppm_x = ppm_x[inds_x]
    hz_x = hz_x[inds_x]
    xrr = xrr[:, :, inds_x]
    if data_pars["encode_imag"]:
        xri = xri[:, :, inds_x]
        xir = xir[:, :, inds_x]
        xii = xii[:, :, inds_x]

    ppm_y = ppm_y[inds_y]
    hz_y = hz_y[inds_y]
    xrr = xrr[:, inds_y, :]
    if data_pars["encode_imag"]:
        xri = xri[:, inds_y, :]
        xir = xir[:, inds_y, :]
        xii = xii[:, inds_y, :]

    X = utils.prepare_2d_input(xrr, ws, data_pars, xri=xri, xir=xir, xii=xii, xmax=0.5)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max ]
    X = X[:, wr_inds]
    
    X_cols = X.transpose(0, 4).squeeze()
    y_cols = []
    n_cols = X_cols.shape[0] // batch_size_1d + 1
    
    for i in range(n_cols):
        print(f"{i+1} / {n_cols}")
        with torch.no_grad():
            y_pred, _, _ = net_1d(X_cols[i*batch_size_1d:(i+1)*batch_size_1d])
        y_cols.append(y_pred)
    
    y_pred = torch.cat(y_cols)
    y_pred = y_pred.unsqueeze(3).transpose(0, 3).numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}cols_{compound}.pdf"
    )
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_min=0.01,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.1,
        save=f"{fdir}cols_{compound}_low_contour.pdf"
    )
    
    output = {}
    output["ppm_x"] = ppm_x
    output["ppm_y"] = ppm_y
    output["X"] = X[0, :, 0].numpy()
    output["wr"] = ws[wr_inds]
    output["pred"] = y_pred[0, -1]
    
    scipy.io.savemat(f"{fdir}{compound}_preds_cols.mat", output)

20220216_ampicillin_baba_vmas_9
1 / 13
2 / 13
3 / 13
4 / 13
5 / 13
6 / 13
7 / 13
8 / 13
9 / 13
10 / 13
11 / 13
12 / 13
13 / 13
20220215_tyrosine_baba_vmas_9
1 / 13
2 / 13
3 / 13
4 / 13
5 / 13
6 / 13
7 / 13
8 / 13
9 / 13
10 / 13
11 / 13
12 / 13
13 / 13
CPU times: user 3h 47min 56s, sys: 30min 39s, total: 4h 18min 35s
Wall time: 56min 47s


In [10]:
compounds2 = ["ampicillin_unshear_gandl_220307", "tyrosine_unshear_gls_220302_0toInf_0to5"]

In [11]:
%%time

for compound in compounds2:
    
    print(compound)
    
    m = scipy.io.loadmat(f"{exp_dir}{compound}.mat")
    xrr = np.transpose(m["newsignal2"], axes=(2, 0, 1))
    ws = m["wr"][0].astype(float)

    X = utils.prepare_2d_input(xrr, ws, data_pars, xmax=1.)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max]
    X = X[:, wr_inds]

    with torch.no_grad():
        y_pred, _, _ = net(X)
    y_pred = y_pred.numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        lw=0.1,
        save=f"{fdir}{compound}.pdf"
    )
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_min=0.01,
        level_neg=True,
        equal_axes=False,
        lw=0.1,
        save=f"{fdir}{compound}_low_contour.pdf"
    )
    
    output = {}
    output["ppm_x"] = ppm_x
    output["ppm_y"] = ppm_y
    output["X"] = X[0, :, 0].numpy()
    output["wr"] = ws[wr_inds]
    output["pred"] = y_pred[0, -1]
    
    scipy.io.savemat(f"{fdir}{compound}_preds.mat", output)

ampicillin_unshear_gandl_220307
tyrosine_unshear_gls_220302_0toInf_0to5
CPU times: user 1min 8s, sys: 5.67 s, total: 1min 14s
Wall time: 31.5 s
