In [1]:
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [2]:
datasets = {
    #"WyckoffTransformer-raw": ("WyckoffTransformer",),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyForDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    #"WyckoffTransformer-free": ("WyckoffTransformer", "CrySPR", "CHGNet_free"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    #"DiffCSP++ raw": ("DiffCSP++",),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "FlowMM": ("FlowMM", "CHGNet_fix"),
    #"MP-20 train": ("split", "train"),
    #"MP-20 test": ("split", "test"),
}
raw_datasets = {
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++"),
    "CrystalFormer": ("CrystalFormer",),
    "DiffCSP++": ("DiffCSP++",),
    "DiffCSP": ("DiffCSP",),
    "FlowMM": ("FlowMM",),
}

In [3]:
all_datasets = load_all_from_config(
    datasets=list(datasets.values()) + list(raw_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test"), ("WyckoffTransformer", "CrySPR", "CHGNet_fix")],
    dataset_name="mp_20")

In [4]:
datasets["WyCryst"] = ('WyCryst', 'CHGNet_fix')
raw_datasets["WyCryst"] = ('WyCryst', 'CHGNet_fix')
all_datasets[('WyCryst', 'CHGNet_fix')] = GeneratedDataset.from_cache(('WyCryst', 'CHGNet_fix'), "mp_20_biternary")

In [5]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [6]:
import evaluation.statistical_evaluator
test_unique = filter_by_unique_structure(all_datasets[('split', 'test')].data)
test_novel = novelty_filter.get_novel(test_unique)
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(test_novel)

In [7]:
import evaluation.novelty
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [8]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [None]:
mp_20 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
(mp_20.spacegroup_number == 1).mean()
mp_20.smact_validity.mean()

0.9057020937893829

In [10]:
1-mp_20.smact_validity.mean()

0.09429790621061707

In [11]:
mp_20.sites.apply(len).mean()

3.012403546397223

In [12]:
from pathlib import Path
novel_save_count = 105
novel_save_path = Path("generated", "Dropbox", f"novel_{novel_save_count}")
novel_save_path.mkdir(parents=True, exist_ok=True)

In [13]:
from pymatgen.io.cif import CifWriter
from pymatgen.core import Structure

def to_cif(structure: Structure) -> str:
    cif_writer = CifWriter(structure.to_primitive())
    return cif_writer.__str__()

In [14]:
table = pd.DataFrame(
    index=datasets.keys(), columns=[
        "Novelty (%)", "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "S.U.N. (%)",
        "Novel Template (%)", "P1 (%)",
        "Space Group", "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
unique_sample_size = 992
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    unique = filter_by_unique_structure(dataset.data)
    print(f"Unique: {len(unique)} / {len(dataset.data)} = {len(unique) / len(dataset.data)}")
    novel_template = ~unique.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set)
    table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
    if transformations == ("split", "train"):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(novel)
    novel_cif = novel.structure.iloc[:novel_save_count].apply(to_cif)
    site_counts = novel.structure.iloc[:novel_save_count].apply(lambda s: len(s.to_primitive().sites))
    print(f"Novel_{novel_save_count} primitive site counts: "
          f"{site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}")
    novel_cif.name = "cif"
    cif_path = novel_save_path.joinpath(*transformations)
    cif_path.mkdir(parents=True, exist_ok=True)
    novel_cif.to_csv(cif_path.joinpath("cif.csv.gz"), index_label="index")
    table.loc[name, "P1 (%)"] = 100 * (novel.group == 1).mean()
    # table.loc[name, "# DoF"] = test_evaluator.get_dof_emd(novel)
    table.loc[name, "Space Group"] = test_evaluator.get_sg_chi2(novel)
    #try:
    #    table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    #except KeyError:
    #    pass
    #table.loc[name, "Elements"] = test_evaluator.get_elements_chi2(novel)
    if "corrected_chgnet_ehull" in novel.columns:
        # S.U.N. is measured with respect to the initial structures
        has_ehull = dataset.data.corrected_chgnet_ehull.notna().sum()
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull
    if transformations == ("split", "train"):
        # Train forms the baseline of the hull
        test_dataset = all_datasets[("split", "test", "CHGNet_fix")].data
        test_with_ehull = test_dataset[test_dataset.corrected_chgnet_ehull.notna()]
        test_unique = filter_by_unique_structure(test_with_ehull)
        test_novel = novelty_filter.get_novel(test_unique)
        table.loc[name, "S.U.N. (%)"] = 100 * (test_novel.corrected_chgnet_ehull <= E_hull_threshold).sum() / len(test_with_ehull)
table

  0%|          | 0/7 [00:00<?, ?it/s]

Unique: 1000 / 1000 = 1.0
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 1000 / 1000 = 1.0
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 988 / 992 = 0.9959677419354839
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 999 / 1000 = 0.999
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 996 / 1000 = 0.996
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 994 / 997 = 0.9969909729187563
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}
Unique: 994 / 994 = 1.0
Novel_105 primitive site counts: {site_counts.mean()} ± {site_counts.std()}; max: {site_counts.max()}


Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormer,90.0,99.555556,80.444444,98.666667,96.715328,0.73714,0.053159,0.096566,38.938939,19.8,3.236607,0.222677,38.038038
WyForDiffCSP++,89.5,99.664804,80.335196,99.217877,96.792745,0.665291,0.050328,0.098188,36.6,20.0,1.457399,0.211951,35.9
CrystalFormer,76.923077,86.842105,82.368421,99.868421,95.13382,0.524171,0.099555,0.163017,33.870968,9.716599,0.909091,0.276161,33.770161
DiffCSP++,89.68969,100.0,85.044643,99.330357,95.79739,0.147837,0.036122,0.503632,41.4,1.001001,2.566964,0.25525,40.8
DiffCSP,90.060241,100.0,80.936455,99.554069,96.206591,0.822378,0.052409,0.293636,57.4,9.538153,36.566332,7.988511,40.6
FlowMM,90.140845,96.205357,82.477679,99.665179,96.361424,0.315307,0.044476,0.114668,49.548646,6.136821,49.303944,14.007995,29.889669
WyCryst,52.615694,99.808795,75.525813,98.852772,87.104623,0.962397,0.112944,0.285756,36.619718,17.203219,4.789272,0.710472,35.211268


In [15]:
table.to_csv("tables/paper_summary_table.csv")
table.to_pickle("tables/paper_summary_table.pkl")

In [16]:
max_subset=["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", "S.S.U.N. (%)", "S.U.N. (%)", "Novel Template (%)"]
# -1 to exclude the MP-20 training set
def highlight_max_value(s):
    if s.name not in max_subset:
        return ['' for _ in s]
    is_max = s == s.max()
    #is_max.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_max]

min_subset=[r"$\rho$", "$E$", "# Elements", "# DoF", "Space Group", "Elements", "P1 (%)"]
def highlight_min_value(s):
    if s.name not in min_subset:
        return ['' for _ in s]
    is_min = s == s.min()
    #is_min.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_min]

In [None]:
def prettify(table):
    return table.style.format({
    "Novelty (%)": "{:.2f}",
    "Structural": "{:.2f}",
    "Compositional": "{:.2f}",
    "Recall": "{:.2f}",
    "Precision": "{:.2f}",
    r"$\rho$": "{:.2f}",
    "$E$": "{:.3f}",
    "# Elements": "{:.3f}",
    "# DoF": "{:.3f}",
    "Space Group": "{:.3f}",
    "Elements": "{:.3f}",
    "Novel Template (%)": "{:.2f}",
    "P1 (%)": "{:.2f}",
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
}).apply(highlight_max_value).apply(highlight_min_value)
prettify(table)

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormer,90.0,99.56,80.44,98.67,96.72,0.74,0.053,0.097,38.9,19.8,3.24,0.223,38.0
WyForDiffCSP++,89.5,99.66,80.34,99.22,96.79,0.67,0.05,0.098,36.6,20.0,1.46,0.212,35.9
CrystalFormer,76.92,86.84,82.37,99.87,95.13,0.52,0.1,0.163,33.9,9.72,0.91,0.276,33.8
DiffCSP++,89.69,100.0,85.04,99.33,95.8,0.15,0.036,0.504,41.4,1.0,2.57,0.255,40.8
DiffCSP,90.06,100.0,80.94,99.55,96.21,0.82,0.052,0.294,57.4,9.54,36.57,7.989,40.6
FlowMM,90.14,96.21,82.48,99.67,96.36,0.32,0.044,0.115,49.5,6.14,49.3,14.008,29.9
WyCryst,52.62,99.81,75.53,98.85,87.1,0.96,0.113,0.286,36.6,17.2,4.79,0.71,35.2


In [18]:
prettify(table.iloc[:, :9]).to_latex("tables/summary_similarity_raw.tex", siunitx=True, convert_css=True)
prettify(table.iloc[:, 9:]).to_latex("tables/summary_symmetry_raw.tex", siunitx=True, convert_css=True)

In [19]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 992
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("tables/cdvae_metrics_table.csv")
prettify(cdvae_table).to_latex("tables/cdvae_metrics_table.tex", siunitx=True, convert_css=True)
prettify(cdvae_table)

  0%|          | 0/7 [00:00<?, ?it/s]

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyFormer,99.6,81.4,98.77,95.94,0.39,0.078,0.081
WyForDiffCSP++,99.7,81.4,99.26,95.85,0.33,0.07,0.078
CrystalFormer,89.92,84.88,99.87,95.45,0.19,0.139,0.119
DiffCSP++,100.0,85.8,99.42,95.48,0.13,0.036,0.453
DiffCSP,100.0,82.5,99.64,95.18,0.46,0.075,0.321
FlowMM,96.59,83.25,99.75,95.83,0.17,0.055,0.107
WyCryst,99.9,82.09,99.63,96.16,0.44,0.33,0.322


In [20]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(raw_datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 992
for name, transformations in tqdm(raw_datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("tables/cdvae_metrics_no_relax_table.csv")
prettify(cdvae_table).to_latex("tables/cdvae_metrics_no_relax_table.tex", siunitx=True, convert_css=True)
prettify(cdvae_table)

  0%|          | 0/7 [00:00<?, ?it/s]

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyFormer,99.6,81.4,98.77,95.94,0.39,0.078,0.081
WyFormerDiffCSP++,99.8,81.4,99.51,95.81,0.36,0.083,0.079
CrystalFormer,93.39,84.98,99.62,94.56,0.19,0.208,0.128
DiffCSP++,99.94,85.13,99.67,99.54,0.31,0.069,0.399
DiffCSP,100.0,83.2,99.82,99.51,0.35,0.095,0.347
FlowMM,96.87,83.11,99.73,99.39,0.12,0.073,0.094
WyCryst,99.9,82.09,99.63,96.16,0.44,0.33,0.322
