In [1]:
import lmdb
from pathlib import Path
import pickle
from fairchem.core.common.utils import pyg2_data_transform

In [2]:
db_path = "/Users/curtischong/Documents/dev/joule/src/fairchem/data/s2ef_test_lmdbs/test_data/s2ef/all/test_id/data.0000.lmdb"

def connect_db(lmdb_path: Path | None = None) -> lmdb.Environment:
    return lmdb.open(
        str(lmdb_path),
        subdir=False,
        readonly=True,
        lock=False,
        readahead=True,
        meminit=False,
        max_readers=1,
    )

cur_env = connect_db(db_path)

In [3]:
# If "length" encoded as ascii is present, use that
length_entry = cur_env.begin().get("length".encode("ascii"))
if length_entry is not None:
    num_entries = pickle.loads(length_entry)
else:
    # Get the number of stores data from the number of entries
    # in the LMDB
    num_entries = cur_env.stat()["entries"]


In [4]:
num_entries

625

In [27]:
import crystal_toolkit
from pymatgen.core import Structure
import numpy as np


def tile_structure(lattice, species, coordinates, scale):
    tiled_lattice = lattice * scale
    tiled_coordinates_list = []
    tiled_species = []
    for i in range(scale):
        for j in range(scale):
            for k in range(scale):
                new_coords = (coordinates + np.array([i, j, k])) / scale
                tiled_coordinates_list.append(new_coords)
                tiled_species.extend(species)
    tiled_coordinates = np.concatenate(tiled_coordinates_list, axis=0)
    return tiled_lattice, tiled_species, tiled_coordinates

def visualize_sample(res, tile_amount=1, is_tags_visible=False):
    lattice = res.cell
    species = res.atomic_numbers
    coordinates = res.pos
    if is_tags_visible:
        # tag all slab atoms below surface as 0, surface as 1, adsorbate as 2
        species = res.tags + 1
    tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, species, coordinates, tile_amount)
    display(Structure(tiled_lattice, tiled_species, tiled_coordinates, coords_are_cartesian=False))
    #display(Structure(lattice, species, coordinates, coords_are_cartesian=False))
    

def get_and_vis_sample(el_idx):
    datapoint_pickled = ( 
        cur_env
        .begin()
        .get(f"{el_idx}".encode("ascii"))
    )
    data_object = pyg2_data_transform(pickle.loads(datapoint_pickled))
    visualize_sample(data_object, 2, False)
get_and_vis_sample(4)