In [1]:
import gzip
import pickle
from wyckoff_transformer.evaluation import smac_validity_from_counter, smact_validity_from_record

In [44]:
dataset = "mp_20"
with gzip.open(f"cache/{dataset}/data.pkl.gz") as f:
    data = pickle.load(f)

In [45]:
data['train'].composition.map(smac_validity_from_counter).mean()

0.43399911556603776

In [46]:
data['train'].composition.map(double_composition).map(smac_validity_from_counter).mean()

0.35112028301886794

In [62]:
data['train'].iloc[0]

site_symmetries                      [m, m, m, m, m, m, m, m, m, m, m, m]
elements                       [Na, Na, Na, Mn, Co, Ni, O, O, O, O, O, O]
multiplicity                         [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
wyckoff_letters                      [a, a, a, a, a, a, a, a, a, a, a, a]
sites_enumeration                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
dof                                  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
spacegroup_number                                                       8
sites_enumeration_augmented        [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
composition                           {Na: 6, Mn: 2, Co: 2, Ni: 2, O: 12}
Name: 37228, dtype: object

In [65]:
gcd_composition(data['train'].iloc[0].composition)

{Element Na: 3, Element Mn: 1, Element Co: 1, Element Ni: 1, Element O: 6}

In [47]:
# multiply composition by 2
def double_composition(composition):
    return {key: value * 2 for key, value in composition.items()}
smac_validity_from_counter({key: value * 2 for key, value in data['train'].composition.iloc[0].items()})

False

In [48]:
from math import gcd

In [49]:
def gcd_composition(composition):
    this_gcd = gcd(*composition.values())
    return {key: value // this_gcd for key, value in composition.items()}

In [50]:
data['train'].composition.map(gcd_composition).map(smac_validity_from_counter).mean()

0.9052181603773585

In [51]:
data['test'].composition.map(smac_validity_from_counter).mean()

0.4359938094185275

In [52]:
data['val'].composition.map(smac_validity_from_counter).mean()

0.43671935448214877

In [53]:
from wyckoff_transformer.tokenization import load_tensors_and_tokenisers, tensor_to_pyxtal, pyxtal_cascade_order, get_letter_from_ss_enum_idx, get_wp_index

In [54]:
tensors, tokenisers, engineers = load_tensors_and_tokenisers("mp_20_biternary", "mp_20_CSP")

In [55]:
import torch

In [56]:
letter_from_ss_enum_idx = get_letter_from_ss_enum_idx(tokenisers['sites_enumeration'])

In [57]:
generated_tensors = torch.stack([tensors["train"][field] for field in pyxtal_cascade_order], dim=2)

In [58]:
wp_index = get_wp_index()

In [59]:
generated_wp = [tensor_to_pyxtal(
    x, y, tokenisers,
    pyxtal_cascade_order,
    letter_from_ss_enum_idx,
    wp_index) for x, y in zip(
    tensors["train"]["spacegroup_number"],
    generated_tensors)]

In [60]:
len(generated_wp)

21317

In [14]:
len(list(filter(lambda x: x is not None, generated_wp)))

21317

In [15]:
generated_tensors.shape

torch.Size([21317, 21, 3])

In [16]:
r = sum(map(smact_validity_from_record, generated_wp))

In [17]:
r/len(generated_wp)

0.4897499648168129

In [18]:
generated_wp[0]

{'group': 139,
 'sites': [['2a'], ['8j', '8i'], ['8f']],
 'species': ['Nd', 'Al', 'Cu'],
 'numIons': [2, 16, 8]}

In [19]:
from collections import Counter

In [20]:
c = Counter()
for e, m in zip(data['train'].iloc[0].elements, data['train'].iloc[0].multiplicity):
    c[e] += m
c

Counter({Element Al: 16, Element Cu: 8, Element Nd: 2})

In [21]:
Counter(zip(data['train'].iloc[0].elements, data['train'].iloc[0].multiplicity))

Counter({(Element Al, 8): 2, (Element Nd, 2): 1, (Element Cu, 8): 1})

In [22]:
c

Counter({Element Al: 16, Element Cu: 8, Element Nd: 2})

In [23]:
data['train'].iloc[0]

site_symmetries                  [4/mmm, m2m., m2m., ..2/m]
elements                                   [Nd, Al, Al, Cu]
multiplicity                                   [2, 8, 8, 8]
wyckoff_letters                                [a, j, i, f]
sites_enumeration                              [0, 1, 0, 0]
dof                                            [0, 1, 1, 0]
spacegroup_number                                       139
sites_enumeration_augmented    [[1, 0, 1, 0], [0, 1, 0, 0]]
composition                          {Nd: 2, Al: 16, Cu: 8}
Name: 19480, dtype: object