In [1]:
import os
import torch
import nglview as nv
from ase.io import read
from fairchem.core.datasets.ase_datasets import AseDBDataset
from src_v2.distill_datasets import LmdbDataset



## main pipeline for getting diverse atoms idxs

In [6]:
set_num = 3
parent_path = '/data/christine/esen_final_node_embedding'
paths = [
    '/data/ishan-amin/OMOL/TOY/ligand_pocket_300/train', # plp300 train
    '/data/ishan-amin/OMOL/TOY/ligand_pocket_300/val', # plp300 val
    '/data/ishan-amin/OMOL/4M/subsets/OMol_subset/protein_ligand_pockets_train_4M', # plp 20k train
    '/data/ishan-amin/OMOL/4M/subsets/OMol_subset/protein_ligand_pockets_val' # plp 20k val
]

info_paths = [
    'plp300_train_unshuffled',
    'plp300_eval_unshuffled',
    'plp20k_train_unshuffled',
    'plp20k_eval_unshuffled'
]

path = paths[set_num]
info_path = info_paths[set_num]

In [7]:
a2g_args = {  
    # "molecule_cell_size": 120.0,
    "r_energy": True,
    "r_forces": True,
}
dataset = AseDBDataset({
                        "src": path,
                        "a2g_args": a2g_args,
                        })
print(len(dataset))

92808


In [8]:
for i in range(15):
    print(len(dataset[i].atomic_numbers))

306
297
273
273
247
247
205
205
147
147
100
100
328
266
266


In [59]:
# min([len(dataset[i].atomic_numbers) for i in range(len(dataset))]) # 350 for plp20k train, also 350 for plp20k eval

35

In [9]:
whole_path = f'{parent_path}/{info_path}'
num_batches = len(os.listdir(whole_path)) // 2  # since this contains x_message and batch
x_all = []
for b in range(num_batches):
    x_message = torch.load(f"{whole_path}/x_message_{b}.pt", map_location="cpu")
    batch = torch.load(f"{whole_path}/batch_{b}.pt", map_location="cpu")
    x_parts = [x_message[batch == i] for i in range(max(batch) + 1)]
    x_all.extend(x_parts)

print(len(x_all))

92808


In [46]:
diverse_idxs = []
for x in x_all:
    # Step 2: Compute pairwise L2 distance matrix [201, 201]
    # (Optional: use other distances like cosine if preferred)
    diffs = x[:, None, :] - x[None, :, :]  # shape (201, 201, 128)
    dists = torch.norm(diffs, dim=2)       # shape (201, 201)
    num_points = x.shape[0] // 10

    # Step 3: Greedy Max-Min Diversity selection
    selected = [torch.randint(0, len(x), (1,)).item()]  # Start from a random point
    remaining = set(range(len(x))) - set(selected)

    for _ in range(num_points - 1):
        # For each remaining candidate, get the distance to the closest already selected point
        candidate_to_min_dist = [
            (i, torch.min(dists[i, selected]).item()) for i in remaining
        ]
        # Pick the candidate with the **maximum** of these minimum distances
        next_idx = max(candidate_to_min_dist, key=lambda t: t[1])[0]
        selected.append(next_idx)
        remaining.remove(next_idx)

    diverse_idxs.append(selected)

print(len(diverse_idxs))

30
0


In [18]:
atom_num = 12
atoms = read(f'plp300_eval_data_unshuffled/{atom_num}.xyz') # need to manually enter data directory here

closest_arr = diverse_idxs[atom_num]
comma_separated_string = ','.join(str(i) for i in closest_arr)

viewer = nv.show_ase(atoms)

# Clear default representations
viewer.clear()

# Add spacefill for magnesium
# viewer.add_representation('spacefill', selection='_Mg', radius=radius, opacity=0.5)

# Add ball+stick for all atoms except index 4 (NGLView uses 1-based indexing, so index 4 -> 5)
arr = str([_ for _ in range(100)])

viewer.add_representation('ball+stick', selection=f'@{comma_separated_string}', radius=0.2, bondRadius=0.01)
viewer.add_representation('ball+stick', selection='*', radius=0.2, bondRadius=0.01, opacity=0.15)

print(f'num atoms: {len(atoms)}')
# Display the viewer
display(viewer)

num atoms: 328


NGLWidget()

In [14]:
div_idxs_tensor = torch.tensor(diverse_idxs)
torch.save(div_idxs_tensor, "not_diverse_idxs_plp300_eval.pt")

In [None]:
# other testing stuff below here

In [4]:
random_indices_list = []
for i in range(300):
    n = len(dataset[i].atomic_numbers)
    indices = torch.randperm(n)[:15]
    random_indices_list.append(indices)

random_indices = torch.stack(random_indices_list)  # Shape: [300, 15]
random_indices.shape

torch.Size([300, 15])

In [5]:
# torch.save(random_indices, 'random_idxs_plp300_eval.pt')

In [88]:
any(dataset[i].charge for i in range(len(dataset)))

False

## checking results

In [2]:
test = torch.load('/data/christine/esen_final_node_embedding/diverse_idxs_plp20k_train.pt')

In [5]:
test[1]

tensor([157,   8,  66,  44,  18,  61, 106, 109,  71,   9, 110,  94, 108, 112,
         60, 101,  74, 105,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,  -1,
         -1,  -1,  -1,  -1,  -1,  -1,  -1])

## looking at hessians

In [22]:
teacher_labels_path = '/data/ishan-amin/OMOL/TOY/labels/ligand_pocket_300/force_jacobians'
labels = LmdbDataset(teacher_labels_path)
len(labels)

Total entries across all LMDB files jacs: 300


300

In [23]:
jac = labels[0].reshape(338, 3, 338, 3)

In [30]:
# these entries are small, so the interactions between far apart atoms are not strong
jac[102, :, 14, :]

tensor([[ 0.0023, -0.0048, -0.0014],
        [-0.0008,  0.0006, -0.0011],
        [ 0.0011, -0.0059, -0.0027]])