In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import font_manager
from concurrent.futures import ThreadPoolExecutor
import matplotlib.image as mpimg

import numpy as np

from einops import rearrange, repeat

from compute import *

# hype parameter
model_name = 'EdSr'
control_name = 'MD'
""" Temp Press PotEng KinEng Enthalpy E_vdwl E_coul E_pair E_bond E_angle E_dihed E_long E_tail E_mol Ecouple Econserve TotEng Lx Ly Lz"""
energy_unit = "kcal $\\cdot$ mol$^{-1}$"
press_unit = "ATM"
temperature = "K"
distance_unit = "Angstrom"
time_unit = "fs"
thermo_style_unit = {
    'temp'     : f"Kelvin ({temperature})",    'Temp'     : f'Kelvin ({temperature})',
    'press'    : f'ATMosphere ({press_unit})', 'Press'    : f'ATMosphere ({press_unit})',
    "pe"       : f'energy ({energy_unit})',    'PotEng'   : f'energy ({energy_unit})',
    "ke"       : f'energy ({energy_unit})',    'KinEng'   : f'energy ({energy_unit})',
    "enthalpy" : f'energy ({energy_unit})',    'Enthalpy' : f'energy ({energy_unit})',
    "evdwl"    : f'energy ({energy_unit})',    'E_vdwl'   : f'energy ({energy_unit})',
    "ecoul"    : f'energy ({energy_unit})',    'E_coul'   : f'energy ({energy_unit})',
    "epair"    : f'energy ({energy_unit})',    'E_pair'   : f'energy ({energy_unit})',
    "ebond"    : f'energy ({energy_unit})',    'E_bond'   : f'energy ({energy_unit})',
    "eangle"   : f'energy ({energy_unit})',    'E_angle'  : f'energy ({energy_unit})',
    "edihed"   : f'energy ({energy_unit})',    'E_dihed'  : f'energy ({energy_unit})',
    "eimp"     : f'energy ({energy_unit})',
    "elong"    : f'energy ({energy_unit})',    'E_long'   : f'energy ({energy_unit})',
    "etail"    : f'energy ({energy_unit})',    'E_tail'   : f'energy ({energy_unit})',
    "emol"     : f'energy ({energy_unit})',    'E_mol'    : f'energy ({energy_unit})',
    "ecouple"  : f'energy ({energy_unit})',    'Ecouple'  : f'energy ({energy_unit})',
    "econserve": f'energy ({energy_unit})',    'Econserve': f'energy ({energy_unit})',
    "etotal"   : f'energy ({energy_unit})',    'TotEng'   : f'energy ({energy_unit})',
    'lx'       : f'length ({distance_unit})',  'Lx'       : f'length ({distance_unit})',
    'ly'       : f'length ({distance_unit})',  'Ly'       : f'length ({distance_unit})',
    'lz'       : f'length ({distance_unit})',  'Lz'       : f'length ({distance_unit})',
}

title_mapping = {
     'temp'     : r'Temperature',                                                'Temp'     : r'Temperature',
     'press'    : r'Pressure',                                                   'Press'    : r'Pressure',
     "pe"       : r'Potential energy',                                           'PotEng'   : r'potential energy',
     "ke"       : r'Kinetic energy',                                             'KinEng'   : r'Kinetic energy',
     "enthalpy" : r'Total energy (pe + ke)',                                     'Enthalpy' : r'Total energy (pe + ke)',
     "evdwl"    : r'Van der Waals pairwise energy',                              'E_vdwl'   : r'Van der Waals pairwise energy',
     "ecoul"    : r'Coulombic pairwise energy',                                  'E_coul'   : r'Coulombic pairwise energy',
     "epair"    : r'Pairwise energy',                                            'E_pair'   : r'Pairwise energy',
     "ebond"    : r'Bond energy',                                                'E_bond'   : r'Bond energy',
     "eangle"   : r'Angle energy',                                               'E_angle'  : r'Angle energy',
     "edihed"   : r'Dihedral energy',                                            'E_dihed'  : r'Dihedral energy',
     "eimp"     : r'Improper energy',
     "elong"    : r'Long-range kspace energy',                                   'E_long'   : r'Long-range kspace energy',
     "etail"    : r'Van der Waals energy long-range tail correction',            'E_tail'   : r'Van der Waals energy long-range tail correction',
     "emol"     : r'Intramolecular energy',                                      'E_mol'    : r'Intramolecular energy',
     "ecouple"  : r'Cumulative energy change due to thermo/baro statting fixes', 'Ecouple'  : r'Cumulative energy change due to thermo/baro statting fixes',
     "econserve": r'Etotal + ecouple',                                           'Econserve': r'Etotal + ecouple',
     "etotal"   : r'Total energy',                                               'TotEng'   : r'Total energy',
     'lx'       : r'Length of x-axis',                                           'Lx'       : r'Length of x-axis',
     'ly'       : r'Length of y-axis',                                           'Ly'       : r'Length of y-axis',
     'lz'       : r'Length of z-axis',                                           'Lz'       : r'Length of z-axis',
     'Rg'       : r'Radius of gyration',                                         'RG'       : r'Radius of gyration',
     'rmsd'     : r'RMSD',                                                       'RMSD'     : r'RMSD',
}

maxIter: int   = 500

np.set_printoptions(threshold = np.inf)

fontsize = font_manager.FontProperties(size = 10)
tick_fontsize = font_manager.FontProperties(size = 10)
title_fontsize = font_manager.FontProperties(size = 11)
legned_fontsize = font_manager.FontProperties(size = 8.5)

row, col = 3, 3

fig = plt.figure(figsize=[10*col,20*row], dpi=256)

pic_idx = 0

In [None]:
intv = 20
basis = 0.01
start = 1
end = 10000
group = "type 3:18"

benchmark_params = [
    f'data/benchmark_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
]
control_params = [
    f'data/control_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
]
taylor_params = [
    f'data/taylor_nve_basis{basis}_intv{intv}_iter500/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
]

# mass shape: (natoms,), id shape: (natoms,),  x shape: (ntrajs, natoms, 3), v shape: (ntrajs, natoms, 3)
with ThreadPoolExecutor(max_workers = 3) as executor:
    futures = [
        executor.submit(extraction, *benchmark_params, filter = None),
        executor.submit(extraction, *control_params, filter = None),
        executor.submit(extraction, *taylor_params , filter = None),
    ]

benchmark_x, benchmark_v, benchmark_delta_t, benchmark_mass, benchmark_atype, benchmark_id, benchmark_boundary, benchmark_heads, benchmark_ppties = futures[0].result()
control_x, control_v, control_delta_t, control_mass, control_atype, control_id, control_boundary, control_heads, control_ppties = futures[1].result()
taylor_x, taylor_v, taylor_delta_t, taylor_mass, taylor_atype, taylor_id, taylor_boundary, taylor_heads, taylor_ppties = futures[2].result()

times = np.arange(end)*intv*basis

In [None]:
# Structure figure

pic_idx += 1

struture = mpimg.imread('../structure.png')

ax = plt.subplot(row, col, pic_idx)
plt.axis('off')
plt.xticks([]); plt.yticks([])
# plt.subplots_adjust(top=0.993,bottom=0.0,left=0.037,right=0.992,hspace=0.04,wspace=0.05)
plt.imshow(struture, interpolation = 'bicubic', aspect = 'equal')

In [None]:
# coordinate figure

xdiff_bc = traj_abs_diff(benchmark_x, benchmark_boundary, control_x, control_boundary)
xdiff_bt = traj_abs_diff(benchmark_x, benchmark_boundary, taylor_x, taylor_boundary)

frame_xdiff_bc = np.mean(xdiff_bc, axis = (-1,  -2))
print(frame_xdiff_bc.shape)
frame_xdiff_bt = np.mean(xdiff_bt, axis = (-1,  -2))

# ! Picture idx
pic_idx += 1

skip = 100
if skip > 1:
    gmean_xdiff_bc = np.pad(np.mean(frame_xdiff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bc[0])
    gmean_xdiff_bt = np.pad(np.mean(frame_xdiff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gmean_xdiff_bc = frame_xdiff_bc
    gmean_xdiff_bt = frame_xdiff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

plt.plot(group_times, gmean_xdiff_bt, label = f"{model_name} $\mathcal{{D}}$", linewidth = 2)
plt.plot(group_times, gmean_xdiff_bc, label = f"{control_name} $\mathcal{{D}}$", linewidth = 1)
plt.axhline(y = np.mean(frame_xdiff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(frame_xdiff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error ($\mathrm{\AA}$)', fontproperties = fontsize)
plt.title(r"Coordinate ($\mathrm{\AA}$)", loc = 'center', fontproperties = title_fontsize)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# velocity figure

vdiff_bc = np.abs(benchmark_v - control_v)
vdiff_bt = np.abs(benchmark_v - taylor_v)

frame_vdiff_bc = np.mean(vdiff_bc, axis = (-1,  -2))
frame_vdiff_bt = np.mean(vdiff_bt, axis = (-1,  -2))


# ! Picture idx
pic_idx += 1

skip = 100
if skip > 1:
    gmean_vdiff_bc = np.pad(np.mean(frame_vdiff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_vdiff_bc[0])
    gmean_vdiff_bt = np.pad(np.mean(frame_vdiff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_vdiff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gmean_vdiff_bc = frame_vdiff_bc
    gmean_vdiff_bt = frame_vdiff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

plt.plot(group_times, gmean_vdiff_bt, label = f"{model_name} $\mathcal{{V}}$", linewidth = 2)
plt.plot(group_times, gmean_vdiff_bc, label = f"{control_name} $\mathcal{{V}}$", linewidth = 1)
plt.axhline(y = np.mean(frame_vdiff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(frame_vdiff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error ($\mathrm{\AA} \cdot \mathrm{fs}^{-1}$)', fontproperties = fontsize)
plt.title(r"Velocity ($\mathrm{\AA} \cdot \mathrm{fs}^{-1}$)", loc = 'center', fontproperties = title_fontsize)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# RMSD
ppty = 'rmsd'

benchmark_RMSD = compute_RMSD(benchmark_x, benchmark_mass, 0, benchmark_boundary)
control_RMSD = compute_RMSD(control_x, control_mass, 0, control_boundary)
taylor_RMSD = compute_RMSD(taylor_x, taylor_mass, 0, taylor_boundary)

rmsd_diff_bc = np.fabs(benchmark_RMSD - control_RMSD)
rmsd_diff_bt = np.fabs(benchmark_RMSD - taylor_RMSD)

# ! Picture idx
pic_idx += 1

skip = 200
if skip > 1:
    grmsd_diff_bc = np.pad(np.mean(rmsd_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bc[0])
    grmsd_diff_bt = np.pad(np.mean(rmsd_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    grmsd_diff_bc = rmsd_diff_bc
    grmsd_diff_bt = rmsd_diff_bt
    group_times = times


ax = plt.subplot(row, col, pic_idx)

plt.plot(group_times, grmsd_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, grmsd_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(rmsd_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(rmsd_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error ($\mathrm{\AA}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} ($\\mathrm{{\\AA}}$)", loc = 'center', fontproperties = title_fontsize)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
plt.yticks([0.0, 0.25, 0.5])
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'upper left', ncol = 2)

In [None]:
# Rg
ppty = 'Rg'


benchmark_Rg = compute_RG(benchmark_x, benchmark_mass, benchmark_boundary)
control_Rg = compute_RG(control_x, control_mass, control_boundary)
taylor_Rg = compute_RG(taylor_x, taylor_mass, taylor_boundary)

Rg_diff_bc = np.fabs(benchmark_Rg - control_Rg)
Rg_diff_bt = np.fabs(benchmark_Rg - taylor_Rg)

# ! Picture idx
pic_idx += 1

skip = 200
if skip > 1:
    gRg_diff_bc = np.pad(np.mean(Rg_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bc[0])
    gRg_diff_bt = np.pad(np.mean(Rg_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gRg_diff_bc = Rg_diff_bc
    gRg_diff_bt = Rg_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

plt.plot(group_times, gRg_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gRg_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(Rg_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(Rg_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error ($\mathrm{\AA}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} ($\\mathrm{{\\AA}}$)", loc = 'center', fontproperties = title_fontsize)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'upper left', ncol = 2)

In [None]:
# kinetic energy
ppty = 'ke'

ke_bidx, = np.where(benchmark_heads == ppty); bke = benchmark_ppties[:, ke_bidx]
ke_cidx, = np.where(control_heads == ppty); cke = control_ppties[:, ke_cidx]
ke_tidx, = np.where(taylor_heads == ppty); tke = taylor_ppties[:, ke_tidx]

ke_diff_bc = np.fabs(cke - bke)
ke_diff_bt = np.fabs(tke - bke)

# ! Picture idx
pic_idx += 1

skip = 100
if skip > 1:
    gke_diff_bc = np.pad(np.mean(ke_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = ke_diff_bc[0])
    gke_diff_bt = np.pad(np.mean(ke_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = ke_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gke_diff_bc = ke_diff_bc
    gke_diff_bt = ke_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

# plt.plot(times[::skip], tke[::skip], label = f"{model_name}", linewidth = 2)
# plt.plot(times[::skip], cke[::skip], label = f"{control_name}", linewidth = 1)
# plt.plot(times[::skip], bke[::skip], label = f"benchmark", linewidth = 1)
plt.plot(group_times, gke_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gke_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(ke_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(ke_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error (kcal$\cdot \mathrm{mol^{-1}}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} (kcal$\\cdot \\mathrm{{mol^{{-1}}}}$)", loc = 'center', fontproperties = title_fontsize)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
# plt.yscale('log'); plt.yticks([1e1, 1e2])
plt.yticks([0, 25, 50])
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# Pairwise energy
ppty = 'epair'

epair_bidx, = np.where(benchmark_heads == ppty); bepair = benchmark_ppties[:, epair_bidx]
epair_cidx, = np.where(control_heads == ppty); cepair = control_ppties[:, epair_cidx]
epair_tidx, = np.where(taylor_heads == ppty); tepair = taylor_ppties[:, epair_tidx]


epair_diff_bc = np.fabs(cepair - bepair)
epair_diff_bt = np.fabs(tepair - bepair)

# ! Picture idx
pic_idx += 1

skip = 200
if skip > 1:
    gepair_diff_bc = np.pad(np.mean(epair_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = epair_diff_bc[0])
    gepair_diff_bt = np.pad(np.mean(epair_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = epair_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gepair_diff_bc = epair_diff_bc
    gepair_diff_bt = epair_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

# plt.plot(times[::skip], tepair[::skip], label = f"{model_name}", linewidth = 2)
# plt.plot(times[::skip], cepair[::skip], label = f"{control_name}", linewidth = 1)
# plt.plot(times[::skip], bepair[::skip], label = f"benchmark", linewidth = 1)
plt.plot(group_times, gepair_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gepair_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(epair_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(epair_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error (kcal$\cdot \mathrm{mol^{-1}}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} (kcal$\\cdot \\mathrm{{mol^{{-1}}}}$)", loc = 'center', fontsize = title_fontsize.get_size() - 1)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
# plt.yscale('log')
plt.yticks([0, 30, 60])
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# Molecular energy
ppty = 'emol'

emol_bidx, = np.where(benchmark_heads == ppty); bemol = benchmark_ppties[:, emol_bidx]
emol_cidx, = np.where(control_heads == ppty); cemol = control_ppties[:, emol_cidx]
emol_tidx, = np.where(taylor_heads == ppty); temol = taylor_ppties[:, emol_tidx]

emol_diff_bc = np.fabs(cemol - bemol)
emol_diff_bt = np.fabs(temol - bemol)

# ! Picture idx
pic_idx += 1

skip = 200
if skip > 1:
    gemol_diff_bc = np.pad(np.mean(emol_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = emol_diff_bc[0])
    gemol_diff_bt = np.pad(np.mean(emol_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = emol_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gemol_diff_bc = emol_diff_bc
    gemol_diff_bt = emol_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

# plt.plot(times[::skip], temol[::skip], label = f"{model_name}", linewidth = 2)
# plt.plot(times[::skip], cemol[::skip], label = f"{control_name}", linewidth = 1)
# plt.plot(times[::skip], bemol[::skip], label = f"benchmark", linewidth = 1)
plt.plot(group_times, gemol_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gemol_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(emol_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(emol_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error (kcal$\cdot \mathrm{mol^{-1}}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} (kcal$\\cdot \\mathrm{{mol^{{-1}}}}$)", loc = 'center', fontsize = title_fontsize.get_size() - 1)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
# plt.yscale('log')
# plt.yticks([0, 10, 20, 30])
plt.yscale('log'); plt.yticks([1e1, 4e1])
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# Molecular energy
ppty = 'press'

press_bidx, = np.where(benchmark_heads == ppty); bpress = benchmark_ppties[:, press_bidx]
press_cidx, = np.where(control_heads == ppty); cpress = control_ppties[:, press_cidx]
press_tidx, = np.where(taylor_heads == ppty); tpress = taylor_ppties[:, press_tidx]

press_diff_bc = np.fabs(cpress - bpress)
press_diff_bt = np.fabs(tpress - bpress)

# ! Picture idx
pic_idx += 1

skip = 200
if skip > 1:
    gpress_diff_bc = np.pad(np.mean(press_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = press_diff_bc[0])
    gpress_diff_bt = np.pad(np.mean(press_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = press_diff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis, pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gpress_diff_bc = press_diff_bc
    gpress_diff_bt = press_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)

# plt.plot(times[::skip], tpress[::skip], label = f"{model_name}", linewidth = 2)
# plt.plot(times[::skip], cpress[::skip], label = f"{control_name}", linewidth = 1)
# plt.plot(times[::skip], bpress[::skip], label = f"benchmark", linewidth = 1)
plt.plot(group_times, gpress_diff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gpress_diff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(press_diff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{MAE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(press_diff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{MAE_{{{control_name}}}}}$')

plt.xlabel(f't (fs)', fontproperties = fontsize); # plt.ylabel(r'error (kcal$\cdot \mathrm{mol^{-1}}$)', fontproperties = fontsize)
plt.title(f"{title_mapping[ppty]} (ATM)", loc = 'center', fontsize = title_fontsize.get_size() - 1)
plt.ticklabel_format(axis = 'x', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (0, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
# exp_idx = int(np.floor(np.log10(data_mean)))
# plt.yscale('log')
# plt.yticks([0, 10, 20, 30])
plt.subplots_adjust(top=0.953,bottom=0.091,left=0.037,right=0.992,hspace=0.554,wspace=0.13)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# fig.tight_layout()
fig.show()