In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('/nfs/homedirs/ket/uq4molecules/DimeNet/dimenet_pytorch/')

In [2]:
import yaml

In [3]:
from utils import init_and_load_model, get_id_dataloader, get_ood_dataloader
from uq_model import UQModel
from utils_metrics import get_maes_and_calibration_evidential, get_uncertainties_evidential, get_metrics_and_uncertainties_dropout
from ood_detection import anomaly_detection

In [4]:
with open('evaluation/seml/configs/config_dropout_qm7x.yaml', 'r') as c:
    config = yaml.safe_load(c)
# For strings that yaml doesn't parse (e.g. None)
for key, val in config.items():
    if type(val) is str:
        try:
            config[key] = ast.literal_eval(val)
        except (ValueError, SyntaxError):
            pass

In [5]:
model_name = config['fixed']['model_name']
model_params = config['fixed']['model_params']
path_to_trained = './logs/20220912_133237_6bFxfh_QM7X_final' # dropout dimenet++ qm7x eq
#path_to_trained = './logs/20220912_133758_dTvWL5_QM7X_final' # dropout dimenet++ qm7x non-eq
#path_to_trained = './logs/20220912_163902_EBf1fL_MD17_final' # dropout dimenet++ md17 aspirin
#path_to_trained = './logs/20220914_095120_WU9P2c_QM7X_final' # evidential dimenet++ qm7x eq
suffix = '/best/model.pth' 

dataset_name = config['fixed']['dataset_name']
dataset_params = config['fixed']['dataset_params']


In [6]:
model = init_and_load_model(model_name, model_params, path_to_trained+suffix)

In [7]:
uq_model = UQModel(model_name, model, n_mc_dropout_runs=150)

In [8]:
id_loader = get_id_dataloader(dataset_name, dataset_params)

In [9]:
ood_loader = get_ood_dataloader(dataset_name, dataset_params)

In [20]:
energy_mae_id, forces_mae_id, calibration_id, energy_uncertainties_id, dets_id, traces_id, largest_eigs_id = get_metrics_and_uncertainties_dropout(uq_model, id_loader)


 62%|██████▏   | 3082/5000 [00:13<00:35, 54.39it/s]  

cov_mat: tensor([[ 1.3296e+06, -1.2224e+06,  3.0229e+05, -2.6082e+02, -8.0480e+03,
          4.5136e+03,  9.9056e+03, -1.7982e+03,  1.8340e+03, -1.3312e+06,
          1.2298e+06, -3.1985e+05,  6.0013e+02,  1.1127e+03,  7.7723e+03,
          4.1149e+02, -5.9707e+02,  2.3833e+01, -9.0712e+03,  1.9102e+03,
          3.4141e+03],
        [-1.2224e+06,  1.1310e+06, -2.8057e+05, -4.7646e+02,  7.9491e+03,
         -4.3004e+03, -1.0320e+04,  2.3079e+03, -1.6678e+03,  1.2249e+06,
         -1.1387e+06,  2.9731e+05, -5.6995e+02, -1.0292e+03, -7.2481e+03,
         -2.2354e+02,  5.6744e+02, -1.0944e+01,  9.1483e+03, -2.0464e+03,
         -3.5039e+03],
        [ 3.0229e+05, -2.8057e+05,  7.2053e+04,  1.3935e+03, -3.0750e+03,
          8.7888e+02,  2.4026e+03, -3.8764e+02,  5.0098e+02, -3.0384e+05,
          2.8338e+05, -7.5971e+04,  2.0427e+01,  3.1849e+02,  1.6352e+03,
          8.4399e+01, -1.5489e+02,  1.6600e+01, -2.3533e+03,  4.9374e+02,
          8.8684e+02],
        [-2.6082e+02, -4.7646e+02,

100%|██████████| 5000/5000 [00:15<00:00, 327.23it/s] 

energy_uncertainties:[tensor(73.5583)]
dets:[tensor(1.1033e-21, dtype=torch.float64)]
traces:[tensor(5112817.1385, dtype=torch.float64)]
largest_eigs:[tensor(5083093.2725, dtype=torch.float64)]





In [26]:
import torch
torch.cat([torch.Tensor([1.0]), torch.Tensor([2.0])])

tensor([1., 2.])

In [14]:
energy_mae, forces_mae, calibration = get_maes_and_calibration(uq_model, id_loader)
print(f'energy_mae: {energy_mae}. forces_mae: {forces_mae}, calibration:{calibration}')

energy_mae: 0.01884382776916027. forces_mae: 0.022280756384134293, calibration:2.9770755767822266


In [34]:
uncertainties_id_1, uncertainties_id_2 = get_uncertainties(uq_model, id_loader)
uncertainties_ood_1, uncertainties_ood_2 = get_uncertainties(uq_model, ood_loader)

In [39]:
roc_1 = anomaly_detection(1/uncertainties_id_1, 1/uncertainties_ood_1, score_type='AUROC')
apr_1 = anomaly_detection(1/uncertainties_id_1, 1/uncertainties_ood_1, score_type='APR')
print(f'roc_1={roc_1}, apr_1={apr_1}')

roc_1=0.58604892, apr_1=0.5395716203931644


In [31]:
output_dict = {}
print(output_dict)

{}


In [33]:
output_dict['a'] = 1
print(output_dict)
output_dict['b'] = 2
print(output_dict)

{'a': 1}
{'a': 1, 'b': 2}


In [40]:
roc_2 = anomaly_detection(1/uncertainties_id_2, 1/uncertainties_ood_2, score_type='AUROC')
apr_2 = anomaly_detection(1/uncertainties_id_2, 1/uncertainties_ood_2, score_type='APR')
print(f'roc_2={roc_2}, apr_2={apr_2}')

roc_2=0.45485764000000006, apr_2=0.44588229576946925


In [11]:
roc1, pr1, roc2, pr2 = get_uncertainty_metrics(uq_model, combined_loader)
print(f'Energy:\t \t \t \t \t \t Forces:\n  \
energy_roc: {roc1}, \t \t   forces_roc: {roc2}\n  \
energy_pr: {pr1}, \t \t   forces_pr: {pr2}\n \n \n')

                not been set for this class (AUCPR). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


Energy:	 	 	 	 	 	 Forces:
  energy_roc: 0.15305501222610474, 	 	   forces_roc: 0.9778759479522705
  energy_pr: 0.3309403359889984, 	 	   forces_pr: 0.9805764555931091
 
 

