In [24]:
import json
import lmdb
from pathlib import Path
import pickle

In [15]:
predictions = "../out/worse_mae/worse_preds_epoch_1.json"

with open(predictions) as f:
    res = json.load(f)
energy_losses = sorted(res["worse_energy"], key=lambda x: x["energy_loss"], reverse=True)
forces_losses = sorted(res["worse_forces"], key=lambda x: x["forces_loss"], reverse=True)

In [34]:
# used to verify that we're getting the right element in the DB (cause they should have the same number of atoms as the num of pred_forces)
len(forces_losses[0]["pred_forces"])

49

In [46]:
# used to verify that we're getting the right element in the DB (cause they should have the same number of atoms as the num of pred_forces)
energy_losses[0]

{'energy_loss': 215.30906677246094,
 'pred_energy': -298.4591064453125,
 'dataset_path': 'datasets/lmdb/real_mace3/val/0.lmdb',
 'data_idx': 29139}

In [62]:
import crystal_toolkit
from pymatgen.core import Structure
import numpy as np

def connect_db(lmdb_path: Path | None = None) -> lmdb.Environment:
    return lmdb.open(
        str(lmdb_path),
        subdir=False,
        readonly=True,
        lock=False,
        readahead=True,
        meminit=False,
        max_readers=1,
    )


def tile_structure(lattice, species, coordinates, scale):
    tiled_lattice = lattice * scale * 0.75
    tiled_coordinates_list = []
    tiled_species = []
    for i in range(scale):
        for j in range(scale):
            for k in range(scale):
                new_coords = (coordinates + np.array([i, j, k])) / scale
                tiled_coordinates_list.append(new_coords)
                tiled_species.extend(species)
    tiled_coordinates = np.concatenate(tiled_coordinates_list, axis=0)
    return tiled_lattice, tiled_species, tiled_coordinates

def visualize_sample(loss_dict, tile_amount=1):
    dataset_path = "../" + loss_dict["dataset_path"]
    data_idx = loss_dict["data_idx"]
    db = connect_db(dataset_path)
    
    datapoint_pickled = (
        db
        .begin()
        .get(f"{data_idx}".encode("ascii"))
    )
    res = pickle.loads(datapoint_pickled)
    lattice = res.cell
    species = res.atomic_numbers
    coordinates = res.pos
    tiled_lattice, tiled_species, tiled_coordinates = tile_structure(lattice, species, coordinates, tile_amount)
    print(f"energy loss: {loss_dict['energy_loss']}, num_atoms: {res.natoms}")
    display(Structure(tiled_lattice, tiled_species, tiled_coordinates))
    
visualize_sample(energy_losses[9], 1)

energy loss: 130.66990661621094, num_atoms: 48


In [70]:
# I want to see the number of atoms in each cell since I noticed that the model sucks at predicting the energy when there's a large number of atoms
def visualize_sample(loss_dicts, tile_amount=1):
    for loss_dict in loss_dicts:
        dataset_path = "../" + loss_dict["dataset_path"]
        data_idx = loss_dict["data_idx"]
        db = connect_db(dataset_path)
    
        datapoint_pickled = (
            db
            .begin()
            .get(f"{data_idx}".encode("ascii"))
        )
        res = pickle.loads(datapoint_pickled)
        print(res.natoms)
visualize_sample(energy_losses)

44
48
46
46
48
46
50
49
48
48
48
48
48
48
48
40
50
50
50
48
36
48
48
49
48
50
48
46
50
50
50
46
50
48
48
50
48
49
48
48
48
48
48
46
48
48
48
46
46
44
