In [1]:
from itertools import repeat
from ast import literal_eval
from pathlib import Path
import logging
import json
import gzip
import pandas as pd

In [2]:
logger = logging.getLogger(__name__)

In [3]:
wycryst_data_raw = pd.read_csv(Path("generated", "WyCryst_mp_20_biternary.csv"), index_col=0,
                            converters=dict(zip(
                                ("reconstructed_ratio1", "reconstructed_wyckoff", "str_wyckoff", "ter_sys"),
                                repeat(literal_eval, 4)
                            )))

In [4]:
wycryst_data_raw.head()

Unnamed: 0,n_element,reconstructed_formula,reconstructed_ratio,reconstructed_ratio1,reconstructed_wyckoff,reconstructed_sg,predicted_property,reconstructed_DoF,str_wyckoff,ter_sys,n_sites,oxid_test
0,3,Nd2Al16Cu8,2168,"[2.0, 16.0, 8.0]","{'Nd': ['2a'], 'Al': ['8i', '8j'], 'Cu': ['8f']}",139,,2.0,"{'Nd': ['2a'], 'Al': ['8i', '8j'], 'Cu': ['8f']}","[Nd, Al, Cu]",26.0,
1,3,Li4Mn4Ir8,448,"[4.0, 4.0, 8.0]","{'Li': ['4b'], 'Mn': ['4a'], 'Ir': ['8c']}",225,,0.0,"{'Li': ['4b'], 'Mn': ['4a'], 'Ir': ['8c']}","[Li, Mn, Ir]",16.0,False
4,3,Na8Bi8O12,8812,"[8.0, 8.0, 12.0]","{'Na': ['2a', '2d', '4h'], 'Bi': ['4g', '4i'],...",12,,9.0,"{'Na': ['2a', '2d', '4h'], 'Bi': ['4g', '4i'],...","[Na, Bi, O]",28.0,
5,3,In4Bi4S6,446,"[4.0, 4.0, 6.0]","{'In': ['2a', '2e'], 'Bi': ['2e', '2e'], 'S': ...",11,,12.0,"{'In': ['2a', '2e'], 'Bi': ['2e', '2e'], 'S': ...","[In, Bi, S]",14.0,True
7,2,Ti4Cu16,416,"[4.0, 16.0]","{'Ti': ['4c'], 'Cu': ['4c', '4c', '4c', '4c']}",62,,10.0,"{'Ti': ['4c'], 'Cu': ['4c', '4c', '4c', '4c']}","[Ti, Cu]",20.0,False


In [5]:
def wycryst_to_pyxtal_dict(record):
    species = []
    all_sites = []
    numIons = []
    for reported_count, (element, sites) in zip(record["reconstructed_ratio1"], record["reconstructed_wyckoff"].items()):
        counted_ions = 0
        this_sites = []
        for site in sites:
            this_sites.append(site)
            counted_ions += int(site[:-1])
        if counted_ions != reported_count:
            logging.warning(f"Reported count {reported_count} does not match sum of site counts {counted_ions} for"
                f" record {record.name} element {element}. Elements: {record["reconstructed_wyckoff"].keys()}")
            return None
        species.append(element)
        all_sites.append(this_sites)
        numIons.append(counted_ions)
    return {
        "species": species,
        "sites": all_sites,
        "numIons": numIons,
        "group": record.reconstructed_sg
    }

In [6]:
generated_wycryst = wycryst_data_raw.apply(wycryst_to_pyxtal_dict, axis=1).dropna().tolist()



In [7]:
with gzip.open(Path("generated", "AI4AM_mp_20_biternary.json.gz"), "rt") as f:
    generated_wyckoff_transformer = json.load(f)

In [8]:
import pickle
from wyckoff_transformer.evaluation import timed_smact_validity_from_record, StatisticalEvaluator
with gzip.open(Path("cache", "mp_20_biternary", "data.pkl.gz"), "rb") as f:
    datasets_pd = pickle.load(f)
evaluator = StatisticalEvaluator(datasets_pd["test"], pd.concat([datasets_pd["train"], datasets_pd["val"]], axis=0))
print("Test novelty: ",evaluator.get_test_novelty())
print("WyCryst novelty: ", evaluator.count_novel(generated_wycryst))
print("WyckoffTransformer novelty: ", evaluator.count_novel(generated_wyckoff_transformer))
novel_wycryst = evaluator.get_novel(generated_wycryst)
novel_wyckoff_transformer = evaluator.get_novel(generated_wyckoff_transformer)

Test novelty:  0.9598201742062377
WyCryst novelty:  0.29177947812438687
WyckoffTransformer novelty:  0.8648


In [9]:
novel_evaluator = StatisticalEvaluator(evaluator.get_novel_dataframe(datasets_pd["test"]),
                                       pd.concat([datasets_pd["train"], datasets_pd["val"]], axis=0))

In [15]:
datasets = {
    "WyCryst": (generated_wycryst, evaluator),
    "WyckoffTransformer": (generated_wyckoff_transformer, evaluator),
    "WyCryst Novel": (novel_wycryst, novel_evaluator),
    "WyckoffTransformer Novel": (novel_wyckoff_transformer, novel_evaluator)}
for name, (dataset, this_evaluator) in datasets.items():
    print(name)
    print("Num sites KS: ", this_evaluator.get_num_sites_ks(dataset))
    print("Num elements KS :", this_evaluator.get_num_elements_ks(dataset))
    print("DoF KS: ", this_evaluator.get_dof_ks(dataset))
    sg_chi2 = this_evaluator.get_sg_chi2(dataset, sample_size="test")
    print(f"SG chi2: statistic={sg_chi2.statistic}, pvalue={sg_chi2.pvalue}")

WyCryst
Num sites KS:  KstestResult(statistic=0.03975535196011648, pvalue=1.4236638119437514e-09, statistic_location=4, statistic_sign=1)
Num elements KS : KstestResult(statistic=0.03838199508352241, pvalue=5.947351369026711e-09, statistic_location=2, statistic_sign=-1)




DoF KS:  KstestResult(statistic=0.04117394256950424, pvalue=3.0839361418463316e-10, statistic_location=7, statistic_sign=1)
SG chi2: statiscic=703.5853771023179, pvalue=1.6072519094623623e-73
WyckoffTransformer
Num sites KS:  KstestResult(statistic=0.02380339983141332, pvalue=0.017604870205337915, statistic_location=4, statistic_sign=1)
Num elements KS : KstestResult(statistic=0.0257, pvalue=0.008049610835145911, statistic_location=3, statistic_sign=1)
DoF KS:  KstestResult(statistic=0.0188548468670975, pvalue=0.10231187840614872, statistic_location=2, statistic_sign=1)
SG chi2: statiscic=152.11975499959556, pvalue=0.3691994206142184
WyCryst Novel
Num sites KS:  KstestResult(statistic=0.2796992422819146, pvalue=0.0, statistic_location=3, statistic_sign=1)
Num elements KS : KstestResult(statistic=0.004451737339108447, pvalue=0.9998335648003727, statistic_location=2, statistic_sign=1)
DoF KS:  KstestResult(statistic=0.20917388936306303, pvalue=6.022183272547049e-221, statistic_location=2



Num sites KS:  KstestResult(statistic=0.039500254595520624, pvalue=9.854143711269751e-06, statistic_location=4, statistic_sign=1)
Num elements KS : KstestResult(statistic=0.02971785383903793, pvalue=0.0019699337641728505, statistic_location=3, statistic_sign=1)
DoF KS:  KstestResult(statistic=0.042402857967930065, pvalue=1.5342722610037222e-06, statistic_location=2, statistic_sign=1)
SG chi2: statiscic=215.41483375782673, pvalue=0.00020286983495509275


In [11]:
train_sg = set(datasets_pd["train"].spacegroup_number)
train_sg.update(datasets_pd["val"].spacegroup_number)
test_sg = frozenset(datasets_pd["test"].spacegroup_number)

In [12]:
len(train_sg), len(test_sg), len(train_sg & test_sg)

(169, 148, 145)