In [1]:
import os
import torch
import nglview as nv
from ase.io import read
from fairchem.core.datasets.ase_datasets import AseDBDataset



## Protein ligand pocket 300 val

In [2]:
a2g_args = {  
    # "molecule_cell_size": 120.0,
    "r_energy": True,
    "r_forces": True,
}
path = '/data/ishan-amin/OMOL/TOY/ligand_pocket_300/val'
dataset = AseDBDataset({
                        "src": path,
                        "a2g_args": a2g_args,
                        })
print(len(dataset))

300


In [3]:
for i in range(15):
    print(len(dataset[i].atomic_numbers))

306
297
273
273
247
247
205
205
147
147
100
100
328
266
266


In [4]:
random_indices_list = []
for i in range(300):
    n = len(dataset[i].atomic_numbers)
    indices = torch.randperm(n)[:15]
    random_indices_list.append(indices)

random_indices = torch.stack(random_indices_list)  # Shape: [300, 15]
random_indices.shape

torch.Size([300, 15])

In [5]:
# torch.save(random_indices, 'random_idxs_plp300_eval.pt')

In [88]:
any(dataset[i].charge for i in range(len(dataset)))

False

In [6]:
num_batches = 75
counter = 0
x_all = []
for b in range(num_batches):
    x_message = torch.load(f"plp300_eval_unshuffled/x_message_{b}.pt", map_location="cpu")
    batch = torch.load(f"plp300_eval_unshuffled/batch_{b}.pt", map_location="cpu")
    x_parts = [x_message[batch == i] for i in range(max(batch) + 1)]
    x_all.extend(x_parts)

print(len(x_all))

300


In [7]:
diverse_idxs = []
num_points = 15
for x in x_all:
    # Step 2: Compute pairwise L2 distance matrix [201, 201]
    # (Optional: use other distances like cosine if preferred)
    diffs = x[:, None, :] - x[None, :, :]  # shape (201, 201, 128)
    dists = torch.norm(diffs, dim=2)       # shape (201, 201)

    # Step 3: Greedy Max-Min Diversity selection
    selected = [torch.randint(0, len(x), (1,)).item()]  # Start from a random point
    remaining = set(range(len(x))) - set(selected)

    for _ in range(num_points - 1):
        # For each remaining candidate, get the distance to the closest already selected point
        candidate_to_min_dist = [
            (i, torch.min(dists[i, selected]).item()) for i in remaining
        ]
        # Pick the candidate with the **maximum** of these minimum distances
        next_idx = max(candidate_to_min_dist, key=lambda t: t[1])[0]
        selected.append(next_idx)
        remaining.remove(next_idx)

    diverse_idxs.append(selected)

print(len(diverse_idxs))

300


In [18]:
atom_num = 12
atoms = read(f'plp300_eval_data_unshuffled/{atom_num}.xyz')

closest_arr = diverse_idxs[atom_num]
comma_separated_string = ','.join(str(i) for i in closest_arr)

viewer = nv.show_ase(atoms)

# Clear default representations
viewer.clear()

# Add spacefill for magnesium
# viewer.add_representation('spacefill', selection='_Mg', radius=radius, opacity=0.5)

# Add ball+stick for all atoms except index 4 (NGLView uses 1-based indexing, so index 4 -> 5)
arr = str([_ for _ in range(100)])

viewer.add_representation('ball+stick', selection=f'@{comma_separated_string}', radius=0.2, bondRadius=0.01)
viewer.add_representation('ball+stick', selection='*', radius=0.2, bondRadius=0.01, opacity=0.15)

print(f'num atoms: {len(atoms)}')
# Display the viewer
display(viewer)

num atoms: 328


NGLWidget()

In [14]:
div_idxs_tensor = torch.tensor(diverse_idxs)
torch.save(div_idxs_tensor, "not_diverse_idxs_plp300_eval.pt")

## Solvated

In [19]:
a2g_args = {  
    # "molecule_cell_size": 120.0,
    "r_energy": True,
    "r_forces": True,
}
path = '/data/ishan-amin/OMOL/TOY/solvated/val'
dataset = AseDBDataset({
                        "src": path,
                        "a2g_args": a2g_args,
                        })
print(len(dataset))

143


In [20]:
for i in range(15):
    print(len(dataset[i].atomic_numbers))

89
96
85
88
93
91
88
89
89
85
85
89
83
89
92


In [22]:
num_batches = 13
counter = 0
x_all = []
for b in range(num_batches):
    x_message = torch.load(f"solvated_unshuffled/x_message_{b}.pt", map_location="cpu")
    batch = torch.load(f"solvated_unshuffled/batch_{b}.pt", map_location="cpu")
    x_parts = [x_message[batch == i] for i in range(max(batch) + 1)]
    x_all.extend(x_parts)

print(len(x_all))

143


In [23]:
diverse_idxs = []
num_points = 10
for x in x_all:
    # Step 2: Compute pairwise L2 distance matrix [201, 201]
    # (Optional: use other distances like cosine if preferred)
    diffs = x[:, None, :] - x[None, :, :]  # shape (201, 201, 128)
    dists = torch.norm(diffs, dim=2)       # shape (201, 201)

    # Step 3: Greedy Max-Min Diversity selection
    selected = [torch.randint(0, len(x), (1,)).item()]  # Start from a random point
    remaining = set(range(len(x))) - set(selected)

    for _ in range(num_points - 1):
        # For each remaining candidate, get the distance to the closest already selected point
        candidate_to_min_dist = [
            (i, torch.min(dists[i, selected]).item()) for i in remaining
        ]
        # Pick the candidate with the **maximum** of these minimum distances
        next_idx = max(candidate_to_min_dist, key=lambda t: t[1])[0]
        selected.append(next_idx)
        remaining.remove(next_idx)

    diverse_idxs.append(selected)

print(len(diverse_idxs))

143


In [27]:
atom_num = 1
atoms = read(f'solvated_data_unshuffled/{atom_num}.xyz')

closest_arr = diverse_idxs[atom_num]
comma_separated_string = ','.join(str(i) for i in closest_arr)

viewer = nv.show_ase(atoms)

# Clear default representations
viewer.clear()

# Add spacefill for magnesium
# viewer.add_representation('spacefill', selection='_Mg', radius=radius, opacity=0.5)

# Add ball+stick for all atoms except index 4 (NGLView uses 1-based indexing, so index 4 -> 5)
arr = str([_ for _ in range(100)])

viewer.add_representation('ball+stick', selection=f'@{comma_separated_string}', radius=0.2, bondRadius=0.01)
viewer.add_representation('ball+stick', selection='*', radius=0.2, bondRadius=0.01, opacity=0.15)

print(f'num atoms: {len(atoms)}')
# Display the viewer
display(viewer)

num atoms: 96


NGLWidget()

In [26]:
atoms.get_chemical_formula()

'C9H54N4O22'

In [28]:
div_idxs_tensor = torch.tensor(diverse_idxs)
torch.save(div_idxs_tensor, "solvated_eval_diverse_idxs.pt")

## protein ligand pocket 300 train

In [28]:
a2g_args = {  
    # "molecule_cell_size": 120.0,
    "r_energy": True,
    "r_forces": True,
}
path = '/data/ishan-amin/OMOL/TOY/ligand_pocket_300/train'
dataset = AseDBDataset({
                        "src": path,
                        "a2g_args": a2g_args,
                        })
print(len(dataset))

300


In [29]:
for i in range(15):
    print(len(dataset[i].atomic_numbers))

338
187
90
158
72
190
244
244
302
78
250
273
95
327
195


In [30]:
random_indices_list = []
for i in range(300):
    n = len(dataset[i].atomic_numbers)
    indices = torch.randperm(n)[:15]
    random_indices_list.append(indices)

random_indices = torch.stack(random_indices_list)  # Shape: [300, 15]
random_indices.shape

torch.Size([300, 15])

In [31]:
torch.save(random_indices, "random_idxs_plp300_train.pt")

In [4]:
num_batches = 69
counter = 0
x_all = []
for b in range(num_batches):
    x_message = torch.load(f"plp300_train_unshuffled/x_message_{b}.pt", map_location="cpu")
    batch = torch.load(f"plp300_train_unshuffled/batch_{b}.pt", map_location="cpu")
    x_parts = [x_message[batch == i] for i in range(max(batch) + 1)]
    x_all.extend(x_parts)

print(len(x_all))

300


In [None]:
diverse_idxs = []
num_points = 15
for x in x_all:
    # Step 2: Compute pairwise L2 distance matrix [201, 201]
    # (Optional: use other distances like cosine if preferred)
    diffs = x[:, None, :] - x[None, :, :]  # shape (201, 201, 128)
    dists = torch.norm(diffs, dim=2)       # shape (201, 201)

    # Step 3: Greedy Max-Min Diversity selection
    selected = [torch.randint(0, len(x), (1,)).item()]  # Start from a random point
    remaining = set(range(len(x))) - set(selected)

    for _ in range(num_points - 1):
        # For each remaining candidate, get the distance to the closest already selected point
        candidate_to_min_dist = [
            (i, torch.min(dists[i, selected]).item()) for i in remaining
        ]
        # Pick the candidate with the **maximum** of these minimum distances
        next_idx = max(candidate_to_min_dist, key=lambda t: t[1])[0] # change this to min for the least diverse
        selected.append(next_idx)
        remaining.remove(next_idx)

    diverse_idxs.append(selected)

print(len(diverse_idxs))

300


In [6]:
atom_num = 1
atoms = read(f'plp300_train_data_unshuffled/{atom_num}.xyz')

closest_arr = diverse_idxs[atom_num]
comma_separated_string = ','.join(str(i) for i in closest_arr)

viewer = nv.show_ase(atoms)

# Clear default representations
viewer.clear()

# Add spacefill for magnesium
# viewer.add_representation('spacefill', selection='_Mg', radius=radius, opacity=0.5)

# Add ball+stick for all atoms except index 4 (NGLView uses 1-based indexing, so index 4 -> 5)
arr = str([_ for _ in range(100)])

viewer.add_representation('ball+stick', selection=f'@{comma_separated_string}', radius=0.2, bondRadius=0.01)
viewer.add_representation('ball+stick', selection='*', radius=0.2, bondRadius=0.01, opacity=0.15)

print(f'num atoms: {len(atoms)}')
# Display the viewer
display(viewer)

num atoms: 187


NGLWidget()

In [7]:
div_idxs_tensor = torch.tensor(diverse_idxs)
torch.save(div_idxs_tensor, "not_diverse_idxs_plp300_train.pt")

In [25]:
div_idxs_tensor

tensor([[318, 101, 316,  ..., 191, 291, 309],
        [ 99, 142,  66,  ..., 108,  89,  60],
        [ 16,  12,   9,  ...,  43,  54,  23],
        ...,
        [142, 117,  55,  ..., 140, 195, 150],
        [121,  59, 104,  ...,   2, 110, 125],
        [ 10, 271, 122,  ..., 188, 247, 252]])

In [59]:
atoms = read(f'plp300_train_data_unshuffled/1.xyz')

In [65]:
res = []
for i in range(300):
    atoms = read(f'plp300_train_data_unshuffled/{i}.xyz')
    res.append(any(atoms.get_initial_charges()))
any(res)

False

In [75]:
res = []
folder = 'solvated_data_unshuffled'
for i in range(len(os.listdir(folder))):
    atoms = read(f'{folder}/{i}.xyz')
    res.append(any(atoms.get_initial_charges()))
any(res)

False

In [89]:
atoms.get_partial_charges()

AttributeError: 'Atoms' object has no attribute 'get_partial_charges'