In [15]:
import jax
import jax.numpy as jnp
from crystalformer.src.utils import GLXYZAW_from_file
from crystalformer.src.elements import element_dict
from crystalformer.src.formula import find_composition_vector, formula_string
from crystalformer.src.wyckoff import mult_table

In [16]:
#alex20_folder = '/opt/data/bcmdata/ZONES/data/PROJECTS/datafile/PRIVATE/zdcao/crystal_gpt/dataset/alex/PBE/alex20/'
alex20_folder = '/opt/data/bcmdata/ZONES/data/PROJECTS/datafile/PRIVATE/zdcao/crystal_gpt/dataset/alex/PBE_20241204/'

train_path = alex20_folder+'/train.lmdb'
valid_path = alex20_folder+'/val.lmdb'
test_path = alex20_folder+'/test.lmdb'


In [17]:
atom_types = 119 
wyck_types = 28 
n_max = 21

train_dataset = GLXYZAW_from_file(train_path, atom_types, wyck_types, n_max)
valid_dataset = GLXYZAW_from_file(valid_path, atom_types, wyck_types, n_max)
test_dataset = GLXYZAW_from_file(test_path, atom_types, wyck_types, n_max)





G: (1387800,)
L: (1387800, 6)
XYZ: (1387800, 21, 3)
A: (1387800, 21)
W: (1387800, 21)
G: (173475,)
L: (173475, 6)
XYZ: (173475, 21, 3)
A: (173475, 21)
W: (173475, 21)
G: (173475,)
L: (173475, 6)
XYZ: (173475, 21, 3)
A: (173475, 21)
W: (173475, 21)


In [18]:
G, L, XYZ, A, W = train_dataset

In [19]:
def find_element(x, elements):
    allowed_values = [0] + [element_dict[e] for e in elements]
    allowed = jnp.asarray(allowed_values)

    # eq[i, j, k] = (x[i, j] == allowed[k])
    eq = (x[:, :, None] == allowed[None, None, :])

    # (1) only allowed values: for each element, it's equal to at least one allowed
    only_allowed = jnp.all(jnp.any(eq, axis=2), axis=1)   # shape (n,)

    # (2) all allowed present: for each allowed value k, it appears in the row at least once
    all_present = jnp.all(jnp.any(eq, axis=1), axis=1)    # shape (n,)

    return jnp.where(only_allowed & all_present)[0]

In [32]:
idx = find_element(A, ['Cs', 'Pb', 'I'])


In [33]:
A[idx]

Array([[82, 55, 53, 55, 53, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [82, 53, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [82, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [55, 82, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [82, 53, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [55, 53, 82, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [82, 53, 55, 53, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [55, 55, 53, 82, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [53, 82, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0],
       [82, 55, 82, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,

In [34]:
G[idx]

Array([ 14, 139, 221, 167,  62,  71,  62,  38, 127, 230, 140,  63, 148,
       139, 139, 189,  12], dtype=int32)

In [36]:
@jax.vmap
def lookup(g, w):
    return mult_table[g-1, w]
M = lookup(G[idx], W[idx]) # (batchsize, n_max)


In [37]:
composition = jax.vmap(find_composition_vector)(A[idx], M)

In [38]:
for c in composition:
    print (formula_string(c))

I6Cs3Pb
I4Cs2Pb
I3CsPb
I6Cs4Pb
I3CsPb
I5CsPb2
I3CsPb
I5Cs2Pb2
I3CsPb
I12Cs3Pb5
I5CsPb2
I3CsPb
I8Cs6Pb
I10Cs4Pb3
I7Cs3Pb2
I12Cs5Pb3
I4Cs2Pb


In [None]:
W[idx][6]

Array([2, 3, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],      dtype=int32)

In [None]:
A[idx][6]

Array([82, 53, 55, 53,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0], dtype=int32)

In [50]:
XYZ[idx][6]

Array([[0.16235477, 0.25      , 0.06006606],
       [0.16400707, 0.25      , 0.49800405],
       [0.42519137, 0.25      , 0.82704854],
       [0.52676034, 0.25      , 0.6094674 ],
       [0.792392  , 0.25      , 0.2865102 ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        ]], dtype=float32)