In [1]:
from typing import Tuple
import pandas as pd
from scipy.stats import pearsonr
from tqdm.notebook import tqdm
import sys
sys.path.append('..')
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [3]:
dft_datasets = {
    "WyFormerDirect": ("WyckoffTransformer", "DFT"),
    "WyFormerCrySPR": ("WyckoffTransformer", "CrySPR", "CHGNet_fix", "DFT"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "DFT"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "DFT"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "DFT"),
    "SymmCD": ("SymmCD", "DFT"),
    "DiffCSP": ("DiffCSP", "DFT"),
    "CrystalFormer": ("CrystalFormer", "DFT"),
    "DiffCSP++": ("DiffCSP++", "DFT"),
    "FlowMM": ("FlowMM", "DFT")
}

source_datasets = {name: t[:-1] for name, t in dft_datasets.items()}

In [4]:
chgnet_datasets = {
    "WyFormerDirect": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerCrySPR": ("WyckoffTransformer", "CrySPR", "CHGNet_fix_release"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "CHGNet_fix"),
    "WyLLM-DiffCSP++": ("WyckoffLLM-naive", "DiffCSP++", "CHGNet_fix"),
    "WyFormer-letters-DiffCSP++": ("WyckoffTransformer-letters", "DiffCSP++", "CHGNet_fix"),
    "SymmCD": ("SymmCD", "CHGNet_fix"),
    "DiffCSP": ("DiffCSP", "CHGNet_fix"),
    "CrystalFormer": ("CrystalFormer", "CHGNet_fix_release"),
    "DiffCSP++": ("DiffCSP++", "CHGNet_fix_release"),
    "FlowMM": ("FlowMM", "CHGNet_fix")
}

In [5]:
chgnet_data = load_all_from_config(datasets=list(chgnet_datasets.values()) + [('WyckoffTransformer', 'CrySPR', 'CHGNet_fix')])

In [6]:
all_datasets = load_all_from_config(
    datasets=list(dft_datasets.values()) + list(source_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test")],
    dataset_name="mp_20")

In [7]:
wycryst_transformations = ('WyCryst', 'CrySPR', 'CHGNet_fix')
source_datasets["WyCryst"] = wycryst_transformations
chgnet_datasets["WyCryst"] = wycryst_transformations
chgnet_data[wycryst_transformations] = GeneratedDataset.from_cache(wycryst_transformations, "mp_20_biternary")
dft_datasets["WyCryst"] = tuple(list(wycryst_transformations) + ["DFT"])
all_datasets[dft_datasets["WyCryst"]] = GeneratedDataset.from_cache(dft_datasets["WyCryst"], "mp_20_biternary")

In [8]:
excluded_categories = frozenset(["radioactive", "rare_earth_metal", "noble_gas"])
from pymatgen.core import Structure
def check_composition(structure: Structure) -> bool:
    for category in excluded_categories:
        if structure.composition.contains_element_type(category):
            return False
    return True

In [9]:
novelty_reference = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
novelty_filter = NoveltyFilter(novelty_reference)

In [10]:
import evaluation.statistical_evaluator
test_evaluator = evaluation.statistical_evaluator.StatisticalEvaluator(all_datasets[('split', 'test')].data)

In [11]:
import evaluation.novelty
train_w_template_set = frozenset(novelty_reference.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1))

In [12]:
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[transformations[:-1]].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

In [13]:
mp_20 = pd.concat([
    all_datasets[('split', 'train')].data,
    all_datasets[('split', 'test')].data,
    all_datasets[('split', 'val')].data], axis=0, verify_integrity=True)
(mp_20.spacegroup_number == 1).mean()
mp_20.smact_validity.mean()

0.9057020937893829

In [14]:
from collections import Counter
from operator import itemgetter
from itertools import chain
element_counts = Counter(chain(*mp_20.elements))

In [15]:
represented_elements=frozenset(map(itemgetter(0), element_counts.most_common(30)))

In [16]:
def check_represented_composition(structure: Structure) -> bool:
    for element in structure.composition:
        if element not in represented_elements:
            return False
    return True

In [17]:
top_10_groups = frozenset(mp_20.spacegroup_number.value_counts().iloc[:10].index)
n_elements_dist = {}

Validity
1. Vanilla; Valid records: 2866 / 9648 = 29.71%
2. Naive; Valid records: 9492 / 9804 = 96.82%
3. Site Symmetry; Valid records: 8955 / 9709 = 92.23%

In [18]:
table = pd.DataFrame(
    index=dft_datasets.keys(), columns=[
        "DFT dataset size",
        "Source Novelty (%)",
        "In-DFT Novelty (%)",
        "S.U.N. (%)",
        "P1 in source (%)",
        "S.S.U.N. (%)"])
table.index.name = "Method"
E_hull_threshold = 0.08
for name, transformations in tqdm(dft_datasets.items()):
    dataset = all_datasets[transformations]
    table.loc[name, "DFT dataset size"] = len(dataset.data)
    try:
        source_dataset = all_datasets[transformations[:-1]]
    except KeyError:
        source_dataset = chgnet_data[transformations[:-1]]
    chgnet_dataset = chgnet_data[chgnet_datasets[name]]

    unique = filter_by_unique_structure(dataset.data)
    novel = novelty_filter.get_novel(unique)
    table.loc[name, "In-DFT Novelty (%)"] = 100 * len(novel) / len(unique)
    source_novel = novelty_filter.get_novel(source_dataset.data)
    source_novelty = 100 * len(source_novel) / len(source_dataset.data)
    table.loc[name, "Source Novelty (%)"] = len(novel) / len(unique) * source_novelty
    table.loc[name, "P1 in source (%)"] = 100 * (source_novel.group == 1).mean()
    try:
        table.loc[name, "SG preserved (%)"] = 100 * is_sg_preserved(novel.spacegroup_number, transformations).mean()
    except KeyError:
        pass
    # source_novel_symmetric = (source_novel.group != 1).sum() / len(source_dataset.data)
    # table["Source Novel !P1 (%)"] = 100 * source_novel_symmetric
    # DFT failure == unreal structure
    dft_structures = 105
    has_ehull = dataset.data.corrected_e_hull.notna()
    is_sun = (novel.corrected_e_hull <= E_hull_threshold) # & (novel.elements.apply(lambda x: len(frozenset(x))) >= 2)
    table.loc[name, "S.U.N. (%)"] = source_novelty * is_sun.sum() / dft_structures
    table.loc[name, "total_sun"] = is_sun.sum().astype(int)
    table.loc[name, "S.S.U.N. (%)"] = source_novelty * (is_sun & (novel.group != 1)).sum() / dft_structures
    table.loc[name, "total_ssun"] = (is_sun & (novel.group != 1)).sum().astype(int)
    table.loc[name, "P1 in stable (%)"] = 100 * (novel[is_sun].group == 1).mean()

    chgnet_unique = filter_by_unique_structure(chgnet_dataset.data)
    chgnet_novel = novelty_filter.get_novel(chgnet_unique)
    chgnet_is_sun = (chgnet_novel.corrected_chgnet_ehull < E_hull_threshold)
    #table.loc[name, "CHGNet dataset size"] = chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
    table.loc[name, "S.U.N. (CHGNet) (%)"] =  100 * chgnet_is_sun.sum() / chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
    table.loc[name, "S.S.U.N. (CHGNet) (%)"] = 100 * (chgnet_is_sun & (chgnet_novel.group != 1)).sum() / chgnet_dataset.data.corrected_chgnet_ehull.notna().sum()
    
    chgnet_dft_available = chgnet_dataset.data.reindex(dataset.data.index[has_ehull])
    table.loc[name, "r DFT CHGNet"] = \
        pearsonr((chgnet_dft_available.corrected_chgnet_ehull < E_hull_threshold).astype(float),
                 (dataset.data.corrected_e_hull < E_hull_threshold).astype(float)).correlation
table

  0%|          | 0/11 [00:00<?, ?it/s]

Unnamed: 0_level_0,DFT dataset size,Source Novelty (%),In-DFT Novelty (%),S.U.N. (%),P1 in source (%),S.S.U.N. (%),SG preserved (%),total_sun,total_ssun,P1 in stable (%),S.U.N. (CHGNet) (%),S.S.U.N. (CHGNet) (%),r DFT CHGNet
Method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
WyFormerDirect,94,90.09,100.0,4.29,1.964702,4.29,86.170213,5.0,5.0,0.0,39.239239,38.238238,0.269486
WyFormerCrySPR,97,90.0,100.0,6.857143,1.555556,6.857143,96.907216,8.0,8.0,0.0,39.239239,38.238238,0.335979
WyFormerDiffCSP++,99,88.59596,98.989899,12.785714,1.564246,12.785714,90.816327,15.0,15.0,0.0,36.7,36.0,0.435583
WyLLM-DiffCSP++,98,94.578313,100.0,9.007458,1.380042,9.007458,94.897959,10.0,10.0,0.0,31.600408,30.88685,0.546646
WyFormer-letters-DiffCSP++,96,90.562249,100.0,6.899981,1.108647,6.899981,91.666667,8.0,8.0,0.0,31.491137,30.96976,0.320983
SymmCD,88,90.649077,100.0,12.086544,2.177203,12.086544,92.045455,14.0,14.0,0.0,34.551148,34.070981,0.102101
DiffCSP,96,88.874479,98.958333,19.672667,31.566641,13.685333,55.789474,23.0,16.0,30.434783,57.4,40.6,0.326947
CrystalFormer,85,77.06059,98.823529,13.367653,1.797176,13.367653,91.666667,18.0,18.0,0.0,37.600806,37.399194,0.229633
DiffCSP++,95,88.95,100.0,7.624286,1.843732,7.624286,94.736842,9.0,9.0,0.0,41.4,40.8,0.322832
FlowMM,97,87.837258,93.814433,16.942341,40.690097,16.050639,43.956044,19.0,18.0,5.263158,50.952859,30.692076,-0.208221


In [121]:
def prettify(table):
    return table.style.format({
    "S.U.N. (%)": "{:.1f}",
    "S.S.U.N. (%)": "{:.1f}",
    #"S.U.N. (CHGNet) (%)": "{:.1f}",
    #"S.S.U.N. (CHGNet) (%)": "{:.1f}",
    #"r DFT CHGNet": "{:.2f}",
}).highlight_max(props="font-weight: bold", axis=0, subset=["S.U.N. (%)", "S.S.U.N. (%)"])

In [122]:
selected_table = table.loc[:, ["S.U.N. (%)", "S.S.U.N. (%)"]]

In [123]:
pretty_table = prettify(selected_table)
pretty_table.to_latex("tables/dft.tex", siunitx=True, convert_css=True)
pretty_table

Unnamed: 0_level_0,S.U.N. (%),S.S.U.N. (%)
Method,Unnamed: 1_level_1,Unnamed: 2_level_1
WyFormerDirect,4.3,4.3
WyFormerCrySPR,6.9,6.9
WyFormerDiffCSP++,12.8,12.8
DiffCSP,19.7,13.7
CrystalFormer,13.4,13.4
DiffCSP++,7.6,7.6
FlowMM,16.9,16.1
WyCryst,5.5,5.5


In [25]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_anonymous_fingerprint, axis=1).isin(train_w_template_set).mean()

0.9713685606898077

In [26]:
all_datasets[('split', 'test')].data.apply(evaluation.novelty.record_to_strict_AFLOW_fingerprint, axis=1).isin(train_strict_AFLOW_set).mean()

0.9328985186822906

In [27]:
CF_CG = GeneratedDataset.from_cache(("CrystalFormer", "CHGNet_fix_release"))

In [99]:
from scipy.stats import ttest_ind
import numpy as np
def get_observation(name, column="total_ssun"):
    all_observations = np.zeros(dft_structures)
    all_observations[:int(table.at[name, column])] = table.loc[name, "Source Novelty (%)"]/100
    return all_observations

In [101]:
for second in table.index:
    print(second, ttest_ind(get_observation("FlowMM"), get_observation(second)))

WyFormerDirect TtestResult(statistic=2.8699643276771534, pvalue=0.004529164843545496, df=208.0)
WyFormerCrySPR TtestResult(statistic=2.0489237904086535, pvalue=0.04172504759818785, df=208.0)
WyFormerDiffCSP++ TtestResult(statistic=0.5399251990582908, pvalue=0.5898261986038633, df=208.0)
DiffCSP TtestResult(statistic=0.3358706725126662, pvalue=0.7373069464508927, df=208.0)
CrystalFormer TtestResult(statistic=0.4278113819160564, pvalue=0.669231052265232, df=208.0)
DiffCSP++ TtestResult(statistic=1.8300352833914524, pvalue=0.0686759122341314, df=208.0)
FlowMM TtestResult(statistic=0.0, pvalue=1.0, df=208.0)


In [103]:
for second in table.index:
    print(second, ttest_ind(get_observation("DiffCSP", column="total_sun"), get_observation(second, column="total_sun")))

WyFormerDirect TtestResult(statistic=3.7329421055392644, pvalue=0.0002442600601858129, df=208.0)
WyFormerCrySPR TtestResult(statistic=2.933952952207651, pvalue=0.0037223172815396767, df=208.0)
WyFormerDiffCSP++ TtestResult(statistic=1.444489793625256, pvalue=0.15010593136692968, df=208.0)
DiffCSP TtestResult(statistic=0.0, pvalue=1.0, df=208.0)
CrystalFormer TtestResult(statistic=1.3621411208732088, pvalue=0.1746266271124153, df=208.0)
DiffCSP++ TtestResult(statistic=2.7203519120538746, pvalue=0.007073135428130352, df=208.0)
FlowMM TtestResult(statistic=0.7296066705938044, pvalue=0.4664514752444603, df=208.0)
