In [1]:
from ocpmodels.common.relaxation.ase_utils import OCPCalculator
import os
import torch
from ase.db import connect
from tqdm import tqdm
import numpy as np
import time

In [2]:
def get_pt_files(directory):
    pt_files = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(".pt"):
                pt_files.append(os.path.join(root, file))
    return pt_files

In [3]:
db = connect("1model_compare.db")
atoms_l = [row.toatoms() for row in db.select()]
f_dft_l = [atoms.get_forces() for atoms in atoms_l]

In [4]:
directory = "web_ml_ocp/checkpoint"
pt_files = get_pt_files(directory)

normalizer = {'normalize_labels': True,
              'target_mean': -0.7554450631141663,
              'target_std': 2.887317180633545,
              'grad_target_mean': 0.0,
              'grad_target_std': 2.887317180633545}

db = connect("1model_compare.db")

for file in pt_files:
    
    checkpoint = torch.load(file, map_location=torch.device("cpu"))
    config = checkpoint["config"]
    config["normalizer"] = normalizer
    config["amp"] = False
    
    if 'warmup_epochs' in config['optim']:
        del config['optim']['warmup_epochs']
        config['optim']['warmup_steps'] = 348786
        config['optim']['lr_milestones'] = [523179, 871966, 1220752]
    
    calc = OCPCalculator(config_yml=config, checkpoint_path=file, cpu=False)
    params = calc.trainer.model.module.num_params
    
    f_ml_l = []
    time_ini = time.time()
    for atoms in atoms_l:
        atoms.calc = calc
        f_ml_l.append(atoms.get_forces())
    time_used = time.time() - time_ini
    
    fe_l = []
    for f_dft, f_ml in zip(f_dft_l, f_ml_l):
        fe = np.linalg.norm(f_dft-f_ml, axis=1)
        fe_l.extend(fe.tolist())
    
    print(os.path.basename(file), np.max(fe_l), np.mean(fe_l), params, "%.3f"%time_used)

  return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),


spinconv_force_centric_2M.pt 0.34110718189396766 0.02317709666321024 8473371 3.165
scn_all_md_s2ef.pt 0.22927953601253248 0.02720846583820647 168921090 25.198
gemnet_t_direct_h512_2M.pt 0.27900381109225014 0.027173708963477143 31671825 1.301
schnet_200k.pt 0.04999883437251112 0.013517339493401776 5704193 0.615
escn_l6_m2_lay12_all_md_s2ef.pt 0.19775642180061 0.025513182181898287 51844608 5.469
dimenetpp_200k.pt 0.2529661740943477 0.018501487097748626 1810182 3.035
escn_l6_m2_lay12_2M_s2ef.pt 0.3085905253222206 0.023051683920072036 51844608 5.736




eq2_83M_2M.pt 0.28574598251964933 0.02526135544242244 83164802 5.625
escn_l4_m2_lay12_2M_s2ef.pt 0.2165425177967108 0.023020795310581675 36112896 5.326
scn_t1_b1_s2ef_2M.pt 0.221690568137876 0.021813690721458803 123561474 11.453
schnet_20M.pt 0.6935090640721074 0.04035769935060489 9088513 0.727
painn_h512_s2ef_all.pt 0.28298535644036255 0.02176883813218182 20073481 0.828
scn_t4_b2_s2ef_2M.pt 0.3635444927499501 0.02452655518997509 126710274 17.211




eq2_31M_ec4_allmd.pt 0.18179591022429295 0.024480499182563648 31058690 3.803
spinconv_force_centric_all.pt 0.2682144231150142 0.02463537885678031 8473371 2.733
dimenetpp_all.pt 0.3734772821838549 0.03296979773613204 1810182 3.292
escn_l6_m3_lay20_all_md_s2ef.pt 0.25788701904841865 0.03093125334368715 200234496 11.023




eq2_153M_ec4_allmd.pt 0.23705128624054386 0.028931406512820584 153602690 9.542
dimenetpp_2M.pt 0.41999582492193366 0.03373823264567013 1810182 3.154
schnet_all_large.pt 0.4113086718347453 0.06412513237039072 9088513 0.785
schnet_2M.pt 0.5521859449950031 0.03064488962206107 9088513 0.730
dimenetpp_20M.pt 0.50419820801938 0.03294546908488895 1810182 3.024




gemnet_t_direct_h512_all.pt 0.3219040724331238 0.03232864744843634 31671825 1.271
gemnet_oc_base_s2ef_all.pt 0.26670431936045075 0.03439738334800376 38864438 3.423




gemnet_oc_large_s2ef_all_md.pt 0.26698158343146094 0.03573003225206806 216408144 5.108
cgcnn_20M.pt 3.3007072074237382 0.2006684284000295 3611649 0.584




cgcnn_all.pt 3.4609059260323063 0.19497880233654433 3611649 0.567
gemnet_oc_base_s2ef_2M.pt 0.2637933305845974 0.02365721314116364 38864438 3.411
dimenet_2M.pt 0.5206047457080465 0.03564879529541184 775206 2.866




cgcnn_200k.pt 1.3294208064635973 0.07857945529963299 245889 0.496
gemnet_oc_base_s2ef_all_md.pt 0.2520762141072302 0.031208742318633294 38864438 3.528
cgcnn_2M.pt 3.451738451978899 0.22388876688510856 2127233 0.601
dimenet_200k.pt 0.6842350310410088 0.04225173051612203 725670 2.848
