In [15]:
import torch
import pandas as pd
from processing.utils import filter_data_by_properties,select_structures
from processing.interpolation.Interpolation import apply_interpolation
from run_sigopt_experiment import build_sigopt_name
from processing.dataloader.dataloader import get_dataloader
from processing.create_model.create_model import create_model
from training.loss import contrastive_loss
from training.evaluate import evaluate_model

In [12]:
gpu_num = 0
device_name = "cuda:" + str(gpu_num)
device = torch.device(device_name)
torch.cuda.set_device(device)

In [2]:
target_prop = 'dft_e_hull'
is_relaxed = True
interpolation = True

In [3]:
train_data = pd.read_json('data/training_set.json')

test_data = pd.read_json('data/test_set.json')
# test_data = pd.read_json('data/holdout_set_B_sites.json')
# test_data = pd.read_json('data/holdout_set_series.json')

In [4]:
data = [train_data, test_data]
processed_data = []

for dataset in data:
    dataset = filter_data_by_properties(dataset,target_prop)

    if is_relaxed:
        dataset = select_structures(dataset,"relaxed")
    else:
        dataset = select_structures(dataset,"unrelaxed")

    if interpolation:
        dataset = apply_interpolation(dataset,target_prop)

    processed_data.append(dataset)

train_data = processed_data[0]
test_data = processed_data[1]

In [5]:
model_type = 'CGCNN'
batch_size = 10

# sigopt_name = build_sigopt_name(target_prop,is_relaxed,interpolation,model_type)
best_model_path = './saved_models_tmp_for_testing_inference/' + model_type # Placeholder for now

In [10]:
train_loader = get_dataloader(train_data, target_prop, model_type, batch_size, interpolation)
test_loader = get_dataloader(test_data, target_prop, model_type, 1, interpolation)

100%|██████████| 4298/4298 [00:01<00:00, 2162.58it/s]
100%|██████████| 1395/1395 [00:00<00:00, 2228.86it/s]


In [None]:
if model_type == "Painn":
    best_model = torch.load(best_model_path + "/best_model", map_location=device)
    normalizer = None
else:
    best_model, normalizer = create_model(model_type, train_loader)
    best_model.load_state_dict(torch.load(best_model_path + "/best_model.torch", map_location=device)['state'])

best_model = best_model.to(device)

In [16]:
if "contrastive" in model_type:
    loss_fn = contrastive_loss 
else:
    loss_fn = torch.nn.L1Loss()

In [20]:
test_result = evaluate_model(best_model, normalizer, model_type, test_loader, loss_fn, gpu_num)

In [28]:
predictions = test_result[0].detach().cpu().numpy()
targets = test_result[1].detach().cpu().numpy()

if interpolation:
    # TODO: Add interpolation

if "contrastive" in model_type:
    test_mae =test_result[2].detach().cpu().numpy()[1]
else:
    test_mae =test_result[2].detach().cpu().numpy()[0]

[-0.10181162,
 -0.11567287,
 -0.11168271,
 -0.08787635,
 -0.10790588,
 -0.07818812,
 -0.086743996,
 -0.08022143,
 -0.106173664,
 -0.07107455,
 -0.11154953,
 -0.07659783,
 -0.111201376,
 -0.09633633,
 -0.08506097,
 -0.100638,
 -0.10382499,
 -0.110770315,
 -0.013722036,
 -0.10768752,
 -0.12681136,
 -0.091164544,
 -0.093467414,
 -0.08913226,
 -0.08387511,
 -0.10631174,
 -0.101822525,
 -0.09404172,
 -0.07648434,
 -0.08493658,
 -0.122175425,
 -0.08706373,
 -0.09117818,
 -0.087539494,
 -0.09472328,
 -0.11096968,
 -0.08335557,
 -0.080961704,
 -0.105945215,
 -0.10930936,
 -0.098663956,
 -0.089232065,
 -0.104527906,
 -0.052381516,
 -0.080280066,
 -0.07965171,
 -0.09429461,
 -0.09571065,
 -0.082287386,
 -0.03937695,
 -0.10384087,
 -0.09053218,
 -0.08276685,
 -0.08934067,
 -0.07977376,
 -0.09057957,
 -0.08120884,
 -0.1043812,
 -0.096072465,
 -0.108489394,
 -0.086414196,
 -0.08398304,
 -0.0630749,
 -0.102858335,
 -0.09737562,
 -0.08756276,
 -0.086157195,
 -0.052565046,
 -0.10040212,
 -0.12287614,


In [29]:
test_data

Unnamed: 0,formula,framework,composition,n_atoms_unrelaxed,n_atoms_opt,unrelaxed_cryst_id,unrelaxed_struct,opt_cryst_id,opt_struct,es_job_id,dft_energy,dft_energy_per_atom,dft_e_hull,structure,dft_e_hull_interp,dft_e_hull_diff,ase_structure,idx
0,Ba1Ca7Co4Ni4O24,Ba0.125Ca0.875Co0.500Ni0.500O3,"{'sites': {'A': ['Ca', 'Ba'], 'B': ['Ni', 'Co'...",40,40,222996264,"{'@module': 'pymatgen.core.structure', '@class...",253605952,"{'@module': 'pymatgen.core.structure', '@class...",73347615,-224.385261,-5.609632,0.203860,"[[-0.00052024 0.00983376 0.00983376] Ba, [-9...",0.179466,0.024393,"(Atom('Ba', [-0.000520236127089, 0.00983375719...",0
1,Ba1Ca7Co4Ni4O24,Ba0.125Ca0.875Co0.500Ni0.500O3,"{'sites': {'A': ['Ca', 'Ba'], 'B': ['Ni', 'Co'...",40,40,222996266,"{'@module': 'pymatgen.core.structure', '@class...",253416559,"{'@module': 'pymatgen.core.structure', '@class...",73269340,-224.684450,-5.617111,0.196380,"[[0.11104373 0. 0. ] Ba, [-0.18...",0.179466,0.016914,"(Atom('Ba', [0.1110437342278133, 0.0, 0.0], ma...",1
2,Ba1Ca7Co4Ni4O24,Ba0.125Ca0.875Co0.500Ni0.500O3,"{'sites': {'A': ['Ca', 'Ba'], 'B': ['Ni', 'Co'...",40,40,222996268,"{'@module': 'pymatgen.core.structure', '@class...",253605948,"{'@module': 'pymatgen.core.structure', '@class...",73347612,-223.010442,-5.575261,0.238230,[[9.56871539e-06 9.56871539e-06 9.56871539e-06...,0.179466,0.058764,"(Atom('Ba', [9.5687153875e-06, 9.5687153875e-0...",2
3,Ba1Ca7Co4Ni4O24,Ba0.125Ca0.875Co0.500Ni0.500O3,"{'sites': {'A': ['Ca', 'Ba'], 'B': ['Ni', 'Co'...",40,40,222996265,"{'@module': 'pymatgen.core.structure', '@class...",252541128,"{'@module': 'pymatgen.core.structure', '@class...",72945680,-224.844699,-5.621117,0.192374,[[-4.78789493e-02 -4.78788480e-02 9.59606213e...,0.179466,0.012907,"(Atom('Ba', [-0.0478789492895896, -0.047878848...",3
4,Ba1Ca7Co4Ni4O24,Ba0.125Ca0.875Co0.500Ni0.500O3,"{'sites': {'A': ['Ca', 'Ba'], 'B': ['Ni', 'Co'...",40,40,222996267,"{'@module': 'pymatgen.core.structure', '@class...",252299215,"{'@module': 'pymatgen.core.structure', '@class...",72945179,-223.832115,-5.595803,0.217688,[[9.61961797e-06 9.61961797e-06 9.50698259e-06...,0.179466,0.038222,"(Atom('Ba', [9.619617975e-06, 9.619617975e-06,...",4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1390,Ga4O24Pd4Yb8,YbGa0.500Pd0.500O3,"{'sites': {'A': ['Yb'], 'B': ['Ga', 'Pd'], 'X'...",40,40,205590756,"{'@module': 'pymatgen.core.structure', '@class...",209625073,"{'@module': 'pymatgen.core.structure', '@class...",66018043,-232.698687,-5.817467,0.124067,[[1.15683360e-05 1.15517717e-05 1.15517717e-05...,0.187711,-0.063644,"(Atom('Ga', [1.15683360176e-05, 1.1551771696e-...",1390
1391,Mo3O24Rh5Yb8,YbMo0.375Rh0.625O3,"{'sites': {'A': ['Yb'], 'B': ['Mo', 'Rh'], 'X'...",40,40,205590519,"{'@module': 'pymatgen.core.structure', '@class...",213294660,"{'@module': 'pymatgen.core.structure', '@class...",69733467,-266.051457,-6.651286,0.221636,[[1.15007168e-05 1.15007168e-05 1.16969033e-05...,0.282433,-0.060797,"(Atom('Mo', [1.150071678e-05, 1.150071678e-05,...",1391
1392,Nb5O24Ru3Yb8,YbNb0.625Ru0.375O3,"{'sites': {'A': ['Yb'], 'B': ['Nb', 'Ru'], 'X'...",40,40,205592102,"{'@module': 'pymatgen.core.structure', '@class...",213096627,"{'@module': 'pymatgen.core.structure', '@class...",69733096,-309.251461,-7.731287,0.086486,"[[1.9579765 5.89298174 5.89298174] Nb, [5.873...",0.181544,-0.095058,"(Atom('Nb', [1.9579765017664261, 5.89298173889...",1392
1393,O24Ru3Sc5Yb8,YbRu0.375Sc0.625O3,"{'sites': {'A': ['Yb'], 'B': ['Ru', 'Sc'], 'X'...",40,40,205592215,"{'@module': 'pymatgen.core.structure', '@class...",213163645,"{'@module': 'pymatgen.core.structure', '@class...",71034563,-293.522982,-7.338075,0.116335,[[1.17474953e-05 1.17474953e-05 5.90615119e+00...,0.266070,-0.149736,"(Atom('O', [1.17474952518e-05, 1.17474952518e-...",1393
