In [None]:
# %matplotlib tk
import matplotlib
matplotlib.use('Qt5Agg')  # 使用 Qt5 后端
import matplotlib.pyplot as plt
from matplotlib import font_manager
from concurrent.futures import ThreadPoolExecutor
import numpy as np

from einops import rearrange, repeat

from compute import *

# hyper-parameter
model_name = 'EdSr'
control_name = 'MD'
benchmark_name = 'BM'
""" Temp Press PotEng KinEng Enthalpy E_vdwl E_coul E_pair E_bond E_angle E_dihed E_long E_tail E_mol Ecouple Econserve TotEng Lx Ly Lz"""
energy_unit = "kcal $\\cdot$ mol$^{-1}$"
press_unit = "ATM"
temperature = "K"
distance_unit = "Angstrom"
time_unit = "fs"
time_scale = {
    "fs": 1,
    "ps": 1000,
    "ns": 1000000,
}
xaxis_time_unit = 'ps'
thermo_style_unit = {
    'temp'     : f"Kelvin ({temperature})",    'Temp'     : f'Kelvin ({temperature})',
    'press'    : f'ATMosphere ({press_unit})', 'Press'    : f'ATMosphere ({press_unit})',
    "pe"       : f'energy ({energy_unit})',    'PotEng'   : f'energy ({energy_unit})',
    "ke"       : f'energy ({energy_unit})',    'KinEng'   : f'energy ({energy_unit})',
    "enthalpy" : f'energy ({energy_unit})',    'Enthalpy' : f'energy ({energy_unit})',
    "evdwl"    : f'energy ({energy_unit})',    'E_vdwl'   : f'energy ({energy_unit})',
    "ecoul"    : f'energy ({energy_unit})',    'E_coul'   : f'energy ({energy_unit})',
    "epair"    : f'energy ({energy_unit})',    'E_pair'   : f'energy ({energy_unit})',
    "ebond"    : f'energy ({energy_unit})',    'E_bond'   : f'energy ({energy_unit})',
    "eangle"   : f'energy ({energy_unit})',    'E_angle'  : f'energy ({energy_unit})',
    "edihed"   : f'energy ({energy_unit})',    'E_dihed'  : f'energy ({energy_unit})',
    "eimp"     : f'energy ({energy_unit})',
    "elong"    : f'energy ({energy_unit})',    'E_long'   : f'energy ({energy_unit})',
    "etail"    : f'energy ({energy_unit})',    'E_tail'   : f'energy ({energy_unit})',
    "emol"     : f'energy ({energy_unit})',    'E_mol'    : f'energy ({energy_unit})',
    "ecouple"  : f'energy ({energy_unit})',    'Ecouple'  : f'energy ({energy_unit})',
    "econserve": f'energy ({energy_unit})',    'Econserve': f'energy ({energy_unit})',
    "etotal"   : f'energy ({energy_unit})',    'TotEng'   : f'energy ({energy_unit})',
    'lx'       : f'length ({distance_unit})',  'Lx'       : f'length ({distance_unit})',
    'ly'       : f'length ({distance_unit})',  'Ly'       : f'length ({distance_unit})',
    'lz'       : f'length ({distance_unit})',  'Lz'       : f'length ({distance_unit})',
}

title_mapping = {
     'temp'     : r'Temperature',                                                'Temp'     : r'Temperature',
     'press'    : r'Pressure',                                                   'Press'    : r'Pressure',
     "pe"       : r'Potential energy',                                           'PotEng'   : r'potential energy',
     "ke"       : r'Kinetic energy',                                             'KinEng'   : r'Kinetic energy',
     "enthalpy" : r'Total energy (pe + ke)',                                     'Enthalpy' : r'Total energy (pe + ke)',
     "evdwl"    : r'Van der Waals pairwise energy',                              'E_vdwl'   : r'Van der Waals pairwise energy',
     "ecoul"    : r'Coulombic pairwise energy',                                  'E_coul'   : r'Coulombic pairwise energy',
     "epair"    : r'Pairwise energy',                                            'E_pair'   : r'Pairwise energy',
     "ebond"    : r'Bond energy',                                                'E_bond'   : r'Bond energy',
     "eangle"   : r'Angle energy',                                               'E_angle'  : r'Angle energy',
     "edihed"   : r'Dihedral energy',                                            'E_dihed'  : r'Dihedral energy',
     "eimp"     : r'Improper energy',
     "elong"    : r'Long-range kspace energy',                                   'E_long'   : r'Long-range kspace energy',
     "etail"    : r'Van der Waals energy long-range tail correction',            'E_tail'   : r'Van der Waals energy long-range tail correction',
     "emol"     : r'Intramolecular energy',                                      'E_mol'    : r'Intramolecular energy',
     "ecouple"  : r'Cumulative energy change due to thermo/baro statting fixes', 'Ecouple'  : r'Cumulative energy change due to thermo/baro statting fixes',
     "econserve": r'Etotal + ecouple',                                           'Econserve': r'Etotal + ecouple',
     "etotal"   : r'Total energy',                                               'TotEng'   : r'Total energy',
     'lx'       : r'Length of x-axis',                                           'Lx'       : r'Length of x-axis',
     'ly'       : r'Length of y-axis',                                           'Ly'       : r'Length of y-axis',
     'lz'       : r'Length of z-axis',                                           'Lz'       : r'Length of z-axis',
     'Rg'       : r'Radius of gyration',                                         'RG'       : r'Radius of gyration',
     'rmsd'     : r'RMSD',                                                       'RMSD'     : r'RMSD',
     'msd'      : r'MSD',                                                        'MSD'      : r'MSD',
}



np.set_printoptions(threshold = np.inf)

fontsize = font_manager.FontProperties(size = 12)
tick_fontsize = font_manager.FontProperties(size = 10)
title_fontsize = font_manager.FontProperties(size = 13)
legned_fontsize = font_manager.FontProperties(size = 8)
item_fontsize = font_manager.FontProperties(size = 12)

tight_layout_arg = dict(
    top=0.921,
    bottom=0.116,
    left=0.052,
    right=0.998,
    hspace=0.222,
    wspace=0.102
)

row, col = 2, 2

fig = plt.figure(figsize=[10*col,20*row], dpi=256)

pic_idx = 0
subgraph_item = 0 + ord('g') - 1
item_pos = (-0.05, 1.1)

## $$\textit{ 30.0\ fs }$$

In [None]:
# intv30
maxIter: int   = 500
intv   : int   = 30
basis  : float = 1.0
start  : int   = 1
end    : int   = 10000
group  : str   = "id 1 163"

benchmark_params = [
    f'data/benchmark_nve_basis{basis}_intv{intv}/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]
taylor_params = [
    f'data/taylor_nve_basis{basis}_intv{intv}_iter500/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]

# mass shape: (natoms,), id shape: (natoms,),  x shape: (ntrajs, natoms, 3), v shape: (ntrajs, natoms, 3)
with ThreadPoolExecutor(max_workers = 3) as executor:
    futures = [
        executor.submit(extraction, *benchmark_params[:-1], filter = None),
        executor.submit(extraction, *taylor_params[:-1] , filter = None),
    ]

benchmark_x, benchmark_v, benchmark_delta_t, benchmark_mass, benchmark_atype, benchmark_id, benchmark_boundary, benchmark_heads, benchmark_ppties, benchmark_init_state, benchmark_last_state = futures[0].result()
taylor_x, taylor_v, taylor_delta_t, taylor_mass, taylor_atype, taylor_id, taylor_boundary, taylor_heads, taylor_ppties, taylor_init_state, taylor_last_state = futures[1].result()

times = np.arange(end)*intv*basis

In [None]:
def read(path, file) -> np.ndarray:
    with open(path + '/' + file, 'r') as f:
        data = f.read().replace('\n', ' ').split(' ')
        if data[-1] == '':
            data.pop(-1)

    data = list(map(float, data))
    data = rearrange(np.array(data), "(npoints ndims) -> ndims npoints", ndims = 3)
    
    return data

benchmark_rdf = read(benchmark_params[0], benchmark_params[-1])
taylor_rdf = read(taylor_params[0], taylor_params[-1])

absve_bt = np.fabs(benchmark_rdf[1] - taylor_rdf[1])

# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, 1)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = item_fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

# ! title of column 
plt.text(0.5, 1.15, "Radial distribution function", transform = ax.transAxes, fontsize = title_fontsize.get_size(), fontweight = 'bold', va = 'center', ha = 'center')

plt.xlabel(f'radius ($\\mathrm{{\AA}}$)', fontproperties = fontsize, fontweight = 'bold')
plt.plot(benchmark_rdf[0], benchmark_rdf[1], label = f"{benchmark_name}", linewidth = 2)
plt.plot(taylor_rdf[0], taylor_rdf[1], label = f"{model_name}", linewidth = 1, alpha = 0.7)

if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 1)

In [None]:
# velocity distribution of all atom
bix, biv, bitype, biid, bim  = benchmark_init_state
blx, blv, bltype, blid, blm  = benchmark_last_state

tix, tiv, titype, tiid, tim  = taylor_init_state
tlx, tlv, tltype, tlid, tlm  = taylor_last_state

type, st, en = group.split(" ")
if type == 'id':
    mask = np.logical_and(biid >= int(st), biid <= int(en))
elif type == 'type':
    mask = np.logical_and(bitype >= int(st), bitype <= int(en))
    
blv_norm = np.sqrt(np.sum(blv**2, axis = -1))
tlv_norm = np.sqrt(np.sum(tlv**2, axis = -1))

axis_left, axis_right = min(blv.min(), tlv.min()), max(blv.max(), tlv.max())

norm_axis_left, norm_axis_right = min(blv_norm.min(), tlv_norm.min()), max(blv_norm.max(), tlv_norm.max())

# ! Picture idx
pic_idx += 1
subgraph_item += 1


ax = plt.subplot(row, col, pic_idx)
print(ax)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')

# ! title of column 
plt.text(0.5, 1.15, "Velocity distribution", transform = ax.transAxes, fontsize = title_fontsize.get_size(), fontweight = 'bold', va = 'center', ha = 'center')

plt.hist([blv.reshape(-1), tlv.reshape(-1)], bins = 20, density = True, histtype = 'bar', range = (axis_left, axis_right), edgecolor='white', color = ['b', 'orange'], align = 'mid', label = [f'{benchmark_name}', f'{model_name}'], rwidth = 1.0)

plt.xlabel(f'velocity ($\\mathrm{{\\AA \cdot fs^{{-1}}}}$)', fontproperties = fontsize, fontweight = 'bold')

plt.tick_params(axis = 'y', labelsize = tick_fontsize.get_size() - 1)
plt.xticks([-5e-3, 0, 5e-3]); plt.xlim(left = -0.0057, right = 0.0057); plt.ylim(top = 260)


if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'upper left', ncol = 1)

In [None]:
plt.subplots_adjust(**tight_layout_arg)
plt.show()