In [None]:
import numpy as np

from lammps import IPyLammps, PyLammps, lammps
from lammps import LMP_STYLE_GLOBAL, LMP_TYPE_SCALAR, LMP_TYPE_VECTOR

from compute import *

model_name = 'EdSr'
control_name = 'MD'
benchmark_name = 'BM'
""" Temp Press PotEng KinEng Enthalpy E_vdwl E_coul E_pair E_bond E_angle E_dihed E_long E_tail E_mol Ecouple Econserve TotEng Lx Ly Lz"""
energy_unit = "kcal $\\cdot$ mol$^{-1}$"
press_unit = "ATM"
temperature = "K"
distance_unit = "Angstrom"
time_unit = "fs"
time_scale = {
    "fs": 1,
    "ps": 1000,
    "ns": 1000000,
}
xaxis_time_unit = 'ps'
thermo_style_unit = {
    'temp'     : f"Kelvin ({temperature})",    'Temp'     : f'Kelvin ({temperature})',
    'press'    : f'ATMosphere ({press_unit})', 'Press'    : f'ATMosphere ({press_unit})',
    "pe"       : f'energy ({energy_unit})',    'PotEng'   : f'energy ({energy_unit})',
    "ke"       : f'energy ({energy_unit})',    'KinEng'   : f'energy ({energy_unit})',
    "enthalpy" : f'energy ({energy_unit})',    'Enthalpy' : f'energy ({energy_unit})',
    "evdwl"    : f'energy ({energy_unit})',    'E_vdwl'   : f'energy ({energy_unit})',
    "ecoul"    : f'energy ({energy_unit})',    'E_coul'   : f'energy ({energy_unit})',
    "epair"    : f'energy ({energy_unit})',    'E_pair'   : f'energy ({energy_unit})',
    "ebond"    : f'energy ({energy_unit})',    'E_bond'   : f'energy ({energy_unit})',
    "eangle"   : f'energy ({energy_unit})',    'E_angle'  : f'energy ({energy_unit})',
    "edihed"   : f'energy ({energy_unit})',    'E_dihed'  : f'energy ({energy_unit})',
    "eimp"     : f'energy ({energy_unit})',
    "elong"    : f'energy ({energy_unit})',    'E_long'   : f'energy ({energy_unit})',
    "etail"    : f'energy ({energy_unit})',    'E_tail'   : f'energy ({energy_unit})',
    "emol"     : f'energy ({energy_unit})',    'E_mol'    : f'energy ({energy_unit})',
    "ecouple"  : f'energy ({energy_unit})',    'Ecouple'  : f'energy ({energy_unit})',
    "econserve": f'energy ({energy_unit})',    'Econserve': f'energy ({energy_unit})',
    "etotal"   : f'energy ({energy_unit})',    'TotEng'   : f'energy ({energy_unit})',
    'lx'       : f'length ({distance_unit})',  'Lx'       : f'length ({distance_unit})',
    'ly'       : f'length ({distance_unit})',  'Ly'       : f'length ({distance_unit})',
    'lz'       : f'length ({distance_unit})',  'Lz'       : f'length ({distance_unit})',
}

title_mapping = {
     'temp'     : r'Temperature',                                                'Temp'     : r'Temperature',
     'press'    : r'Pressure',                                                   'Press'    : r'Pressure',
     "pe"       : r'Potential energy',                                           'PotEng'   : r'potential energy',
     "ke"       : r'Kinetic energy',                                             'KinEng'   : r'Kinetic energy',
     "enthalpy" : r'Total energy (pe + ke)',                                     'Enthalpy' : r'Total energy (pe + ke)',
     "evdwl"    : r'Van der Waals pairwise energy',                              'E_vdwl'   : r'Van der Waals pairwise energy',
     "ecoul"    : r'Coulombic pairwise energy',                                  'E_coul'   : r'Coulombic pairwise energy',
     "epair"    : r'Pairwise energy',                                            'E_pair'   : r'Pairwise energy',
     "ebond"    : r'Bond energy',                                                'E_bond'   : r'Bond energy',
     "eangle"   : r'Angle energy',                                               'E_angle'  : r'Angle energy',
     "edihed"   : r'Dihedral energy',                                            'E_dihed'  : r'Dihedral energy',
     "eimp"     : r'Improper energy',
     "elong"    : r'Long-range kspace energy',                                   'E_long'   : r'Long-range kspace energy',
     "etail"    : r'Van der Waals energy long-range tail correction',            'E_tail'   : r'Van der Waals energy long-range tail correction',
     "emol"     : r'Intramolecular energy',                                      'E_mol'    : r'Intramolecular energy',
     "ecouple"  : r'Cumulative energy change due to thermo/baro statting fixes', 'Ecouple'  : r'Cumulative energy change due to thermo/baro statting fixes',
     "econserve": r'Etotal + ecouple',                                           'Econserve': r'Etotal + ecouple',
     "etotal"   : r'Total energy',                                               'TotEng'   : r'Total energy',
     'lx'       : r'Length of x-axis',                                           'Lx'       : r'Length of x-axis',
     'ly'       : r'Length of y-axis',                                           'Ly'       : r'Length of y-axis',
     'lz'       : r'Length of z-axis',                                           'Lz'       : r'Length of z-axis',
     'Rg'       : r'Radius of gyration',                                         'RG'       : r'Radius of gyration',
     'rmsd'     : r'RMSD',                                                       'RMSD'     : r'RMSD',
     'msd'      : r'MSD',                                                        'MSD'      : r'MSD',
}

In [None]:
thermo_style = [
    'custom', 'step', 'time', 'spcpu',
    'temp', 'press',
    'pe', 'ke',
    'enthalpy', 'evdwl', 'ecoul', 'epair',
    'ebond', 'eangle', 'edihed',
    'elong', 'etail', 'emol',
    'ecouple', 'econserve', 'etotal',
    'lx', 'ly', 'lz',
]
def create_simulation(timestep = 0.2, cmdargs = None, num_threads: int = 1, ensemble: str = 'nve') -> IPyLammps:

    lmp = lammps(cmdargs=cmdargs)
    
    MDsimulation = PyLammps(ptr = lmp)

    MDsimulation.enable_cmd_history = True
    if num_threads > 1:
        MDsimulation.package(f"omp {num_threads} neigh yes")
        MDsimulation.suffix('omp')
    
    MDsimulation.units('real')
    MDsimulation.atom_style('full')

    MDsimulation.pair_style('lj/cut/coul/cut 22.0')
    MDsimulation.bond_style('harmonic')
    MDsimulation.angle_style('cosine/squared')
    MDsimulation.dihedral_style('fourier')
    MDsimulation.improper_style('harmonic')

    MDsimulation.dielectric(15.0)
    MDsimulation.pair_modify('shift yes')
    MDsimulation.special_bonds('lj/coul 0.0 1.0 1.0')
    
    MDsimulation.read_data('../lmps/data/nve_protein.data')

    MDsimulation.atom_modify('sort 0 0.0') # turn off sort algorithm

    MDsimulation.group('protein id 1:163')
    MDsimulation.group('water id > 163')

    MDsimulation.neighbor('10.0 bin')
    
    MDsimulation.compute("p2w protein group/group water pair yes")
    MDsimulation.compute("w2w water group/group water pair yes")
    MDsimulation.compute("p2p protein group/group protein pair yes molecule intra")

    MDsimulation.neigh_modify('every 1 delay 0 check yes')
    if ensemble == 'nvt':
        raise NotImplementedError
    elif ensemble == 'nve':
        MDsimulation.fix('1 all nve')
    # MDsimulation.fix('eqfix all nvt temp 300.0 300.0 100.0')
        
    MDsimulation.thermo(1)

    MDsimulation.thermo_modify('lost/bond ignore')

    MDsimulation.thermo_style(' '.join(thermo_style))

    MDsimulation.timestep(timestep) # attn set timestep

    # initialize system state
    MDsimulation.run(0, 'pre yes post no')

    MDsimulation.enable_cmd_history = False

    return MDsimulation

In [None]:
# intv10
maxIter: int   = 500
intv   : int   = 20
basis  : float = 1.0
start  : int   = 1
end    : int   = 10000
group  : str   = "id 1 163"
mode   : str   = 'benchmark'

# MD setting
num_threads = 16
ensemble = 'nve'
cmdargs = ["-log", "none"] # args: https://docs.lammps.org/latest/Run_options.html

params = [
    f'data/{mode}_nve_basis{basis}_intv{intv}/' if mode != 'EdSr' else f'data/EdSr_nve_basis{basis}_intv{intv}_iter500/',
    f'frames{start}_{end}.npz', 
    f'GlobalVariable.npz',
    f'rdf_pro2wa.dat',
]

x, v, delta_t, mass, atype, id, boundary, heads, ppties, init_state, last_state = extraction(*params[:-1], filter = None)

times = np.arange(end)*intv*basis * (time_scale[time_unit]/time_scale[xaxis_time_unit])

In [None]:
def Update(Lammps: IPyLammps | PyLammps, Position: np.ndarray, Velocity: np.ndarray, Id: np.ndarray, Atomtype: np.ndarray) -> None:
    """
    define the gradient function
    """
    llnp = Lammps.lmp.numpy

    x: np.ndarray = llnp.extract_atom('x')
    v: np.ndarray = llnp.extract_atom('v')
    id: np.ndarray = llnp.extract_atom('id')
    atype: np.ndarray = llnp.extract_atom('type')

    x[:] = Position
    v[:] = Velocity
    id[:] = Id
    atype[:] = Atomtype
    

    before_id = llnp.extract_atom('id').copy()

    Lammps.run(0, 'pre yes post no');

    after_id = llnp.extract_atom('id')

    assert (before_id == after_id).all(), 'array has been changed !!!!!'

    return 

In [None]:
simulation = create_simulation(basis * intv, cmdargs = list(cmdargs), num_threads = num_threads, ensemble = ensemble)
llnp = simulation.lmp.numpy

In [None]:
protein2water = []
water2water = []

In [None]:
for idx in range(0, x.shape[0]):
    Update(simulation, x[idx], v[idx], id, atype)
    p2w = llnp.extract_compute('p2w', LMP_STYLE_GLOBAL, LMP_TYPE_SCALAR)
    w2w = llnp.extract_compute('w2w', LMP_STYLE_GLOBAL, LMP_TYPE_SCALAR)
    protein2water.append(p2w); water2water.append(w2w)

In [None]:
protein2water = np.array(protein2water)
water2water = np.array(water2water)

In [None]:
data = np.stack([times, protein2water, water2water], axis = 0)
np.save(f'{params[0]}/compute.npy', data)