In [3]:
import pandas as pd
from tqdm.notebook import tqdm
import sys
sys.path.append('..')
from evaluation.generated_dataset import GeneratedDataset, load_all_from_config
from evaluation.novelty import NoveltyFilter, filter_by_unique_structure

In a future release, impute_nan will be set to True by default.
                    This means that features that are missing or are NaNs for elements
                    from the data source will be replaced by the average of that value
                    over the available elements.
                    This avoids NaNs after featurization that are often replaced by
                    dataset-dependent averages.


In [1]:
dft_datasets = {
    "WyFormerDirect": ("WyckoffTransformer", "DFT"),
    "WyFormerCrySPR": ("WyckoffTransformer", "CrySPR", "CHGNet_fix", "DFT"),
    "WyFormerDiffCSP++": ("WyckoffTransformer", "DiffCSP++", "DFT"),
    "DiffCSP": ("DiffCSP", "DFT"),
    "DiffCSP++": ("DiffCSP++", "DFT")
}

source_datasets = {name: t[:-1] for name, t in dft_datasets.items()}

In [4]:
all_datasets = load_all_from_config(
    datasets=list(dft_datasets.values()) + list(source_datasets.values()) + \
        [("split", "train"), ("split", "val"), ("split", "test")],
    dataset_name="mp_20")

In [8]:
from typing import Tuple
def is_sg_preserved(relaxed_sg, transformations: Tuple[str]) -> pd.Series:
    source_sg = all_datasets[(transformations[0],)].data.spacegroup_number
    return relaxed_sg == source_sg.reindex_like(relaxed_sg)

def is_wyckoff_preserved(transformations: Tuple[str]) -> pd.Series:
    dft_fingerprint = all_datasets[transformations].data.fingerprint
    initial_fingerprint = all_datasets[(transformations[0],)].data.fingerprint
    return dft_fingerprint == initial_fingerprint.reindex_like(dft_fingerprint)

In [12]:
is_wyckoff_preserved(("WyckoffTransformer", "CrySPR", "CHGNet_fix", "DFT")).sum()

80

In [17]:
transformations = ("WyckoffTransformer", "DiffCSP++", "DFT")
preserved = is_wyckoff_preserved(transformations)

In [21]:
stable = all_datasets[transformations].data.corrected_e_hull <= 0.08

In [24]:
preserved[stable].sum(), preserved[~stable].sum()

(14, 63)

In [25]:
transformations_naive = ("WyckoffTransformer", "DFT")

In [34]:
fingerprint_naive = all_datasets[transformations_naive].data.fingerprint
fingerprint_transformed = all_datasets[transformations].data.fingerprint

# Reindex the fingerprints to align the indices
common_index = fingerprint_naive.index.intersection(fingerprint_transformed.index)
aligned_fingerprint_naive = fingerprint_naive.reindex(common_index)
aligned_fingerprint_transformed = fingerprint_transformed.reindex(common_index)

# Compare the fingerprints
comparison = aligned_fingerprint_naive == aligned_fingerprint_transformed
comparison

true_index
2       True
3      False
4       True
5       True
6       True
       ...  
113     True
114     True
115     True
116     True
117     True
Name: fingerprint, Length: 88, dtype: bool

In [35]:
comparison.sum()

77