In [1]:
from evaluation.generated_dataset import GeneratedDataset, StructureStorage, WyckoffStorage
mp_20_test = GeneratedDataset()
mp_20_test.load_structures("cdvae/data/mp_20/test.csv", storage_type=StructureStorage.CDVAE_csv_cif)
mp_20_test.load_wyckoffs("cache/mp_20/data.pkl.gz", storage_type=WyckoffStorage.WTCache, cache_key="test")





In [2]:
crystalformer_raw = GeneratedDataset()
crystalformer_raw.load_structures("generated/CrystalFormer_mp_20.csv.gz", StructureStorage.CrystalFormer)

In [3]:
crystalformer_raw.compute_wyckoffs(n_jobs=20)

In [4]:
crystalformer_raw.convert_wyckoffs_to_pyxtal()


In [5]:
crystalformer_raw.compute_cdvae_crystals(n_jobs=20)



In [6]:
crystalformer_raw.data.iloc[0]

structure                      [[5.03535423 5.03535423 5.03535423] Cu, [5.035...
site_symmetries                                      [-43m, -43m, .3m, .3m, .3m]
elements                                                     [Cu, Cu, S, Se, Cr]
multiplicity                                                  [4, 4, 16, 16, 16]
wyckoff_letters                                                  [b, c, e, e, e]
sites_enumeration                                                [1, 2, 0, 0, 0]
dof                                                              [0, 0, 1, 1, 1]
spacegroup_number                                                            216
sites_enumeration_augmented    ((2, 0, 0, 0, 0), (1, 3, 0, 0, 0), (0, 3, 0, 0...
composition                                       {Cu: 8, S: 16, Se: 16, Cr: 16}
group                                                                        216
sites                                            [[4b, 4c], [16e], [16e], [16e]]
species                     

In [7]:
from typing import Dict
from pandas import Series
def record_to_augmented_fingerprint(row: Dict|Series) -> tuple:
    """
    Computes a fingerprint includeing possible Wyckoff position enumeration.
    Args:
        row contains the Wyckoff information:
        - spacegroup_number
        - elements
        - site_symmetries
        - sites_enumeration_augmented
    Returns:
        A tuple of the spacegroup number and a frozenset of tuples of the elements,
        site symmetries and the Wyckoff position enumeration.
    """
    spacegroup_number = row["spacegroup_number"]
    transposed_augmentations = zip(*row["sites_enumeration_augmented"])
    return (
        row["spacegroup_number"],
        frozenset(
            map(
                tuple,
                zip(row["elements"], row["site_symmetries"], *transposed_augmentations)
            )
        )
    )

In [18]:
mp_20_test_fingerprints = mp_20_test.data.apply(record_to_augmented_fingerprint, axis=1)
crystalformer_raw_fingerprints = crystalformer_raw.data.apply(record_to_augmented_fingerprint, axis=1)

In [20]:
len(frozenset(mp_20_test_fingerprints))/len(mp_20_test_fingerprints)

0.8841476895865575

In [13]:
len(crystalformer_raw_fingerprints)

999

In [24]:
len(frozenset(crystalformer_raw_fingerprints) - frozenset(mp_20_test_fingerprints))/len(frozenset(crystalformer_raw_fingerprints))

0.9579055441478439