In [1]:
import numpy as np
import os
import sys
import copy

import torch
torch.set_num_threads(os.cpu_count())
from torch import nn

import importlib
import json

import matplotlib as mpl
import matplotlib.pyplot as plt

from pipnet import data
from pipnet import model
from pipnet import utils

#import nmrglue as ng
import scipy
import scipy.io

device = "cuda" if torch.cuda.is_available() else "cpu"

np.random.seed(123)

In [2]:
model_name = "PIPNet_2022_11_15_2_layers"
model_dir = f"../../data/2D/{model_name}/"
fig_dir = f"../../figures/2D/{model_name}/"

epoch = 250

sel_wrs = [
    50000., 51000., 52000., 53000., 54000.,
    55000., 56000., 57000., 58000., 59000.,
    60000., 61000., 62000., 63000., 64000.,
    65000., 66000., 67000., 68000., 69000.,
    70000., 71000., 72000., 73000., 74000.,
    75000., 76000., 77000., 78000., 79000.,
    80000., 81000., 82000., 83000., 84000.,
    85000., 86000., 87000., 88000., 89000.,
    90000., 91000., 92000., 93000., 94000.,
    95000., 96000., 97000., 98000., 99000., 100000.
]

dw_max = 500.

ppm_range_x = [-2, 12.]
ppm_range_y = [-7., 20.]

compounds = ["20220216_ampicillin_baba_vmas_9", "20220215_tyrosine_baba_vmas_9"]

exp_dir = "../../data/experimental_spectra/2D/"

fields = {
    "20220216_ampicillin_baba_vmas_9": 800.,
    "20220215_tyrosine_baba_vmas_9": 800.,
}

peaks = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

peaks2 = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

int_regions = {
    "20220216_ampicillin_baba_vmas_9": [],
    "20220215_tyrosine_baba_vmas_9": [],
}

In [3]:
if not os.path.exists(model_dir):
    raise ValueError(f"Unknown model: {model_name}")
    
if not os.path.exists(fig_dir):
    os.mkdir(fig_dir)
    
fdir = fig_dir + "eval_experimental/"
if not os.path.exists(fdir):
    os.mkdir(fdir)

In [4]:
with open(f"{model_dir}model_pars.json", "r") as F:
    model_pars = json.load(F)
model_pars["noise"] = 0.

with open(f"{model_dir}data_pars.json", "r") as F:
    data_pars = json.load(F)

In [5]:
net = model.ConvLSTMEnsemble(**model_pars).to(device)
net.load_state_dict(torch.load(model_dir + f"epoch_{epoch}_network", map_location=torch.device(device)))
net = net.eval()

In [6]:
%%time

for compound in compounds:
    
    print(compound)
    
    ppm_x, ppm_y, hz_x, hz_y, ws, xrr, xri, xir, xii = utils.extract_2d_dataset(f"{exp_dir}{compound}/", 1, 30, load_imag=data_pars["encode_imag"])

    inds_x = np.where(np.logical_and(ppm_x >= ppm_range_x[0], ppm_x <= ppm_range_x[1]))[0]
    inds_y = np.where(np.logical_and(ppm_y >= ppm_range_y[0], ppm_y <= ppm_range_y[1]))[0]

    ppm_x = ppm_x[inds_x]
    hz_x = hz_x[inds_x]
    xrr = xrr[:, :, inds_x]
    if data_pars["encode_imag"]:
        xri = xri[:, :, inds_x]
        xir = xir[:, :, inds_x]
        xii = xii[:, :, inds_x]

    ppm_y = ppm_y[inds_y]
    hz_y = hz_y[inds_y]
    xrr = xrr[:, inds_y, :]
    if data_pars["encode_imag"]:
        xri = xri[:, inds_y, :]
        xir = xir[:, inds_y, :]
        xii = xii[:, inds_y, :]

    X = utils.prepare_2d_input(xrr, ws, data_pars, xri=xri, xir=xir, xii=xii, xmax=0.5)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max ]
    X = X[:, wr_inds]

    with torch.no_grad():
        y_pred, _, _ = net(X)
    y_pred = y_pred.numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        xvals=ppm_x,
        yvals=ppm_y,
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        xinv=True,
        yinv=True,
        lw=0.5,
        save=f"{fdir}sel_wr_{compound}.pdf"
    )

20220216_ampicillin_baba_vmas_9
20220215_tyrosine_baba_vmas_9
CPU times: user 14min 20s, sys: 2min 27s, total: 16min 48s
Wall time: 3min 10s


In [7]:
compounds2 = ["ampicillin_unshear_gandl_220307", "tyrosine_unshear_gls_220302_0toInf_0to5"]

In [8]:
%%time

for compound in compounds2:
    
    print(compound)
    
    m = scipy.io.loadmat(f"{exp_dir}{compound}.mat")
    xrr = np.transpose(m["newsignal2"], axes=(2, 0, 1))
    ws = m["wr"][0].astype(float)

    X = utils.prepare_2d_input(xrr, ws, data_pars, xmax=0.5)

    wr_inds = [np.argmin(np.abs(ws - w)) for w in sel_wrs if np.min(np.abs(ws - w)) < dw_max]
    X = X[:, wr_inds]

    with torch.no_grad():
        y_pred, _, _ = net(X)
    y_pred = y_pred.numpy()
    
    utils.plot_2d_iso_prediction(
        X[0].numpy(),
        y_pred[0],
        wr_factor=data_pars["wr_norm_factor"],
        all_steps=True,
        show=False,
        level_neg=True,
        equal_axes=False,
        save=f"{fdir}sel_wr_{compound}.pdf"
    )

ampicillin_unshear_gandl_220307
tyrosine_unshear_gls_220302_0toInf_0to5
CPU times: user 48.3 s, sys: 3.22 s, total: 51.6 s
Wall time: 13.7 s
