In [None]:
"""train.lmdb: 405,073 structures
test.lmdb: 135,116 structures
val.lmdb: 135,015 structures"""

In [None]:
import lmdb
import pickle
import torch
import pyarrow as pa
import pyarrow.parquet as pq

In [None]:
keep_rows = [
    "pos",
    "cell",
    "atomic_numbers",
    "pbc",
    "material_id",
    "reduced_formula",
    "space_group",
    "chemical_system",
    "num_sites",
    "cif",
    "energy_above_hull",
    "dft_band_gap",
    "dft_bulk_modulus",
    "dft_mag_density",
    "hhi_score",
    "ml_bulk_modulus",
]

In [None]:
def read_lmdb_content(lmdb_path):
    env = lmdb.open(
        lmdb_path,
        subdir=False,
        readonly=True,
    )
    with env.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            yield pickle.loads(value)

In [None]:
def create_table(lmdb_content, keep_rows):
    table_rows = []
    for row in lmdb_content:
        row = {
            key: val.tolist() if isinstance(val, torch.Tensor) else val
            for key, val in row.items()
            if key in keep_rows
        }
        table_rows.append(
            {
                "positions": [float(y) for x in row["pos"] for y in x],
                "cell": [float(y) for x in row["cell"] for y in x],
                "atomic_numbers": [int(x) for x in row["atomic_numbers"]],
                "pbc": [int(x) for x in row["pbc"]],
                "material_id": str(row["material_id"]),
                "reduced_formula": str(row["reduced_formula"]),
                "space_group": str(row["space_group"]),
                "chemical_system": str(row["chemical_system"]),
                "num_sites": int(row["num_sites"]),
                "cif": str(row["cif"]),
                "energy_above_hull": float(row["energy_above_hull"]),
                "dft_band_gap": float(row["dft_band_gap"]),
                "dft_bulk_modulus": float(row["dft_bulk_modulus"]),
                "dft_mag_density": float(row["dft_mag_density"]),
                "hhi_score": float(row["hhi_score"]),
                "ml_bulk_modulus": float(row["ml_bulk_modulus"]),
            }
        )

    return pa.Table.from_pylist(table_rows)

In [None]:
test = create_table(read_lmdb_content("data/test.lmdb"), keep_rows)
pq.write_table(test, "parquets/test.parquet", compression="ZSTD", compression_level=18)

In [None]:
val = create_table(read_lmdb_content("data/val.lmdb"), keep_rows)
# pq.write_table(val, "parquets/val.parquet", compression="ZSTD", compression_level=18)

In [None]:
train = create_table(read_lmdb_content("data/train.lmdb"), keep_rows)

In [None]:
len(train)

In [None]:
train_table = pa.Table.from_pylist(train)

In [None]:
train_table.num_rows

In [None]:
pq.write_table(
    train_table, "parquets/train.parquet", compression="ZSTD", compression_level=18
)

### HF upload

In [None]:
from huggingface_hub import HfApi
from dotenv import load_dotenv
import os

load_dotenv()
token = os.getenv("HF_TOKEN")

api = HfApi(token=token)

In [None]:
api.create_repo(
    repo_id="colabfit/Alex-MP-20_Polymorph_Split", repo_type="dataset", token=token
)

In [None]:
api.upload_folder(
    folder_path="parquets",
    repo_type="dataset",
    repo_id="colabfit/Alex-MP-20_Polymorph_Split",
    token=token,
)

In [None]:
api.upload_file(
    path_or_fileobj="README.md",
    path_in_repo="README.md",
    repo_type="dataset",
    repo_id="colabfit/Alex-MP-20_Polymorph_Split",
    token=token,
)

In [None]:
api.delete_file(
    path_in_repo="parquets/val.parquet",
    repo_type="dataset",
    repo_id="colabfit/Alex-MP-20_Polymorph_Split",
    token=token,
)
api.delete_file(
    path_in_repo="parquets/test.parquet",
    repo_type="dataset",
    repo_id="colabfit/Alex-MP-20_Polymorph_Split",
    token=token,
)