In [2]:
import sys
from typing import Tuple
import pandas as pd
from tqdm.notebook import tqdm
from pymatgen.core import Structure
sys.path.append("..")
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [3]:
datasets = {
    "WyckoffTransformer-raw": ("WyckoffTransformer",),
    "WyFormer-harmonic-raw": ("WyckoffTransformer-harmonic",),
    "WyFormer-letters": ("WyckoffTransformer-letters",),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "CHGNet_fix"),
    "SymmCD": ("SymmCD", "CHGNet_fix"),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++", "CHGNet_fix"),
    "WyForDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-naive-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-vanilla-DiffCSP++": ("WyckoffLLM-vanilla", "DiffCSP++"),
    "WyLLM-site-symmetry-DiffCSP++": ("WyckoffLLM-site-symmetry", "DiffCSP++"),
    #"WyckoffTransformer-free": ("WyckoffTransformer", "CrySPR", "CHGNet_free"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    #"DiffCSP++ raw": ("DiffCSP++",),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "FlowMM": ("FlowMM", "CHGNet_fix"),
    "MiAD": ("MiAD", "CHGNet_free"),
    #"MP-20 train": ("split", "train"),
    #"MP-20 test": ("split", "test"),
}
raw_datasets = {
    "SymmCD": ("SymmCD",),
    "WyFormer": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++"),
    "WyFormer-harmonic-DiffCSP++": ("WyckoffTransformer-harmonic", "DiffCSP++"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++"),
    "CrystalFormer": ("CrystalFormer",),
    "DiffCSP++": ("DiffCSP++",),
    "DiffCSP": ("DiffCSP",),
    "FlowMM": ("FlowMM",),
}

In [4]:
all_datasets = load_all_from_config(
    datasets=list(datasets.values()) + list(raw_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test"), ("WyckoffTransformer", "CrySPR", "CHGNet_fix")],
    dataset_name="mp_20")

In [5]:
wycryst_transformations = ('WyCryst', 'CrySPR', 'CHGNet_fix')
datasets["WyCryst"] = wycryst_transformations
raw_datasets["WyCryst"] = wycryst_transformations
all_datasets[wycryst_transformations] = GeneratedDataset.from_cache(wycryst_transformations, "mp_20_biternary")

In [6]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [7]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [8]:
import evaluation.novelty
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [9]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [10]:
mp_20 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
(mp_20.spacegroup_number == 1).mean()
mp_20.smact_validity.mean()

0.9057020937893829

In [11]:
from collections import Counter
from operator import itemgetter
from itertools import chain
element_counts = Counter(chain(*mp_20.elements))

In [12]:
represented_elements=frozenset(map(itemgetter(0), element_counts.most_common(30)))

In [13]:
def check_represented_composition(structure: Structure) -> bool:
    for element in structure.composition:
        if element not in represented_elements:
            return False
    return True

In [14]:
def check_double_represented_composition(structure: Structure) -> bool:
    found_first_exotic = False
    for element in structure.composition:
        if element not in represented_elements:
            if found_first_exotic:
                return False
            found_first_exotic = True
    return True

In [15]:
top_10_groups = frozenset(mp_20.spacegroup_number.value_counts().iloc[:10].index)
n_elements_dist = {}

Validity
1. Vanilla; Valid records: 2866 / 9648 = 29.71%
2. Naive; Valid records: 9492 / 9804 = 96.82%
3. Site Symmetry; Valid records: 8955 / 9709 = 92.23%

Formal validity:
1. WyFomer: 0.9772; [WanDB](https://wandb.ai/symmetry-advantage/WyckoffTransformer/runs/yj1cme83/logs?nw=nwuserkazeev)
2. SymmCD: 0.9580

In [16]:
table = pd.DataFrame(
    index=datasets.keys(), columns=[
        "Novelty (%)",
        "Represented composition",
        "Structural", "Compositional", 
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements",
        "S.U.N. (%)",
        "R.S.U.N. (%)",
        "~R.S.U.N. (%)",
        "Top-10 S.U.N. (%)",
        "Novel Uniques Templates (#)",
        "Novel Template (%)", 
        "P1 (%)",
        "Space Group", "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
precision_etc_sample_size = 1000
for name, transformations in tqdm([("train", ("split", "train"))]+list(datasets.items())):
    data = all_datasets[transformations].data.copy()
    if len(data) > precision_etc_sample_size:
        data = data.sample(precision_etc_sample_size, random_state=42)
    unique = filter_by_unique_structure(data)
    print(f"{name} unique: {len(unique)} / {len(data)} = {len(unique) / len(data) :.3%}")
    if transformations == ("split", "train"):
        novel = unique
    else:
        novel = novelty_filter.get_novel(unique)
    table.loc[name, "Novelty (%)"] = 100 * len(novel) / len(unique)
    if "structural_validity" in novel.columns:
        table.loc[name, "Structural"] = 100 * novel.structural_validity.mean()
        table.loc[name, "Compositional"] = 100 * novel.smact_validity.mean()
    if "cdvae_crystal" in novel.columns:
        table.loc[name, "Represented composition"] = novel.structure.apply(check_represented_composition).mean()
        cov_metrics = test_evaluator.get_coverage(novel.cdvae_crystal)    
        table.loc[name, "Recall"] = 100 * cov_metrics["cov_recall"]
        table.loc[name, "Precision"] = 100 * cov_metrics["cov_precision"]
        novel = novel[novel.structural_validity]
        all_templates = novel.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1)
        novel_template = ~all_templates.isin(train_w_template_set)
        table.loc[name, "Novel Template (%)"] = 100 * novel_template.mean()
        table.loc[name, "Novel Uniques Templates (#)"] = all_templates[novel_template].nunique() 
        table.loc[name, r"$\rho$"] = test_evaluator.get_density_emd(novel)
        table.loc[name, "$E$"] = test_evaluator.get_cdvae_e_emd(novel)
        table.loc[name, "# Elements"] = test_evaluator.get_num_elements_emd(novel)
        n_elements_dist[name] = novel.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(novel)
    table.loc[name, "P1 (%)"] = 100 * (novel.group == 1).mean()
    # table.loc[name, "# DoF"] = test_evaluator.get_dof_emd(novel)
    table.loc[name, "Space Group"] = test_evaluator.get_sg_chi2(novel)
    #try:
    #    table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    #except KeyError:
    #    pass
    #table.loc[name, "Elements"] = test_evaluator.get_elements_chi2(novel)
    if "corrected_chgnet_ehull" in novel.columns:
        # S.U.N. is measured with respect to the initial structures
        has_ehull = data.corrected_chgnet_ehull.notna()
        data_is_represented = data.structure.apply(check_represented_composition)
        is_sun = (novel.corrected_chgnet_ehull <= E_hull_threshold) # & (novel.elements.apply(lambda x: len(frozenset(x))) >= 2)
        table.loc[name, "S.U.N. (%)"] = 100 * is_sun.sum() / has_ehull.sum()
        is_represented_novel = novel.structure.apply(check_represented_composition)
        table.loc[name, "R.S.U.N. (%)"] = 100 * (is_sun & is_represented_novel).sum() / (has_ehull & data_is_represented).sum()
        table.loc[name, "~R.S.U.N. (%)"] = 100 * (is_sun & ~is_represented_novel).sum() / (has_ehull & ~data_is_represented).sum()
        table.loc[name, "S.S.U.N. (%)"] = 100 * (is_sun & (novel.group != 1)).sum() / has_ehull.sum()
        has_ehull_top_10 = data[data.spacegroup_number.isin(top_10_groups)].corrected_chgnet_ehull.notna().sum()
        table.loc[name, "Top-10 S.U.N. (%)"] = 100 * is_sun[novel.spacegroup_number.isin(top_10_groups)].sum() / has_ehull_top_10
    #if transformations == ("split", "train"):
        # Train forms the baseline of the hull
      #  test_dataset = all_datasets[("split", "test")].data
      #  test_with_ehull = test_dataset[test_dataset.corrected_chgnet_ehull.notna()]
      #  test_unique = filter_by_unique_structure(test_with_ehull)
      #  test_novel = novelty_filter.get_novel(test_unique)
      #  table.loc[name, "S.U.N. (%)"] = 100 * (test_novel.corrected_chgnet_ehull <= E_hull_threshold).sum() / len(test_with_ehull)
n_elements_dist["MP-20"] = mp_20.elements.apply(lambda e: len(frozenset(e))).value_counts() / len(mp_20)
table

  0%|          | 0/18 [00:00<?, ?it/s]

train unique: 1000 / 1000 = 100.000%
WyckoffTransformer-raw unique: 999 / 1000 = 99.900%
WyFormer-harmonic-raw unique: 999 / 1000 = 99.900%
WyFormer-letters unique: 995 / 1000 = 99.500%
WyFormer-letters-DiffCSP++ unique: 958 / 959 = 99.896%
SymmCD unique: 998 / 1000 = 99.800%
WyFormer unique: 1000 / 1000 = 100.000%
WyFormer-harmonic-DiffCSP++ unique: 983 / 983 = 100.000%
WyForDiffCSP++ unique: 1000 / 1000 = 100.000%
WyLLM-naive-DiffCSP++ unique: 979 / 981 = 99.796%
WyLLM-vanilla-DiffCSP++ unique: 567 / 999 = 56.757%
WyLLM-site-symmetry-DiffCSP++ unique: 998 / 999 = 99.900%
CrystalFormer unique: 988 / 992 = 99.597%
DiffCSP++ unique: 999 / 1000 = 99.900%
DiffCSP unique: 996 / 1000 = 99.600%
FlowMM unique: 994 / 997 = 99.699%
MiAD unique: 997 / 1000 = 99.700%
WyCryst unique: 994 / 994 = 100.000%


Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
WyckoffTransformer-raw,90.990991,,,,,,,,,,,,,,,2.20022,0.207734,
WyFormer-harmonic-raw,92.892893,,,,,,,,,,,,,,,2.37069,0.174319,
WyFormer-letters,90.351759,,,,,,,,,,,,,,,1.557286,0.209694,
WyFormer-letters-DiffCSP++,90.39666,0.207852,99.538106,82.909931,98.152425,96.982092,0.434744,0.038203,0.112837,31.386861,29.6875,31.812256,33.6714,250.0,38.167053,1.160093,0.218429,30.865485
SymmCD,88.777555,0.190745,95.823928,84.875847,99.548533,94.660623,0.621964,0.102004,0.524995,34.1,33.333333,34.271726,38.045738,161.0,28.857479,2.355713,0.241021,33.2
WyFormer,90.0,0.2,99.555556,80.444444,98.666667,96.705726,0.735854,0.053076,0.096573,38.938939,43.455497,37.871287,44.190871,180.0,22.098214,3.236607,0.222591,38.038038
WyFormer-harmonic-DiffCSP++,92.268566,0.175303,99.779493,82.690187,99.228225,95.931904,0.597028,0.096215,0.040025,35.503561,32.53012,36.107711,41.078838,254.0,38.01105,2.320442,0.173972,34.486267
WyForDiffCSP++,89.5,0.198883,99.664804,80.335196,99.217877,96.783109,0.664005,0.050209,0.098194,36.6,38.219895,36.217553,39.314516,186.0,22.309417,1.457399,0.212014,35.9
WyLLM-naive-DiffCSP++,95.097038,0.185822,99.677766,82.814178,98.496241,95.556047,0.472273,0.059707,0.01768,31.600408,28.49162,32.294264,36.938776,293.0,39.978448,1.293103,0.172013,30.88685
WyLLM-vanilla-DiffCSP++,95.590829,0.765683,99.815498,88.745387,94.464945,59.672784,2.225215,0.233504,0.252764,,,,,87.0,28.096118,2.033272,0.621622,


In [17]:
symmetry_table = table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [18]:
symmetry_table.style.format("{:.1f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
WyFormer-letters-DiffCSP++,250.0,1.2,0.2,31.4,30.9
WyForDiffCSP++,186.0,1.5,0.2,36.6,35.9


In [19]:
print(symmetry_table.to_markdown())

| Method                     |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| WyFormer-letters-DiffCSP++ |                           250 |  1.16009 |      0.218429 |      31.3869 |        30.8655 |
| WyForDiffCSP++             |                           186 |  1.4574  |      0.212014 |      36.6    |        35.9    |


In [20]:
symmetry_table_symmcd = table.loc[["SymmCD", "WyForDiffCSP++"],
                           ["Novel Uniques Templates (#)", "P1 (%)", "Space Group", "S.U.N. (%)", "S.S.U.N. (%)"]]

In [21]:
symmetry_table_symmcd.style.format("{:.2f}").highlight_max(axis=0, props="font-weight: bold")

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
SymmCD,161.0,2.36,0.24,34.1,33.2
WyForDiffCSP++,186.0,1.46,0.21,36.6,35.9


In [22]:
print(symmetry_table_symmcd.to_markdown())

| Method         |   Novel Uniques Templates (#) |   P1 (%) |   Space Group |   S.U.N. (%) |   S.S.U.N. (%) |
|:---------------|------------------------------:|---------:|--------------:|-------------:|---------------:|
| SymmCD         |                           161 |  2.35571 |      0.241021 |         34.1 |           33.2 |
| WyForDiffCSP++ |                           186 |  1.4574  |      0.212014 |         36.6 |           35.9 |


In [23]:
table.loc[["WyFormer-letters-DiffCSP++", "WyForDiffCSP++"],
          ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]]

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyFormer-letters-DiffCSP++,90.39666,99.538106,82.909931,98.152425,96.982092,0.434744,0.038203,0.112837
WyForDiffCSP++,89.5,99.664804,80.335196,99.217877,96.783109,0.664005,0.050209,0.098194


In [24]:
table.loc[:, ["S.U.N. (%)", "S.S.U.N. (%)"]]

Unnamed: 0_level_0,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1
WyckoffTransformer-raw,,
WyFormer-harmonic-raw,,
WyFormer-letters,,
WyFormer-letters-DiffCSP++,31.386861,30.865485
SymmCD,34.1,33.2
WyFormer,38.938939,38.038038
WyFormer-harmonic-DiffCSP++,35.503561,34.486267
WyForDiffCSP++,36.6,35.9
WyLLM-naive-DiffCSP++,31.600408,30.88685
WyLLM-vanilla-DiffCSP++,,


In [25]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set).mean()

0.9713685606898077

In [27]:
table.to_csv("tables/paper_summary_table.csv")
table.to_pickle("tables/paper_summary_table.pkl")

In [28]:
max_subset=["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", "S.S.U.N. (%)", "S.U.N. (%)", "Novel Template (%)"]
# -1 to exclude the MP-20 training set
def highlight_max_value(s):
    if s.name not in max_subset:
        return ['' for _ in s]
    is_max = s == s.max()
    #is_max.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_max]

min_subset=[r"$\rho$", "$E$", "# Elements", "# DoF", "Space Group", "Elements", "P1 (%)"]
def highlight_min_value(s):
    if s.name not in min_subset:
        return ['' for _ in s]
    is_min = s == s.min()
    #is_min.iloc[-1] = False
    return ['font-weight: bold' if v else '' for v in is_min]

In [29]:
def prettify(table):
    return table.style.format({
    "Novelty (%)": "{:.2f}",
    "Structural": "{:.2f}",
    "Compositional": "{:.2f}",
    "Recall": "{:.2f}",
    "Precision": "{:.2f}",
    r"$\rho$": "{:.2f}",
    "$E$": "{:.3f}",
    "# Elements": "{:.3f}",
    "# DoF": "{:.3f}",
    "Space Group": "{:.3f}",
    "Elements": "{:.3f}",
    "Novel Template (%)": "{:.2f}",
    "P1 (%)": "{:.2f}",
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
}).apply(highlight_max_value).apply(highlight_min_value)
prettify(table)

Unnamed: 0_level_0,Novelty (%),Represented composition,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements,S.U.N. (%),R.S.U.N. (%),~R.S.U.N. (%),Top-10 S.U.N. (%),Novel Uniques Templates (#),Novel Template (%),P1 (%),Space Group,S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
WyckoffTransformer-raw,90.99,,,,,,,,,,,,,,,2.2,0.208,
WyFormer-harmonic-raw,92.89,,,,,,,,,,,,,,,2.37,0.174,
WyFormer-letters,90.35,,,,,,,,,,,,,,,1.56,0.21,
WyFormer-letters-DiffCSP++,90.4,0.207852,99.54,82.91,98.15,96.98,0.43,0.038,0.113,31.4,29.6875,31.812256,33.6714,250.0,38.17,1.16,0.218,30.9
SymmCD,88.78,0.190745,95.82,84.88,99.55,94.66,0.62,0.102,0.525,34.1,33.333333,34.271726,38.045738,161.0,28.86,2.36,0.241,33.2
WyFormer,90.0,0.2,99.56,80.44,98.67,96.71,0.74,0.053,0.097,38.9,43.455497,37.871287,44.190871,180.0,22.1,3.24,0.223,38.0
WyFormer-harmonic-DiffCSP++,92.27,0.175303,99.78,82.69,99.23,95.93,0.6,0.096,0.04,35.5,32.53012,36.107711,41.078838,254.0,38.01,2.32,0.174,34.5
WyForDiffCSP++,89.5,0.198883,99.66,80.34,99.22,96.78,0.66,0.05,0.098,36.6,38.219895,36.217553,39.314516,186.0,22.31,1.46,0.212,35.9
WyLLM-naive-DiffCSP++,95.1,0.185822,99.68,82.81,98.5,95.56,0.47,0.06,0.018,31.6,28.49162,32.294264,36.938776,293.0,39.98,1.29,0.172,30.9
WyLLM-vanilla-DiffCSP++,95.59,0.765683,99.82,88.75,94.46,59.67,2.23,0.234,0.253,,,,,87.0,28.1,2.03,0.622,


In [30]:
LLM_columns_1 = ["Novelty (%)", "Structural", "Compositional", "Recall", "Precision", r"$\rho$", "$E$", "# Elements"]
LLM_columns_2 = ["Novel Uniques Templates (#)", "P1 (%)", "Space Group"]

In [31]:
LLM_rows = ["WyForDiffCSP++", "WyLLM-naive-DiffCSP++", "WyLLM-vanilla-DiffCSP++", "WyLLM-site-symmetry-DiffCSP++"]

In [32]:
pt = prettify(table.loc[LLM_rows, LLM_columns_1])
pt.to_latex("tables/llm_1.tex", siunitx=True, convert_css=True)
pt

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyForDiffCSP++,89.5,99.66,80.34,99.22,96.78,0.66,0.05,0.098
WyLLM-naive-DiffCSP++,95.1,99.68,82.81,98.5,95.56,0.47,0.06,0.018
WyLLM-vanilla-DiffCSP++,95.59,99.82,88.75,94.46,59.67,2.23,0.234,0.253
WyLLM-site-symmetry-DiffCSP++,89.58,99.89,83.89,99.44,96.31,0.29,,0.039


In [34]:
prettify(table.loc[["WyForDiffCSP++", "DiffCSP++", "DiffCSP", "FlowMM", "MiAD"], LLM_columns_1])

Unnamed: 0_level_0,Novelty (%),Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
WyForDiffCSP++,89.5,99.66,80.34,99.22,96.78,0.66,0.05,0.098
DiffCSP++,89.69,100.0,85.04,99.33,95.79,0.15,0.036,0.504
DiffCSP,90.06,100.0,80.94,99.55,96.2,0.82,0.052,0.294
FlowMM,90.14,96.21,82.48,99.67,96.35,0.31,0.044,0.115
MiAD,86.36,99.07,84.2,99.88,94.33,0.19,0.055,0.066


In [44]:
len(GeneratedDataset.from_cache(("WyckoffLLM-naive", )).data)

9492

In [45]:
len(GeneratedDataset.from_cache(("WyckoffLLM-site-symmetry", )).data)

8955

In [46]:
len(GeneratedDataset.from_cache(("WyckoffLLM-vanilla", )).data)

2866

In [40]:
pt = prettify(table.loc[LLM_rows, LLM_columns_2])
pt.to_latex("tables/llm_2.tex", siunitx=True, convert_css=True)
pt

Unnamed: 0_level_0,Novel Uniques Templates (#),P1 (%),Space Group
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
WyForDiffCSP++,186,1.46,0.212
WyLLM-naive-DiffCSP++,237,1.38,0.167
WyLLM-vanilla-DiffCSP++,87,2.03,0.621
WyLLM-site-symmetry-DiffCSP++,191,2.24,0.158


In [37]:
prettify(table.iloc[:, :9]).to_latex("tables/summary_similarity_raw.tex", siunitx=True, convert_css=True)
prettify(table.iloc[:, 9:]).to_latex("tables/summary_symmetry_raw.tex", siunitx=True, convert_css=True)

In [1]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 992
for name, transformations in tqdm(datasets.items()):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("tables/cdvae_metrics_table.csv")
prettify(cdvae_table).to_latex("tables/cdvae_metrics_table.tex", siunitx=True, convert_css=True)
prettify(cdvae_table)

NameError: name 'evaluation' is not defined

In [40]:
raw_test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)
cdvae_table = pd.DataFrame(index=pd.Index(raw_datasets.keys(), tupleize_cols=False),
    columns=[
        "Structural", "Compositional",
        "Recall", "Precision",
        r"$\rho$", "$E$", "# Elements"])
sample_size = 992
for name, transformations in tqdm(list(raw_datasets.items())):
    dataset = all_datasets[transformations]
    if "structure" in dataset.data.columns:
        cdvae_table.loc[name, "Compositional"] = 100*dataset.data.smact_validity.mean()
        cdvae_table.loc[name, "Structural"] = 100*dataset.data.structural_validity.mean()
        valid = dataset.data[dataset.data.naive_validity]
        cov_metrics = raw_test_evaluator.get_coverage(valid.cdvae_crystal)
        cdvae_table.loc[name, "Recall"] = 100*cov_metrics["cov_recall"]
        cdvae_table.loc[name, "Precision"] = 100*cov_metrics["cov_precision"]
        cdvae_table.loc[name, r"$\rho$"] = raw_test_evaluator.get_density_emd(valid)
        cdvae_table.loc[name, "$E$"] = raw_test_evaluator.get_cdvae_e_emd(valid)
        cdvae_table.loc[name, "# Elements"] = raw_test_evaluator.get_num_elements_emd(valid)
cdvae_table.to_csv("tables/cdvae_metrics_no_relax_table.csv")
prettify(cdvae_table).to_latex("tables/cdvae_metrics_no_relax_table.tex", siunitx=True, convert_css=True)
prettify(cdvae_table)

  0%|          | 0/8 [00:00<?, ?it/s]

Unnamed: 0,Structural,Compositional,Recall,Precision,$\rho$,$E$,# Elements
WyFormer,99.6,81.4,98.77,95.94,0.39,0.078,0.081
WyFormerDiffCSP++,99.8,81.4,99.51,95.81,0.36,0.083,0.079
WyLLM-DiffCSP++,99.8,83.33,98.91,94.09,0.19,0.09,0.029
CrystalFormer,93.39,84.98,99.62,94.56,0.19,0.208,0.128
DiffCSP++,99.94,85.13,99.67,99.54,0.31,0.069,0.399
DiffCSP,100.0,83.2,99.82,99.51,0.35,0.095,0.347
FlowMM,96.87,83.11,99.73,99.39,0.12,0.073,0.094
WyCryst,99.9,82.09,99.63,96.16,0.44,0.33,0.322
