In [1]:
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In [2]:
datasets = {
    "WyckoffTransformer": ("WyckoffTransformer", ),
    "WyCryst": ("WyCryst", ),
    "WyCryst-chgnet": ("WyCryst", "CHGNet_fix")
    #"train": ("split", "train"),
    #"val": ("split", "val"),
    #"test": ("split", "test"),
}

In [3]:
all_datasets = load_all_from_config(
    datasets=list(datasets.values()) + [("split", "train"), ("split", "val"), ("split", "test")],
    dataset_name="mp_20_biternary")

In [4]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [5]:
import evaluation.statistical_evaluator
#import importlib
#importlib.reload(evaluation.statistical_evaluator)
test_unique = filter_by_unique_structure(all_datasets[('split', 'test')].data)
test_novel = novelty_filter.get_novel(test_unique)
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(test_novel)

In [6]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [7]:
mp_20_biternary = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
print((mp_20_biternary.spacegroup_number == 1).mean())
mp_20_biternary.smact_validity.mean()

0.010822206605762474


0.8909065354884048

In [8]:
mp_20_test = GeneratedDataset.from_cache(('split', 'test'), "mp_20")
mp_20_test_unique = filter_by_unique_structure(mp_20_test.data)
mp_20_test_novel = novelty_filter.get_novel(mp_20_test_unique)
mp_20_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(mp_20_test_novel)

In [9]:
# It almost doesn't affect WyCryst, but allows to put everything in one table
test_evaluator = mp_20_test_evaluator

In [19]:
import evaluation.novelty
import importlib
importlib.reload(evaluation.novelty)
mp_20_train_val = pd.concat([
    GeneratedDataset.from_cache(('split', 'train'), "mp_20").data,
    GeneratedDataset.from_cache(('split', 'val'), "mp_20").data], axis=0, verify_integrity=True)

AttributeError: 'DataFrame' object has no attribute 'test_dataset'

In [20]:
train_w_template_set = frozenset(mp_20_train_val.apply(
    evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [21]:
table = pd.DataFrame(
    index=datasets.keys(), columns=[
        "Novelty (%)", "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "S.U.N. (%)",
        "Novel Template (%)", "P1 (%)",
        "Space Group", "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
unique_sample_size = 992
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    unique = filter_by_unique_structure(dataset.data)
    print(len(unique), len(dataset.data), len(unique) / len(dataset.data))
    novel_template = ~unique.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set)
    table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
    if transformations in (("split", "train"), ("split", "val")):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(novel)
    table.loc[name, "P1 (%)"] = 100 * (novel.group == 1).mean()
    # table.loc[name, "# DoF"] = test_evaluator.get_dof_emd(novel)
    table.loc[name, "Space Group"] = test_evaluator.get_sg_chi2(novel)
    #try:
    #    table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    #except KeyError:
    #    pass
    #table.loc[name, "Elements"] = test_evaluator.get_elements_chi2(novel)
    if "corrected_chgnet_ehull" in novel.columns:
        # S.U.N. is measured with respect to the initial structures
        has_ehull = dataset.data.corrected_chgnet_ehull.notna().sum()
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull.sum()
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull.sum()

  0%|          | 0/3 [00:00<?, ?it/s]

In [22]:
table.to_csv("tables/mp_20_biternary_paper_summary_table.csv")
table.to_pickle("tables/mp_20_biternary_paper_summary_table.pkl")

In [23]:
max_subset=["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", "S.S.U.N. (%)", "S.U.N. (%)", "Novel Template (%)"]
# -1 to exclude the MP-20 training set
def highlight_max_value(s):
    if s.name not in max_subset:
        return ['' for _ in s]
    is_max = s == s.max()
    #is_max.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_max]

min_subset=[r"$\rho$", "$E$", "# Elements", "# DoF", "Space Group", "Elements", "P1 (%)"]
def highlight_min_value(s):
    if s.name not in min_subset:
        return ['' for _ in s]
    is_min = s == s.min()
    #is_min.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_min]

In [25]:
def prettify(table):
    return table.style.format({
    "Novelty (%)": "{:.2f}",
    "Structural": "{:.2f}",
    "Compositional": "{:.2f}",
    "Recall": "{:.2f}",
    "Precision": "{:.2f}",
    r"$\rho$": "{:.2f}",
    "$E$": "{:.3f}",
    "# Elements": "{:.3f}",
    "# DoF": "{:.3f}",
    "Space Group": "{:.3f}",
    "Elements": "{:.3f}",
    "Novel Template (%)": "{:.2f}",
    "P1 (%)": "{:.2f}",
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
})#.apply(highlight_max_value).apply(highlight_min_value)
prettify(table)

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyckoffTransformer,89.37,,,,,,,,63.2,22.85,1.12,0.079,62.4
WyCryst,61.71,,,,,,,,29.5,30.38,4.46,0.489,27.6
WyCryst-chgnet,52.62,99.81,75.53,98.85,87.1,0.96,0.113,0.286,36.6,17.2,4.79,0.71,35.2


In [18]:
prettify(table.iloc[:, :9]).to_latex("tables/mp_20_biternary_summary_similarity_raw.tex", siunitx=True, convert_css=True)
prettify(table.iloc[:, 9:]).to_latex("tables/mp_20_biternary_summary_symmetry_raw.tex", siunitx=True, convert_css=True)