In [1]:
import pandas as pd
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In [2]:
all_datasets = load_all_from_config()

In [3]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_reference_set = frozenset(novelty_reference.fingerprint)
novelty_filter = NoveltyFilter(novelty_reference)

In [4]:
import evaluation.statistical_evaluator
import importlib
importlib.reload(evaluation.statistical_evaluator)
test_unique = filter_by_unique_structure(all_datasets[('split', 'test')].data)
test_novel = novelty_filter.get_novel(test_unique)
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(test_novel)

In [5]:
from tqdm.notebook import tqdm
sample_size = 992
sample_size_big = 10000
results = pd.DataFrame(index=pd.Index(all_datasets.keys(), tupleize_cols=False),
    columns=["Unique992", f"Unique{sample_size_big}", "Novel WY", "Novel SM",
        "StructVal", "SMACT",
        "& sites EMD", "& sites KS", "& #elements KS", "& #elements EMD",
        "#DoF KS", "#DoF EMD",
        "density EMD", "E EMD",
        "COV-R", "COV-P",
        "& SG chi2", "& elements chi2",
        "S.U.N. 0", "S.U.N. 0.08"])

for transformations, dataset in tqdm(all_datasets.items()):
    if "smact_validity" in dataset.data.columns:
        results.loc[[transformations], "SMACT"] = dataset.data.smact_validity.mean()
        results.loc[[transformations], "StructVal"] = dataset.data.structural_validity.mean()

    if "corrected_chgnet_ehull" in dataset.data.columns:
        has_e_hull = dataset.data.corrected_chgnet_ehull.notnull()
        sample_for_unique = dataset.data.loc[has_e_hull].sample(sample_size)
    else:
        sample_for_unique = dataset.data.sample(sample_size)
    if len(sample_for_unique) < sample_size:
        raise ValueError(f"Dataset {transformations} has less than {sample_size} entries")
    unique_sample = filter_by_unique_structure(sample_for_unique)
    results.loc[[transformations], "Unique992"] = len(unique_sample) / len(sample_for_unique)

    if len(dataset.data) >= sample_size_big:
        big_sample_for_unique = dataset.data.sample(sample_size_big)
        big_unique_sample = filter_by_unique_structure(big_sample_for_unique)
        results.loc[[transformations], f"Unique{sample_size_big}"] = len(big_unique_sample) / len(big_sample_for_unique)
    
    unique = dataset.data.drop_duplicates(subset="fingerprint")  
    if len(transformations) != 2 or transformations[1] not in ("train", "val"):
        is_novel = ~unique.fingerprint.isin(novelty_reference_set)
        novel_wy = unique.loc[is_novel]
        results.loc[[transformations], "Novel WY"] = len(novel_wy) / len(unique)
        novel = novelty_filter.get_novel(unique)
        results.loc[[transformations], "Novel SM"] = len(novel) / len(unique)
    else:
        novel = unique

    if "smact_validity" in dataset.data.columns:
        results.loc[[transformations], "StructVal"] = novel.structural_validity.mean()
        results.loc[[transformations], "SMACT"] = novel.smact_validity.mean()
    
    results.loc[[transformations], "& sites EMD"] = test_evaluator.get_num_sites_emd(novel)
    results.loc[[transformations], "& sites KS"] = test_evaluator.get_num_sites_ks(novel).statistic
    results.loc[[transformations], "& #elements KS"] = test_evaluator.get_num_elements_ks(novel).statistic
    results.loc[[transformations], "& #elements EMD"] = test_evaluator.get_num_elements_emd(novel)
    results.loc[[transformations], "#DoF KS"] = test_evaluator.get_dof_ks(novel).statistic
    results.loc[[transformations], "#DoF EMD"] = test_evaluator.get_dof_emd(novel)
    if "structure" in dataset.data.columns:
        results.loc[[transformations], "density EMD"] = test_evaluator.get_density_emd(novel)
        results.loc[[transformations], "E EMD"] = test_evaluator.get_cdvae_e_emd(novel)
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        results.loc[[transformations], "COV-R"] = cov_metrics["cov_recall"]
        results.loc[[transformations], "COV-P"] = cov_metrics["cov_precision"]

    results.loc[[transformations], "& SG chi2"] = test_evaluator.get_sg_chi2(novel, sample_size=sample_size)
    results.loc[[transformations], "& elements chi2"] = test_evaluator.get_elements_chi2(novel, sample_size=sample_size)

    if "corrected_chgnet_ehull" in dataset.data.columns:
        total_with_ehull = has_e_hull.sum()
        stable_008 = (novel.corrected_chgnet_ehull < 0.08).sum()
        stable_0 = (novel.corrected_chgnet_ehull <= 0).sum()
        results.loc[[transformations], "S.U.N. 0"] = stable_0 / total_with_ehull
        results.loc[[transformations], "S.U.N. 0.08"] = stable_008 / total_with_ehull
    
results

  0%|          | 0/20 [00:00<?, ?it/s]

AttributeError: 'Series' object has no attribute 'structure'

In [8]:
results.to_csv("evaluation_results.csv")

In [9]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(all_datasets.keys(), tupleize_cols=False),
    columns=[
        "StructVal", "SMACT",
        "COV-R", "COV-P",
        "density EMD", "E EMD", "#elements EMD"])
sample_size = 992
for transformations, dataset in tqdm(all_datasets.items()):
    if "structure" in dataset.data.columns:
        cdvae_table.loc[[transformations], "SMACT"] = dataset.data.smact_validity.mean()
        cdvae_table.loc[[transformations], "StructVal"] = dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[[transformations], "COV-R"] = cov_metrics["cov_recall"]
        cdvae_table.loc[[transformations], "COV-P"] = cov_metrics["cov_precision"]
        cdvae_table.loc[[transformations], "density EMD"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[[transformations], "E EMD"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[[transformations], "#elements EMD"] = raw_test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("cdvae_metrics_table.csv")
cdvae_table

  0%|          | 0/20 [00:00<?, ?it/s]

Unnamed: 0,StructVal,SMACT,COV-R,COV-P,density EMD,E EMD,#elements EMD
"(WyckoffTransformer,)",,,,,,,
"(WyckoffTransformer, CrySPR, CHGNet_fix)",0.997,0.814,0.987685,0.959319,0.392825,0.078476,0.080551
"(WyckoffTransformer, CrySPR, CHGNet_free)",0.999,0.814,0.99016,0.954897,0.404807,0.064897,0.079908
"(WyckoffTransformer, CrySPR, CHGNet_fix_release)",0.996,0.814,0.98767,0.95943,0.387186,0.077984,0.081197
"(WyckoffTransformer, DiffCSP++)",0.998,0.814,0.995074,0.958103,0.361735,0.08327,0.078874
"(WyckoffTransformer, DiffCSP++, CHGNet_fix)",0.997,0.814,0.992602,0.958545,0.327697,0.070343,0.077776
"(WyckoffTransformer, DiffCSP++, CHGNet_free)",0.997,0.814,0.992602,0.956113,0.33662,0.067816,0.077776
"(CrystalFormer,)",0.933934,0.84985,0.996226,0.945611,0.193111,0.208384,0.128222
"(CrystalFormer, CHGNet_fix_release)",0.899194,0.84879,0.9987,0.954455,0.185524,0.138571,0.118921
"(DiffCSP,)",1.0,0.832,0.998197,0.995136,0.350799,0.094984,0.346595
