In [None]:
# %matplotlib tk
import matplotlib
matplotlib.use('Qt5Agg')
import matplotlib.pyplot as plt
from matplotlib import font_manager
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from einops import rearrange, repeat

from compute import *

# hyper-parameter
model_name = 'EdSr'
control_name = 'MD'
benchmark_name = 'BM'
""" Temp Press PotEng KinEng Enthalpy E_vdwl E_coul E_pair E_bond E_angle E_dihed E_long E_tail E_mol Ecouple Econserve TotEng Lx Ly Lz"""
energy_unit = "kcal $\\cdot$ mol$^{-1}$"
press_unit = "ATM"
temperature = "K"
distance_unit = "Angstrom"
time_unit = "fs"
time_scale = {
    "fs": 1,
    "ps": 1000,
    "ns": 1000000,
}
xaxis_time_unit = 'ps'
thermo_style_unit = {
    'temp'     : f"Kelvin ({temperature})",    'Temp'     : f'Kelvin ({temperature})',
    'press'    : f'ATMosphere ({press_unit})', 'Press'    : f'ATMosphere ({press_unit})',
    "pe"       : f'energy ({energy_unit})',    'PotEng'   : f'energy ({energy_unit})',
    "ke"       : f'energy ({energy_unit})',    'KinEng'   : f'energy ({energy_unit})',
    "enthalpy" : f'energy ({energy_unit})',    'Enthalpy' : f'energy ({energy_unit})',
    "evdwl"    : f'energy ({energy_unit})',    'E_vdwl'   : f'energy ({energy_unit})',
    "ecoul"    : f'energy ({energy_unit})',    'E_coul'   : f'energy ({energy_unit})',
    "epair"    : f'energy ({energy_unit})',    'E_pair'   : f'energy ({energy_unit})',
    "ebond"    : f'energy ({energy_unit})',    'E_bond'   : f'energy ({energy_unit})',
    "eangle"   : f'energy ({energy_unit})',    'E_angle'  : f'energy ({energy_unit})',
    "edihed"   : f'energy ({energy_unit})',    'E_dihed'  : f'energy ({energy_unit})',
    "eimp"     : f'energy ({energy_unit})',
    "elong"    : f'energy ({energy_unit})',    'E_long'   : f'energy ({energy_unit})',
    "etail"    : f'energy ({energy_unit})',    'E_tail'   : f'energy ({energy_unit})',
    "emol"     : f'energy ({energy_unit})',    'E_mol'    : f'energy ({energy_unit})',
    "ecouple"  : f'energy ({energy_unit})',    'Ecouple'  : f'energy ({energy_unit})',
    "econserve": f'energy ({energy_unit})',    'Econserve': f'energy ({energy_unit})',
    "etotal"   : f'energy ({energy_unit})',    'TotEng'   : f'energy ({energy_unit})',
    'lx'       : f'length ({distance_unit})',  'Lx'       : f'length ({distance_unit})',
    'ly'       : f'length ({distance_unit})',  'Ly'       : f'length ({distance_unit})',
    'lz'       : f'length ({distance_unit})',  'Lz'       : f'length ({distance_unit})',
}

title_mapping = {
     'temp'     : r'Temperature',                                                'Temp'     : r'Temperature',
     'press'    : r'Pressure',                                                   'Press'    : r'Pressure',
     "pe"       : r'Potential energy',                                           'PotEng'   : r'potential energy',
     "ke"       : r'Kinetic energy',                                             'KinEng'   : r'Kinetic energy',
     "enthalpy" : r'Total energy (pe + ke)',                                     'Enthalpy' : r'Total energy (pe + ke)',
     "evdwl"    : r'Van der Waals pairwise energy',                              'E_vdwl'   : r'Van der Waals pairwise energy',
     "ecoul"    : r'Coulombic pairwise energy',                                  'E_coul'   : r'Coulombic pairwise energy',
     "epair"    : r'Pairwise energy',                                            'E_pair'   : r'Pairwise energy',
     "ebond"    : r'Bond energy',                                                'E_bond'   : r'Bond energy',
     "eangle"   : r'Angle energy',                                               'E_angle'  : r'Angle energy',
     "edihed"   : r'Dihedral energy',                                            'E_dihed'  : r'Dihedral energy',
     "eimp"     : r'Improper energy',
     "elong"    : r'Long-range kspace energy',                                   'E_long'   : r'Long-range kspace energy',
     "etail"    : r'Van der Waals energy long-range tail correction',            'E_tail'   : r'Van der Waals energy long-range tail correction',
     "emol"     : r'Intramolecular energy',                                      'E_mol'    : r'Intramolecular energy',
     "ecouple"  : r'Cumulative energy change due to thermo/baro statting fixes', 'Ecouple'  : r'Cumulative energy change due to thermo/baro statting fixes',
     "econserve": r'Etotal + ecouple',                                           'Econserve': r'Etotal + ecouple',
     "etotal"   : r'Total energy',                                               'TotEng'   : r'Total energy',
     'lx'       : r'Length of x-axis',                                           'Lx'       : r'Length of x-axis',
     'ly'       : r'Length of y-axis',                                           'Ly'       : r'Length of y-axis',
     'lz'       : r'Length of z-axis',                                           'Lz'       : r'Length of z-axis',
     'Rg'       : r'Radius of gyration',                                         'RG'       : r'Radius of gyration',
     'rmsd'     : r'RMSD',                                                       'RMSD'     : r'RMSD',
     'msd'      : r'MSD',                                                        'MSD'      : r'MSD',
}



np.set_printoptions(threshold = np.inf)

fontsize = font_manager.FontProperties(size = 12)
tick_fontsize = font_manager.FontProperties(size = 10)
title_fontsize = font_manager.FontProperties(size = 13)
legned_fontsize = font_manager.FontProperties(size = 8)
item_fontsize = font_manager.FontProperties(size = 12)

tight_layout_arg = dict(
    top=0.896,
    bottom=0.111,
    left=0.025,
    right=0.997,
    hspace=0.222,
    wspace=0.177
)

row, col = 2, 3

fig = plt.figure(figsize=[10*col,20*row], dpi=256)

pic_idx = 0
subgraph_item = 0 + ord('a') - 1
item_pos = (-0.06, 1.12)

## $$\textit{ 10.0\ fs }$$

In [None]:
# intv10
maxIter: int   = 500
intv   : int   = 10
basis  : float = 1.0
start  : int   = 1
end    : int   = 10000
group  : str   = "id 1 163"

benchmark_params = [
    f'data/benchmark_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]
control_params = [
    f'data/control_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]
edsr_params = [
    f'data/EdSr_nve_basis{basis}_intv{intv}_iter500/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]

# mass shape: (natoms,), id shape: (natoms,),  x shape: (ntrajs, natoms, 3), v shape: (ntrajs, natoms, 3)
with ThreadPoolExecutor(max_workers = 3) as executor:
    futures = [
        executor.submit(extraction, *benchmark_params[:-1], filter = None),
        executor.submit(extraction, *control_params[:-1], filter = None),
        executor.submit(extraction, *edsr_params[:-1] , filter = None),
    ]

benchmark_x, benchmark_v, benchmark_delta_t, benchmark_mass, benchmark_atype, benchmark_id, benchmark_boundary, benchmark_heads, benchmark_ppties, benchmark_init_state, benchmark_last_state = futures[0].result()
control_x, control_v, control_delta_t, control_mass, control_atype, control_id, control_boundary, control_heads, control_ppties, control_init_state, control_last_state = futures[1].result()
edsr_x, edsr_v, edsr_delta_t, edsr_mass, edsr_atype, edsr_id, edsr_boundary, edsr_heads, edsr_ppties, edsr_init_state, edsr_last_state = futures[2].result()

times = np.arange(end)*intv*basis * (time_scale[time_unit]/time_scale[xaxis_time_unit])

In [None]:
# coordinate figure



xdiff_bc = traj_abs_diff(benchmark_x, benchmark_boundary, control_x, control_boundary)
xdiff_bt = traj_abs_diff(benchmark_x, benchmark_boundary, edsr_x, edsr_boundary)

frame_xdiff_bc = np.mean(xdiff_bc, axis = (-1,  -2))
print(frame_xdiff_bc.shape)
frame_xdiff_bt = np.mean(xdiff_bt, axis = (-1,  -2))

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    gmean_xdiff_bc = np.pad(np.mean(frame_xdiff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bc[0])
    gmean_xdiff_bt = np.pad(np.mean(frame_xdiff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bt[0])
    group_times = np.pad(times[::skip] + (skip - 1)*intv*basis * (time_scale[time_unit]/time_scale[xaxis_time_unit]), pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gmean_xdiff_bc = frame_xdiff_bc
    gmean_xdiff_bt = frame_xdiff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

# ! title of column 
plt.text(0.5, 1.15, "MAE of Coordinates\n($\mathrm{\AA}$)", transform = ax.transAxes, fontsize = title_fontsize.get_size(), fontweight = 'bold', va = 'center', ha = 'center')


plt.plot(group_times, gmean_xdiff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gmean_xdiff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(frame_xdiff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{AVE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(frame_xdiff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{AVE_{{{control_name}}}}}$')
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())

plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-2, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# RMSD
ppty = 'rmsd'

type, st, en = group.split(" ")
if type == 'id':
    benchmark_mask = np.logical_and(benchmark_id >= int(st), benchmark_id <= int(en))
    control_mask = np.logical_and(control_id >= int(st), control_id <= int(en))
    edsr_mask = np.logical_and(edsr_id >= int(st), edsr_id <= int(en))
elif type == 'type':
    benchmark_mask = np.logical_and(benchmark_atype >= int(st), benchmark_atype <= int(en))
    control_mask = np.logical_and(control_atype >= int(st), control_atype <= int(en))
    edsr_mask = np.logical_and(edsr_atype >= int(st), edsr_atype <= int(en))

benchmark_RMSD = compute_RMSD(benchmark_x[:, benchmark_mask], benchmark_mass[benchmark_mask], 0, benchmark_boundary)
control_RMSD = compute_RMSD(control_x[:, control_mask], control_mass[control_mask], 0, control_boundary)
edsr_RMSD = compute_RMSD(edsr_x[:, edsr_mask], edsr_mass[edsr_mask], 0, edsr_boundary)

rmsd_diff_bc = np.fabs(benchmark_RMSD - control_RMSD)
rmsd_diff_bt = np.fabs(benchmark_RMSD - edsr_RMSD)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    grmsd_diff_bc = np.pad(np.mean(rmsd_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bc[0])
    grmsd_diff_bt = np.pad(np.mean(rmsd_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bt[0])
    group_times = np.pad(times.reshape(-1, skip)[:, -1], pad_width = (1, 0), mode = 'constant', constant_values = 0)
    skip_brmsd = np.pad(np.mean(benchmark_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = benchmark_RMSD[0])
    skip_crmsd = np.pad(np.mean(control_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = control_RMSD[0])
    skip_trmsd = np.pad(np.mean(edsr_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = edsr_RMSD[0])
else:
    grmsd_diff_bc = rmsd_diff_bc
    grmsd_diff_bt = rmsd_diff_bt
    group_times = times


ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

# attn title of column 
plt.text(0.5, 1.15, f"{title_mapping[ppty]}\n($\\mathrm{{\\AA}}$)", fontsize = title_fontsize.get_size(), transform = ax.transAxes, fontweight = 'bold', va = 'center', ha = 'center')

plt.plot(times, benchmark_RMSD, label = f"{benchmark_name}", linewidth = 2)
plt.plot(times, edsr_RMSD, label = f"{model_name}", linewidth = 1, alpha = 0.7)
plt.plot(times, control_RMSD, label = f"{control_name}", linewidth = 1, alpha = 0.7)

plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-2, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
plt.ylim(bottom = -0.8)
plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 1)

box = ax.get_position()
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
aia = fig.add_axes([x0 + 0.055, y0 + 0.011, 0.12, 0.12])
aia.set_title('MAE', fontdict = {'fontsize': title_fontsize.get_size() - 3, 'verticalalignment': 'top', 'fontweight': 'bold'})

aia.axhline(y = np.mean(rmsd_diff_bt), color = 'orange', linestyle = '--', linewidth = 2, label = f'{model_name}')
aia.axhline(y = np.mean(rmsd_diff_bc), color = 'green', linestyle = '--', linewidth = 1, label = f'{control_name}')
aia.set_xticks([])
aia.set_yticks([round(min(np.mean(rmsd_diff_bt), np.mean(rmsd_diff_bc)), 2), round(max(np.mean(rmsd_diff_bt), np.mean(rmsd_diff_bc)), 2)])
aia.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
aia.tick_params(axis = 'y', labelsize = tick_fontsize.get_size() - 4)
aia.set_ylim(bottom = 0.02, top = max(np.mean(rmsd_diff_bt), np.mean(rmsd_diff_bc)) * 1.1)
aia.yaxis.get_offset_text().set(size = tick_fontsize.get_size() - 4)
aia.legend(fontsize = legned_fontsize.get_size() - 3, loc = 'lower right', ncol = 2)

In [None]:
# Rg
ppty = 'Rg'

type, st, en = group.split(" ")
if type == 'id':
    benchmark_mask = np.logical_and(benchmark_id >= int(st), benchmark_id <= int(en))
    control_mask = np.logical_and(control_id >= int(st), control_id <= int(en))
    edsr_mask = np.logical_and(edsr_id >= int(st), edsr_id <= int(en))
elif type == 'type':
    benchmark_mask = np.logical_and(benchmark_atype >= int(st), benchmark_atype <= int(en))
    control_mask = np.logical_and(control_atype >= int(st), control_atype <= int(en))
    edsr_mask = np.logical_and(edsr_atype >= int(st), edsr_atype <= int(en))

benchmark_Rg = compute_RG(benchmark_x[:, benchmark_mask], benchmark_mass[benchmark_mask], benchmark_boundary)
control_Rg = compute_RG(control_x[:, control_mask], control_mass[control_mask], control_boundary)
edsr_Rg = compute_RG(edsr_x[:, edsr_mask], edsr_mass[edsr_mask], edsr_boundary)

Rg_diff_bc = np.fabs(benchmark_Rg - control_Rg)
Rg_diff_bt = np.fabs(benchmark_Rg - edsr_Rg)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    gRg_diff_bc = np.pad(np.mean(Rg_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bc[0])
    gRg_diff_bt = np.pad(np.mean(Rg_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bt[0])
    group_times = np.pad(times.reshape(-1, skip)[:, -1], pad_width = (1, 0), mode = 'constant', constant_values = 0)
    skip_brg = np.pad(np.mean(benchmark_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = benchmark_Rg[0])
    skip_crg = np.pad(np.mean(control_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = control_Rg[0])
    skip_trg = np.pad(np.mean(edsr_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = edsr_Rg[0])
else:
    gRg_diff_bc = Rg_diff_bc
    gRg_diff_bt = Rg_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

# attn title of column 
plt.text(0.5, 1.15, f"{title_mapping[ppty]}\n($\\mathrm{{\\AA}}$)", fontsize = title_fontsize.get_size(), transform = ax.transAxes, fontweight = 'bold', va = 'center', ha = 'center')

plt.plot(group_times, skip_brg, label = f"{benchmark_name}", linewidth = 2)
plt.plot(group_times, skip_trg, label = f"{model_name}", linewidth = 1, alpha = 0.7)
plt.plot(group_times, skip_crg, label = f"{control_name}", linewidth = 1, alpha = 0.7)

plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())

plt.yticks([10.8, 10.9, 11.0, 11.1, 11.2])
plt.ylim(bottom = 10.8)
plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 1)

box = ax.get_position()
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
aia = fig.add_axes([x0 + 0.055, y0 + 0.011, 0.12, 0.12])
aia.set_title('MAE', fontdict = {'fontsize': title_fontsize.get_size() - 3, 'verticalalignment': 'top', 'fontweight': 'bold'})

aia.axhline(y = np.mean(Rg_diff_bt), color = 'orange', linestyle = '--', linewidth = 2, label = f'{model_name}')
aia.axhline(y = np.mean(Rg_diff_bc), color = 'green', linestyle = '--', linewidth = 1, label = f'{control_name}')
aia.set_xticks([])
aia.set_yticks([round(min(np.mean(Rg_diff_bt), np.mean(Rg_diff_bc)), 3), round(max(np.mean(Rg_diff_bt), np.mean(Rg_diff_bc)), 3)])
aia.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
aia.tick_params(axis = 'y', labelsize = tick_fontsize.get_size() - 4)
aia.set_ylim(bottom = 0.03, top = 0.062)
aia.yaxis.get_offset_text().set(size = tick_fontsize.get_size() - 4)
aia.legend(fontsize = legned_fontsize.get_size() - 3, loc = 'lower right', ncol = 2)

## $$\textit{ 20.0\ fs }$$

In [None]:
# intv30
maxIter: int   = 500
intv   : int   = 20
basis  : float = 1.0
start  : int   = 1
end    : int   = 10000
group  : str   = "id 1 163"

benchmark_params = [
    f'data/benchmark_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]
control_params = [
    f'data/control_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]
edsr_params = [
    f'data/EdSr_nve_basis{basis}_intv{intv}_iter500/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]

# mass shape: (natoms,), id shape: (natoms,),  x shape: (ntrajs, natoms, 3), v shape: (ntrajs, natoms, 3)
with ThreadPoolExecutor(max_workers = 3) as executor:
    futures = [
        executor.submit(extraction, *benchmark_params[:-1], filter = None),
        executor.submit(extraction, *control_params[:-1], filter = None),
        executor.submit(extraction, *edsr_params[:-1] , filter = None),
    ]

benchmark_x, benchmark_v, benchmark_delta_t, benchmark_mass, benchmark_atype, benchmark_id, benchmark_boundary, benchmark_heads, benchmark_ppties, benchmark_init_state, benchmark_last_state = futures[0].result()
control_x, control_v, control_delta_t, control_mass, control_atype, control_id, control_boundary, control_heads, control_ppties, control_init_state, control_last_state = futures[1].result()
edsr_x, edsr_v, edsr_delta_t, edsr_mass, edsr_atype, edsr_id, edsr_boundary, edsr_heads, edsr_ppties, edsr_init_state, edsr_last_state = futures[2].result()

times = np.arange(end)*intv*basis * (time_scale[time_unit]/time_scale[xaxis_time_unit])

In [None]:
# coordinate figure

xdiff_bc = traj_abs_diff(benchmark_x, benchmark_boundary, control_x, control_boundary)
xdiff_bt = traj_abs_diff(benchmark_x, benchmark_boundary, edsr_x, edsr_boundary)

frame_xdiff_bc = np.mean(xdiff_bc, axis = (-1,  -2))
print(frame_xdiff_bc.shape)
frame_xdiff_bt = np.mean(xdiff_bt, axis = (-1,  -2))

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    gmean_xdiff_bc = np.pad(np.mean(frame_xdiff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bc[0])
    gmean_xdiff_bt = np.pad(np.mean(frame_xdiff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = frame_xdiff_bt[0])
    group_times = np.pad(times.reshape(-1, skip)[:, -1], pad_width = (1, 0), mode = 'constant', constant_values = 0)
else:
    gmean_xdiff_bc = frame_xdiff_bc
    gmean_xdiff_bt = frame_xdiff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')


plt.plot(group_times, gmean_xdiff_bt, label = f"{model_name}", linewidth = 2)
plt.plot(group_times, gmean_xdiff_bc, label = f"{control_name}", linewidth = 1)
plt.axhline(y = np.mean(frame_xdiff_bt), color = 'g', linestyle = '--', linewidth = 2, label = f'$\\mathrm{{AVE_{{{model_name}}}}}$')
plt.axhline(y = np.mean(frame_xdiff_bc), color = 'r', linestyle = '--', linewidth = 1, label = f'$\\mathrm{{AVE_{{{control_name}}}}}$')

plt.xlabel(f't ($\\mathrm{{{xaxis_time_unit}}}$)', fontproperties = fontsize, fontweight = 'bold')

plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-2, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
# RMSD
ppty = 'rmsd'

type, st, en = group.split(" ")
if type == 'id':
    benchmark_mask = np.logical_and(benchmark_id >= int(st), benchmark_id <= int(en))
    control_mask = np.logical_and(control_id >= int(st), control_id <= int(en))
    edsr_mask = np.logical_and(edsr_id >= int(st), edsr_id <= int(en))
elif type == 'type':
    benchmark_mask = np.logical_and(benchmark_atype >= int(st), benchmark_atype <= int(en))
    control_mask = np.logical_and(control_atype >= int(st), control_atype <= int(en))
    edsr_mask = np.logical_and(edsr_atype >= int(st), edsr_atype <= int(en))

benchmark_RMSD = compute_RMSD(benchmark_x[:, benchmark_mask], benchmark_mass[benchmark_mask], 0, benchmark_boundary)
control_RMSD = compute_RMSD(control_x[:, control_mask], control_mass[control_mask], 0, control_boundary)
edsr_RMSD = compute_RMSD(edsr_x[:, edsr_mask], edsr_mass[edsr_mask], 0, edsr_boundary)

rmsd_diff_bc = np.fabs(benchmark_RMSD - control_RMSD)
rmsd_diff_bt = np.fabs(benchmark_RMSD - edsr_RMSD)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    grmsd_diff_bc = np.pad(np.mean(rmsd_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bc[0])
    grmsd_diff_bt = np.pad(np.mean(rmsd_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = rmsd_diff_bt[0])
    group_times = np.pad(times.reshape(-1, skip)[:, -1], pad_width = (1, 0), mode = 'constant', constant_values = 0)
    skip_brmsd = np.pad(np.mean(benchmark_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = benchmark_RMSD[0])
    skip_crmsd = np.pad(np.mean(control_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = control_RMSD[0])
    skip_trmsd = np.pad(np.mean(edsr_RMSD.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = edsr_RMSD[0])
else:
    grmsd_diff_bc = rmsd_diff_bc
    grmsd_diff_bt = rmsd_diff_bt
    group_times = times


ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

plt.plot(times, benchmark_RMSD, label = f"{benchmark_name}", linewidth = 2)
plt.plot(times, edsr_RMSD, label = f"{model_name}", linewidth = 1, alpha = 0.7)
plt.plot(times, control_RMSD, label = f"{control_name}", linewidth = 1, alpha = 0.7)

plt.xlabel(f't ($\\mathrm{{{xaxis_time_unit}}}$)', fontproperties = fontsize, fontweight = 'bold')

plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-2, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
plt.yticks([0., 1, 2, 3])
plt.ylim(bottom = -0.95)
plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 1)

box1 = ax.get_position()
x0, y0, x1, y1 = box1.x0, box1.y0, box1.x1, box1.y1
aia = fig.add_axes([x0 + 0.055, y0 + 0.011, 0.12, 0.12])
aia.set_title('MAE', fontdict = {'fontsize': title_fontsize.get_size() - 3, 'verticalalignment': 'top', 'fontweight': 'bold'})
aia.axhline(y = np.mean(rmsd_diff_bt), color = 'orange', linestyle = '--', linewidth = 2, label = f'{model_name}')
aia.axhline(y = np.mean(grmsd_diff_bc), color = 'green', linestyle = '--', linewidth = 1, label = f'{control_name}')
aia.set_xticks([])
aia.set_yticks([round(min(np.mean(rmsd_diff_bt), np.mean(grmsd_diff_bc)), 2), round(max(np.mean(rmsd_diff_bt), np.mean(grmsd_diff_bc)), 2)])
aia.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
aia.tick_params(axis = 'y', labelsize = tick_fontsize.get_size() - 4)
aia.yaxis.get_offset_text().set(size = tick_fontsize.get_size() - 4)
aia.set_ylim(bottom = -0.2, top = max(np.mean(rmsd_diff_bt), np.mean(grmsd_diff_bc)) * 1.1)
aia.legend(fontsize = legned_fontsize.get_size() - 3, loc = 'lower right', ncol = 2)

In [None]:
# Rg
ppty = 'Rg'

type, st, en = group.split(" ")
if type == 'id':
    benchmark_mask = np.logical_and(benchmark_id >= int(st), benchmark_id <= int(en))
    control_mask = np.logical_and(control_id >= int(st), control_id <= int(en))
    edsr_mask = np.logical_and(edsr_id >= int(st), edsr_id <= int(en))
elif type == 'type':
    benchmark_mask = np.logical_and(benchmark_atype >= int(st), benchmark_atype <= int(en))
    control_mask = np.logical_and(control_atype >= int(st), control_atype <= int(en))
    edsr_mask = np.logical_and(edsr_atype >= int(st), edsr_atype <= int(en))

benchmark_Rg = compute_RG(benchmark_x[:, benchmark_mask], benchmark_mass[benchmark_mask], benchmark_boundary)
control_Rg = compute_RG(control_x[:, control_mask], control_mass[control_mask], control_boundary)
edsr_Rg = compute_RG(edsr_x[:, edsr_mask], edsr_mass[edsr_mask], edsr_boundary)

Rg_diff_bc = np.fabs(benchmark_Rg - control_Rg)
Rg_diff_bt = np.fabs(benchmark_Rg - edsr_Rg)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

skip = 200
if skip > 1:
    gRg_diff_bc = np.pad(np.mean(Rg_diff_bc.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bc[0])
    gRg_diff_bt = np.pad(np.mean(Rg_diff_bt.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = Rg_diff_bt[0])
    group_times = np.pad(times.reshape(-1, skip)[:, -1], pad_width = (1, 0), mode = 'constant', constant_values = 0)
    skip_brg = np.pad(np.mean(benchmark_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = benchmark_Rg[0])
    skip_crg = np.pad(np.mean(control_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = control_Rg[0])
    skip_trg = np.pad(np.mean(edsr_Rg.reshape(-1, skip), axis = -1), pad_width = (1, 0), mode = 'constant', constant_values = edsr_Rg[0])
else:
    gRg_diff_bc = Rg_diff_bc
    gRg_diff_bt = Rg_diff_bt
    group_times = times

ax = plt.subplot(row, col, pic_idx)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

plt.plot(group_times, skip_brg, label = f"{benchmark_name}", linewidth = 2)
plt.plot(group_times, skip_trg, label = f"{model_name}", linewidth = 1, alpha = 0.7)
plt.plot(group_times, skip_crg, label = f"{control_name}", linewidth = 1, alpha = 0.7)

plt.xlabel(f't ($\\mathrm{{{xaxis_time_unit}}}$)', fontproperties = fontsize, fontweight = 'bold');
plt.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())

plt.yticks([10.8, 10.9, 11.0, 11.1, 11.2])
plt.ylim(bottom = 10.8)
plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 1)

box = ax.get_position()
x0, y0, x1, y1 = box.x0, box.y0, box.x1, box.y1
aia = fig.add_axes([x0 + 0.055, y0 + 0.011, 0.12, 0.12])
aia.set_title('MAE', fontdict = {'fontsize': title_fontsize.get_size() - 3, 'verticalalignment': 'top', 'fontweight': 'bold'})
aia.axhline(y = np.mean(Rg_diff_bt), color = 'orange', linestyle = '--', linewidth = 2, label = f'{model_name}')
aia.axhline(y = np.mean(Rg_diff_bc), color = 'green', linestyle = '--', linewidth = 1, label = f'{control_name}')
aia.tick_params(axis = 'y', labelsize = tick_fontsize.get_size() - 4)
aia.ticklabel_format(axis = 'y', style = 'scientific', scilimits = (-1, 2), useMathText = True, useLocale = True)
aia.tick_params(axis = 'both', labelsize = tick_fontsize.get_size() - 4)
aia.set_xticks([])
aia.set_yticks([round(min(np.mean(Rg_diff_bt), np.mean(Rg_diff_bc)), 3), round(max(np.mean(Rg_diff_bt), np.mean(Rg_diff_bc)), 3)])
aia.set_ylim(bottom = 0.035, top = max(np.mean(Rg_diff_bt), np.mean(Rg_diff_bc)) * 1.1)
aia.yaxis.get_offset_text().set(size = tick_fontsize.get_size() - 4)
aia.legend(fontsize = legned_fontsize.get_size() - 3, loc = 'lower right', ncol = 2)

In [None]:
plt.subplots_adjust(**tight_layout_arg)
plt.show()