In [1]:
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
import sys
sys.path.append('..')
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [2]:
raw_datasets = {
    "WyFormer": ("WyckoffTransformer", ),
    "SymmCD": ("SymmCD", ),
    "WyFormer-CHGNet": ("WyckoffTransformer", "CrySPR", "CHGNet_fix"),
    "SymmCD-CHGNet": ("SymmCD", "CHGNet_fix")
}

In [3]:
all_datasets = load_all_from_config(
    datasets=list(raw_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test")],
    dataset_name="mpts_52")

In [4]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [5]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(
    all_datasets[('split', 'test')].data, cdvae_eval_model_name="mp20")

In [6]:
mpts_52 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data,
    all_datasets[('split', 'test')].data], axis=0, verify_integrity=True)

In [7]:
train_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(
    all_datasets[('split', 'train')].data, cdvae_eval_model_name="mp20")

In [8]:
train_evaluator.test_dataset.smact_validity.mean()

0.8710372534696859

In [9]:
train_evaluator.test_dataset.structural_validity.mean()

1.0

In [10]:
frozenset(range(1000)) - frozenset(all_datasets[("WyckoffTransformer", "CrySPR", "CHGNet_fix")].data.index)

frozenset({278, 399, 467, 535, 697})

In [11]:
(train_evaluator.test_dataset.spacegroup_number == 1).sum() / len(train_evaluator.test_dataset)

0.0023009495982468956

In [12]:
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [13]:
train_evaluator.get_num_sites_emd(all_datasets[('split', 'test')].data)

2.141915765380806

In [14]:
mpts_52.fingerprint.nunique(), mpts_52.shape[0]

(39907, 40476)

In [15]:
mpts_52.fingerprint.nunique() / mpts_52.shape[0]

0.985942286787232

In [16]:
mpts_52.smact_validity.mean()

0.8824488585828639

Formal validity:
1. WyFormer 0.935 [WanDB](https://wandb.ai/symmetry-advantage/WyckoffTransformer/runs/9sqoiuvh?nw=nwuserkazeev)
2. SymmCD: 9170 / 10000

In [25]:
formal_validity = {
    "WyckoffTransformer": 0.935,
    "SymmCD": 9170 / 10000
}

In [27]:
table = pd.DataFrame(
    index=raw_datasets.keys(), columns=[
        "Novelty (%)", "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "Novel Template (%)", 
        "P1 (%)",
        "Space Group"])
table.index.name = "Method"
E_hull_threshold = 0.08
for name, transformations in tqdm(raw_datasets.items()):
    dataset = all_datasets[transformations]
    unique = filter_by_unique_structure(dataset.data)
    print(f"Unique: {len(unique)} / {len(dataset.data)} = {len(unique) / len(dataset.data)}")
    novel_template = ~unique.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set)
    table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
    if transformations == ("split", "train"):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        #cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        #table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        #table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        all_templates = novel.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1)
        novel_template = ~all_templates.isin(train_w_template_set)
        table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
        table.loc[name, "Novel Uniques Templates (#)"] = all_templates[novel_template].nunique() 
        table.loc[name, r"$\rho$"] = train_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = train_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = train_evaluator.get_num_elements_emd(novel)
    table.loc[name, "P1 (%)"] = 100 * (novel.group == 1).mean()
    table.loc[name, "Space Group"] = train_evaluator.get_sg_chi2(novel)
    if "corrected_chgnet_ehull" in dataset.data.columns:
        has_ehull = dataset.data.corrected_chgnet_ehull.notna()
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull.sum()
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull.sum()
table

  0%|          | 0/4 [00:00<?, ?it/s]

Unique: 9330 / 9350 = 0.9978609625668449
Unique: 9011 / 9170 = 0.9826608505997819
Unique: 995 / 995 = 1.0
Unique: 8734 / 8890 = 0.9824521934758155


Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,Novel Template (%),P1 (%),Space Group,Novel Uniques Templates (#),S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
WyFormer,97.867095,,,,,,,,37.299035,0.208082,0.04972,,,
SymmCD,95.228055,100.0,79.641067,,,0.127218,0.275252,0.632112,27.828924,0.279688,0.050546,1131.0,,
WyFormer-CHGNet,98.693467,99.287169,76.680244,,,0.697912,0.108448,0.228226,44.512821,0.0,0.225362,386.0,24.422111,24.422111
SymmCD-CHGNet,93.061598,96.690453,79.466043,,,0.551155,0.173934,0.616459,28.464181,0.305382,0.063133,1075.0,25.073116,25.050619


In [18]:
cdvae_table = pd.DataFrame(index=pd.Index(raw_datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 900
for name, transformations in tqdm(raw_datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = test_evaluator.get_coverage(valid.cdvae_crystal, sample_size)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("tables/cdvae_metrics_no_relax_table.csv")
#prettify(cdvae_table).to_latex("tables/cdvae_metrics_no_relax_table.tex", siunitx=True, convert_css=True)
#prettify(cdvae_table)

  0%|          | 0/4 [00:00<?, ?it/s]

Requested sample size 900 is larger than the number of generated samples 757.


In [19]:
cdvae_table

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyFormer,,,,,,,
SymmCD,100.0,80.196292,97.777778,95.479249,1.074465,0.640633,0.364367
WyFormer-CHGNet,99.296482,76.683417,95.555556,95.034585,0.6083,0.479468,0.376129
SymmCD-CHGNet,96.974128,80.213723,97.0,97.369071,0.800725,0.552997,0.359354
