In [1]:
from ase.io import iread 

samples=iread('/home/emi/ML/simpleGNN/simplegnn/schnet_test/pressure_test/dataset_with_stress_energy_forces.extxyz', format='extxyz')
sample=next(samples)

sample

Atoms(symbols='O140Si70', pbc=True, cell=[12.5, 12.5, 12.5], calculator=SinglePointCalculator(...))

In [2]:
from simplegnn.schnet_calculator import SchNetCalculator

In [3]:
from simplegnn.schnet import SchNetModel
import torch
cutoff=5.0
num_gaussians=60
hidden_dim=100
num_interactions=3
num_filters=100
model = SchNetModel(hidden_dim=hidden_dim, num_gaussians=num_gaussians, 
                    num_filters=num_filters, num_interactions=num_interactions, cutoff=cutoff)


model_path='/home/emi/ML/simpleGNN/simplegnn/schnet_test/pressure_test/model_schnet_torch_full_stepLR.pth'

model.load_state_dict(torch.load(model_path,weights_only=True))

<All keys matched successfully>

In [4]:
sample.calc=SchNetCalculator(model=model, device='cuda', cutoff=cutoff)

In [5]:
sample.get_forces()

array([[-1.55686665e+00,  1.60748982e+00,  2.77963161e+00],
       [-7.42225647e+00,  8.65026534e-01,  8.05118799e-01],
       [ 4.62892652e-01, -1.48293877e+00, -3.82932854e+00],
       [-6.73227692e+00,  3.01396656e+00,  1.06998234e+01],
       [ 2.57364130e+00,  3.98089933e+00, -4.17519951e+00],
       [ 3.12799788e+00,  9.06673670e-02,  2.90984583e+00],
       [ 3.32760143e+00,  2.25375605e+00,  2.54430151e+00],
       [-2.79948378e+00,  3.92904520e+00, -2.71374178e+00],
       [-3.30022407e+00,  6.28093338e+00,  1.08550596e+00],
       [-1.03351736e+00, -3.12837929e-01,  2.43773913e+00],
       [-3.88828492e+00,  6.61169887e-01,  3.22870684e+00],
       [-3.90161395e-01,  2.01483870e+00, -5.73840737e-01],
       [-2.67237473e+00, -2.39814758e-01,  4.66731548e-01],
       [ 3.67334914e+00,  3.29642534e-01,  3.22386599e+00],
       [ 9.40057874e-01,  1.22271562e+00, -2.76904202e+00],
       [-4.31134748e+00, -4.48660314e-01,  2.54892826e+00],
       [-1.41597199e+00,  1.68830180e+00

In [6]:
sample.calc.results

{'energy': -792.6171875,
 'forces': array([[-1.55686665e+00,  1.60748982e+00,  2.77963161e+00],
        [-7.42225647e+00,  8.65026534e-01,  8.05118799e-01],
        [ 4.62892652e-01, -1.48293877e+00, -3.82932854e+00],
        [-6.73227692e+00,  3.01396656e+00,  1.06998234e+01],
        [ 2.57364130e+00,  3.98089933e+00, -4.17519951e+00],
        [ 3.12799788e+00,  9.06673670e-02,  2.90984583e+00],
        [ 3.32760143e+00,  2.25375605e+00,  2.54430151e+00],
        [-2.79948378e+00,  3.92904520e+00, -2.71374178e+00],
        [-3.30022407e+00,  6.28093338e+00,  1.08550596e+00],
        [-1.03351736e+00, -3.12837929e-01,  2.43773913e+00],
        [-3.88828492e+00,  6.61169887e-01,  3.22870684e+00],
        [-3.90161395e-01,  2.01483870e+00, -5.73840737e-01],
        [-2.67237473e+00, -2.39814758e-01,  4.66731548e-01],
        [ 3.67334914e+00,  3.29642534e-01,  3.22386599e+00],
        [ 9.40057874e-01,  1.22271562e+00, -2.76904202e+00],
        [-4.31134748e+00, -4.48660314e-01,  2.5489

In [7]:
sample.info

{'source_dir': 'run_0',
 'virial_eV_s6': array([842.4187 , 783.21333, 661.19142, 109.84022, -80.41969,  10.41788]),
 'virial_eV_tensor': array([[842.4187 , 109.84022,  10.41788],
        [109.84022, 783.21333, -80.41969],
        [ 10.41788, -80.41969, 661.19142]]),
 'pressure_kB_s6': array([691.04852, 642.48148, 542.3851 ,  90.10356, -65.96947,   8.54594])}

In [8]:
sample.get_stress()

array([ 0.49474493,  0.47504511,  0.38069385, -0.03512139,  0.01345136,
        0.05347294])

In [9]:
sample.get_stress()*1602.176634

array([792.66876133, 761.10618296, 609.93874808, -56.27069226,
        21.55145943,  85.67309482])

In [10]:
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution,Stationary
from ase.md.nvtberendsen import NVTBerendsen
from ase.md import MDLogger
from ase import units
from time import perf_counter

In [11]:
# input parameters
time_step    = 1.0    # fsec
num_md_steps = 5000
num_interval = 10
taut         = 1.0    # fs
temperature = 2000    # Kelvin

In [12]:
units.GPa

0.006241509125883258

In [13]:
1/units.GPa

160.21766208

In [14]:
print("temperature = ",temperature)

temperature_str = str(int(temperature)).zfill(4)
output_filename = f"output"
log_filename = output_filename + ".log"
traj_filename = output_filename + ".traj"
print("log_filename = ",log_filename)
print("traj_filename = ",traj_filename)

atoms = sample.copy()
atoms.calc = SchNetCalculator(model=model, device='cuda', cutoff=cutoff)

# Set the momenta corresponding to T=300K
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature,force_temp=True)
Stationary(atoms)

dyn = NVTBerendsen(atoms, time_step*units.fs, 
                   temperature_K = temperature, taut=taut*units.fs, 
                   loginterval=num_interval, trajectory=traj_filename)


# Print statements
def print_dyn():
    imd = dyn.get_number_of_steps()
    etot  = atoms.get_total_energy()
    temp_K = atoms.get_temperature()
    stress = atoms.get_stress(include_ideal_gas=True)/units.GPa
    stress_ave = (stress[0]+stress[1]+stress[2])/3.0
    elapsed_time = perf_counter() - start_time
    print(f"  {imd: >3}   {etot:.3f}    {temp_K:.2f}    {stress_ave:.2f}  {stress[0]:.2f}  {stress[1]:.2f}  {stress[2]:.2f}  {stress[3]:.2f}  {stress[4]:.2f}  {stress[5]:.2f}    {elapsed_time:.3f}")


dyn.attach(print_dyn, interval=num_interval)
dyn.attach(MDLogger(dyn, atoms, output_filename+".log", header=True, stress=True, peratom=True, mode="a"), interval=num_interval)

# Now run the dynamics
start_time = perf_counter()
print(f"    imd     Etot(eV)    T(K)    stress(mean,xx,yy,zz,yz,xz,xy)(GPa)  elapsed_time(sec)")
dyn.run(num_md_steps)

temperature =  2000
log_filename =  output.log
traj_filename =  output.traj
    imd     Etot(eV)    T(K)    stress(mean,xx,yy,zz,yz,xz,xy)(GPa)  elapsed_time(sec)
    0   -738.328    2000.00    69.15  76.16  72.94  58.36  -5.61  1.93  8.47    0.029
   10   -882.853    3286.59    42.61  41.92  47.73  38.19  -0.19  3.26  5.35    0.233
   20   -1020.766    2362.77    34.43  32.81  38.95  31.52  -1.74  4.17  2.00    0.437
   30   -1107.826    2290.12    31.54  29.11  35.63  29.89  -1.35  3.88  -1.27    0.641
   40   -1177.723    2200.01    30.78  28.43  35.48  28.44  0.28  5.40  -2.37    0.844
   50   -1221.462    2127.34    28.03  26.15  31.10  26.85  1.07  4.04  0.96    1.048
   60   -1250.626    2108.49    23.92  22.32  25.15  24.28  1.70  2.67  1.65    1.250
   70   -1276.639    2090.96    21.85  20.10  23.93  21.52  0.33  1.90  2.09    1.454
   80   -1297.581    2046.61    23.26  20.87  25.90  23.00  0.14  0.70  2.93    1.658
   90   -1310.221    2046.37    23.35  21.99  23.42  24.64 

True

In [15]:
import nglview as nv
from ase.io import read

frames = read('output.traj', ':')  # 全フレームを読み込み（大きい場合は ':' を '::10' などに）
view = nv.show_asetraj(frames)
view.add_unitcell()            # セルを表示
view.add_spacefill()           # 表示スタイルは好みで（ball+stick など）
view



NGLWidget(max_frame=500)