In [1]:
import pandas as pd
import pickle
from pymatgen.core import Element
import gzip

In [2]:
with gzip.open("cache/mp_20/data.pkl.gz", "rb") as f:
    data = pickle.load(f)

In [3]:
import json
with gzip.open("generated/AI4AM_mp_20.json.gz") as f:
    generated = json.load(f)

In [4]:
with open("cache/wychoffs_enumerated_by_ss.pkl.gz", "rb") as f:
    wychoffs_index = pickle.load(f)[2]

In [5]:
def generated_to_fingerprint(wy_dict):
    elements = []
    site_symmetries = []
    for specie, sites in zip(wy_dict["species"], wy_dict["sites"]):
        element = Element(specie)
        for site in sites:
            elements.append(element)
            site_symmetries.append(wychoffs_index[wy_dict["group"]][site[-1:]])
    return (
        wy_dict["group"],
        frozenset(map(tuple, zip(elements, site_symmetries))),
    )

In [6]:
generated_fp = list(map(generated_to_fingerprint, generated))

In [7]:
generated_fp[0]

(71,
 frozenset({(Element Dy, 'mmm'),
            (Element In, 'm2m'),
            (Element In, 'mm2'),
            (Element Cu, '-1'),
            (Element Cu, 'mm2')}))

In [8]:
def to_fingerprint(row, include_enum=False):
    if include_enum:
        return (
            row["spacegroup_number"],
            frozenset(map(tuple, zip(row["elements"], row["site_symmetries"]))),
            frozenset(map(tuple, row["sites_enumeration_augmented"]))
            )
    return (
        row["spacegroup_number"],
        frozenset(map(tuple, zip(row["elements"], row["site_symmetries"]))))

In [9]:
combined_train = pd.concat([data["train"], data["val"]], axis=0)

In [15]:
train_val_fp = frozenset(combined_train.map(to_fingerprint, axis=1))
test_fp = data['test'].apply(to_fingerprint, axis=1)

TypeError: to_fingerprint() got an unexpected keyword argument 'axis'

In [16]:
test_fp.map(lambda x: x in train_val_fp).value_counts()[False] / len(test_fp)

0.9451691355295158

In [256]:
sum((fp not in train_val_fp for fp in generated_fp[:len(test_fp)])) / len(test_fp)

0.8676763210258678

In [264]:
len(frozenset(test_fp.tolist()))/len(test_fp)

0.9834180853415875

In [263]:
len(frozenset(generated_fp[:len(test_fp)]))/len(test_fp)

0.9771169577713906