In [1]:
import pandas as pd
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config

In [2]:
all_datasets = load_all_from_config()

In [3]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data.fingerprint,
    all_datasets[('split', 'val')].data.fingerprint], verify_integrity=True)
novelty_reference_set = frozenset(novelty_reference)

In [40]:
import evaluation.statistical_evaluator
import importlib
importlib.reload(evaluation.statistical_evaluator)
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [41]:
results = pd.DataFrame(index=pd.Index(all_datasets.keys(), tupleize_cols=False),
    columns=["CompVal", "SMACT", "Unique", "& Novel",
        "& CompVal", "& SMACT",
        "& sites KS", "& elements KS", "& DoF KS",
        "& SG chi2", "& elements chi2",
        "S.U.N. 0", "S.U.N. 0.08"])
sample_size = 992
for transformations, dataset in all_datasets.items():
    if "smact_validity" in dataset.data.columns:
        results.loc[[transformations], "SMACT"] = dataset.data.smact_validity.mean()
        results.loc[[transformations], "CompVal"] = dataset.data.structural_validity.mean()
    if "corrected_chgnet_ehull" in dataset.data.columns:
        has_e_hull = dataset.data.corrected_chgnet_ehull.notnull()
        sample = dataset.data.loc[has_e_hull].sample(sample_size)
    else:
        sample = dataset.data.sample(sample_size)
    if len(sample) < sample_size:
        raise ValueError(f"Dataset {transformations} has less than {sample_size} entries")
    
    unique = sample.drop_duplicates(subset="fingerprint")
    results.loc[[transformations], "Unique"] = len(unique) / len(sample)
    
    if len(transformations) != 2 or transformations[1] not in ("train", "val"):
        is_novel = ~unique.fingerprint.isin(novelty_reference_set)
        novel = unique.loc[is_novel]
        results.loc[[transformations], "& Novel"] = len(novel) / len(sample)
    else:
        novel = unique
    if len(novel) == 0:
        continue
    if "smact_validity" in dataset.data.columns:
        results.loc[[transformations], "& CompVal"] = novel.structural_validity.mean()
        results.loc[[transformations], "& SMACT"] = novel.smact_validity.mean()
    
    results.loc[[transformations], "& sites KS"] = test_evaluator.get_num_sites_ks(novel).statistic
    results.loc[[transformations], "& elements KS"] = test_evaluator.get_num_elements_ks(novel).statistic
    results.loc[[transformations], "& DoF KS"] = test_evaluator.get_dof_ks(novel).statistic

    results.loc[[transformations], "& SG chi2"] = test_evaluator.get_sg_chi2(novel)
    results.loc[[transformations], "& elements chi2"] = test_evaluator.get_elements_chi2(novel)

    if "corrected_chgnet_ehull" in dataset.data.columns:
        stable_008 = (novel.corrected_chgnet_ehull < 0.08).sum()
        stable_0 = (novel.corrected_chgnet_ehull <= 0).sum()
        results.loc[[transformations], "S.U.N. 0"] = stable_0 / len(sample)
        results.loc[[transformations], "S.U.N. 0.08"] = stable_008 / len(sample)
    
results

Generated dataset has extra space groups compared to test
Generated dataset has extra elements compared to test
Generated dataset has extra elements compared to test
Generated dataset has extra elements compared to test
Generated dataset has extra elements compared to test
Generated dataset has extra elements compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra space groups compared to test
Generated dataset has extra elements compared to test


Unnamed: 0,CompVal,SMACT,Unique,& Novel,& CompVal,& SMACT,& sites KS,& elements KS,& DoF KS,& SG chi2,& elements chi2,S.U.N. 0,S.U.N. 0.08
"(WyckoffTransformer,)",,,0.996976,0.865927,,,0.036164,0.066072,0.04133,154.184541,766.543597,,
"(WyckoffTransformer, CHGNet_fix)",0.997,0.814,0.997984,0.879032,0.99656,0.806193,0.037924,0.056294,0.036035,171.948119,615.850959,0.127016,0.364919
"(WyckoffTransformer, CHGNet_free)",0.999,0.814,0.990927,0.865927,0.998836,0.804424,0.161876,0.050938,0.241596,1063.418461,594.802701,0.15625,0.436492
"(WyckoffTransformer, CHGNet_fix_release)",0.996,0.814,0.997984,0.879032,0.99656,0.805046,0.043658,0.056294,0.052091,163.332451,638.260894,0.136089,0.382056
"(WyckoffTransformer, DiffCSP++)",0.998,0.814,0.997984,0.878024,0.997704,0.804822,0.039291,0.056595,0.040708,163.102634,637.327192,0.125,0.356855
"(WyckoffTransformer, DiffCSP++, CHGNet_free)",0.997,0.814,0.997984,0.876008,0.997699,0.806674,0.068501,0.052594,0.080706,215.990514,612.10921,0.125,0.365927
"(CrystalFormer,)",0.933934,0.84985,0.996976,0.741935,0.915761,0.827446,0.116482,0.122272,0.091704,199.949631,1433.856439,,
"(CrystalFormer, CHGNet_fix_release)",0.899194,0.84879,0.995968,0.731855,0.870523,0.823691,0.107619,0.11714,0.08015,183.443203,1336.58298,0.193548,0.362903
"(DiffCSP,)",1.0,0.832,0.97379,0.831653,1.0,0.835152,0.370437,0.15119,0.342758,2245.579179,1145.168787,,
"(DiffCSP, CHGNet_fix)",1.0,0.825,0.976815,0.851815,1.0,0.813018,0.392126,0.168791,0.388252,2591.36026,1387.773703,0.177419,0.546371


In [42]:
results.to_latex("results.tex", float_format="%.3f", escape=True)