In [1]:
import pickle
import gzip
import json
from itertools import chain
from pathlib import Path
from pandas import Series

In [2]:
with gzip.open(Path("cache", "mp_20", "data.pkl.gz"), "rb") as f:
    mp_20 = pickle.load(f)

In [3]:
def to_json_naive(row: Series):
    sites = []
    for element, letter in zip(row["elements"], row["wyckoff_letters"]):
        sites.append((element, letter))
    # Sort by site, then by element
    sites.sort(key=lambda x: (x[1], x[0]))
    return {
        "spacegroup_number": int(row["spacegroup_number"]),
        "wyckoff_sites": [(str(element), letter) for element, letter in sites],
    }

In [4]:
def to_json_enumerations(row: Series):
    for enumeration in row["sites_enumeration_augmented"]:
        sites = []
        for element, site_symmetry, enumeration in zip(row["elements"], row["site_symmetries"], enumeration):
            sites.append((element, site_symmetry, enumeration))
        # Sort by site, then by element
        sites.sort(key=lambda x: (x[1], x[2], x[0]))
        yield {
            "spacegroup_number": int(row["spacegroup_number"]),
            "wyckoff_sites": [(str(element), site_symmetry, enumeration) for element, site_symmetry, enumeration in sites],
        }

In [5]:
naive_dicts = mp_20['train'].apply(to_json_naive, axis=1).to_list() + mp_20['val'].apply(to_json_naive, axis=1).to_list()
with gzip.open(Path("generated", "Dropbox", "mp_20", "wyckoff_naive.json.gz"), 'wt') as f:
    json.dump(naive_dicts, f)



In [6]:
augmented_dicts = list(chain(*(list(to_json_enumerations(row)) for _, row in chain(mp_20['train'].iterrows(), mp_20['val'].iterrows()))))

In [7]:
with gzip.open(Path("generated", "Dropbox", "mp_20", "wyckoff_augmented.json.gz"), 'wt') as f:
    json.dump(augmented_dicts, f)