In [1]:
%env CUDA_VISIBLE_DEVICES=2

env: CUDA_VISIBLE_DEVICES=2


In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
from ase.io import read
import torch
from tqdm import tqdm
from glob import glob
from pathlib import Path
from itertools import product
from mace import data, tools
from mace.tools import torch_geometric
from mace.tools.torch_tools import to_numpy
from contextlib import contextmanager

In [3]:
test_files = sorted(glob("/home/hatemh/BOTNet-datasets/dataset_3BPA/test_*"))
checkpoint_files = sorted(glob("/home/hatemh/mace/experiments/checkpoints/*.model"))
num_tests = len(test_files)
num_checkpoints = len(checkpoint_files)
total_combinations = num_tests * num_checkpoints
total_combinations

96

In [4]:
z_table = tools.utils.AtomicNumberTable([1, 6, 7, 8])
r_max = 5.0
device = "cuda"

In [5]:
def parse_clip_grad(name: str):
    part = list(filter(lambda x: x.startswith("cg"), name.split("_")))

    if len(part) == 0:
        return -1

    return float(part[0].split("-")[-1])


def parse_loss_scale(name: str):
    part = list(filter(lambda x: x.startswith("ls"), name.split("_")))

    if len(part) == 0:
        return 1.0

    return float(part[0].split("-")[-1])


def parse_values(checkpoint_file: str):
    name = Path(checkpoint_file).stem
    seed = int(name.split("-")[-1])
    dtype = "float32" if name.find("fp64") == -1 else "float64"
    clip_grad = parse_clip_grad(name)
    loss_scale = parse_loss_scale(name)
    return {
        "seed": seed,
        "dtype": dtype,
        "clip_grad": clip_grad,
        "loss_scale": loss_scale,
    }


def make_loader(file: str):
    atoms_list = read(file, index=":")
    configs = data.utils.config_from_atoms_list(atoms_list)
    dataset = [
        data.AtomicData.from_config(c, z_table=z_table, cutoff=r_max) for c in configs
    ]
    loader = torch_geometric.dataloader.DataLoader(
        dataset=dataset, batch_size=32, shuffle=False, drop_last=False
    )
    return loader


@contextmanager
def default_dtype(dtype: torch.dtype):
    init = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(init)


def eval_model(checkpoint_file: str, test_file: str, device: str = device):
    model = torch.load(checkpoint_file)
    model.to(device)
    model_dtype = {p.dtype for p in model.parameters()}.pop()

    with default_dtype(model_dtype):
        loader = make_loader(test_file)
        errors = []

        for batch in tqdm(loader):
            batch = batch.to(device)
            target = to_numpy(batch.energy)
            output = model(batch.to_dict())
            pred = to_numpy(output["energy"])
            errors.append(np.abs(target - pred))
        
    del model
    del loader
    torch.cuda.empty_cache()
    
    errors = 1e3 * np.concatenate(errors, axis=0)
        
    
    return {
        **parse_values(checkpoint_file),
        "dataset": Path(test_file).stem,
        "mae (meV)": np.mean(errors),
        "rmse (meV)": np.sqrt(np.mean(errors**2)),
    }

In [6]:
records = [
    eval_model(cf, tf)
    for cf, tf in tqdm(product(checkpoint_files, test_files), total=total_combinations)
]
df = pd.DataFrame(records)
df

100%|██████████| 67/67 [00:19<00:00,  3.38it/s]
100%|██████████| 53/53 [00:12<00:00,  4.15it/s]
100%|██████████| 67/67 [00:16<00:00,  4.06it/s]
100%|██████████| 221/221 [00:53<00:00,  4.14it/s]
100%|██████████| 67/67 [00:16<00:00,  4.15it/s]
100%|██████████| 53/53 [00:12<00:00,  4.18it/s]
100%|██████████| 67/67 [00:16<00:00,  4.10it/s]
100%|██████████| 221/221 [00:52<00:00,  4.19it/s]
100%|██████████| 67/67 [00:16<00:00,  4.14it/s]
100%|██████████| 53/53 [00:12<00:00,  4.18it/s]
100%|██████████| 67/67 [00:16<00:00,  4.10it/s]
100%|██████████| 221/221 [00:52<00:00,  4.20it/s]
100%|██████████| 67/67 [00:21<00:00,  3.05it/s]
100%|██████████| 53/53 [00:16<00:00,  3.21it/s]
100%|██████████| 67/67 [00:21<00:00,  3.14it/s]
100%|██████████| 221/221 [01:08<00:00,  3.22it/s]
100%|██████████| 67/67 [00:20<00:00,  3.19it/s]
100%|██████████| 53/53 [00:16<00:00,  3.21it/s]
100%|██████████| 67/67 [00:21<00:00,  3.14it/s]
100%|██████████| 221/221 [01:08<00:00,  3.22it/s]
100%|██████████| 67/67 [00:20<

Unnamed: 0,seed,dtype,clip_grad,loss_scale,dataset,mae (meV),rmse (meV)
0,1702,float32,10.0,1.0,test_1200K,34.625736,49.534935
1,1702,float32,10.0,1.0,test_300K,3.413580,5.112343
2,1702,float32,10.0,1.0,test_600K,8.639244,13.212955
3,1702,float32,10.0,1.0,test_dih,13.687396,36.273258
4,285,float32,10.0,1.0,test_1200K,44.702732,62.189041
...,...,...,...,...,...,...,...
91,285,float32,-1.0,1.0,test_dih,7.106304,16.125027
92,43,float32,-1.0,1.0,test_1200K,24.524092,35.421421
93,43,float32,-1.0,1.0,test_300K,5.227447,6.760247
94,43,float32,-1.0,1.0,test_600K,8.650207,13.371348


In [7]:
sweep_results = df.groupby(["dtype", "clip_grad", "loss_scale", "dataset"])["rmse (meV)"].describe()
sweep_results.sort_values(by=["mean"])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,count,mean,std,min,25%,50%,75%,max
dtype,clip_grad,loss_scale,dataset,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
float64,-1.0,1000.0,test_300K,3.0,4.957189,0.984851,4.014803,4.445974,4.877144,5.428382,5.979619
float64,-1.0,1.0,test_300K,3.0,5.015252,0.596109,4.425874,4.713942,5.002009,5.309941,5.617873
float32,-1.0,1000.0,test_300K,3.0,5.471029,0.629171,4.773877,5.208226,5.642576,5.819606,5.996635
float32,-1.0,1.0,test_300K,3.0,5.534762,1.539809,3.806381,4.92202,6.037659,6.398953,6.760247
float64,10.0,1.0,test_300K,3.0,6.292116,1.778453,4.55051,5.385545,6.22058,7.162919,8.105258
float32,10.0,1.0,test_300K,3.0,6.877461,2.222139,5.112343,5.629774,6.147205,7.76002,9.372835
float64,-1.0,1.0,test_dih,3.0,10.877778,1.550219,9.901228,9.984037,10.066846,11.366053,12.665261
float64,-1.0,1000.0,test_dih,3.0,10.948316,1.433837,10.033759,10.122061,10.210363,11.405594,12.600825
float32,-1.0,1000.0,test_600K,3.0,12.327121,0.76805,11.442189,12.080528,12.718867,12.769587,12.820307
float32,-1.0,1.0,test_600K,3.0,12.44514,1.203852,11.084336,11.982036,12.879735,13.125542,13.371348
