In [1]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import torch.nn as nn

# Original dataset
dataset = QM9(root='data/QM9')

# New dataset with molecules of less than 5 atoms
small_dataset = [data for data in dataset if data.num_nodes < 10]

unique_atoms = set()
for data in dataset:
    unique_atoms.update(data.z.tolist())

print(f"Number of unique atoms in the original dataset: {len(unique_atoms)}")
print(f"Unique atomic numbers: {sorted(list(unique_atoms))}")

# Original dataset statistics
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size
print(f"Original dataset - Train size: {train_size}, Val size: {val_size}, Test size: {test_size}")

# Small dataset statistics
small_total_size = len(small_dataset)
small_train_size = int(0.8 * small_total_size)
small_val_size = int(0.1 * small_total_size)
small_test_size = small_total_size - small_train_size - small_val_size
print(f"Small dataset - Train size: {small_train_size}, Val size: {small_val_size}, Test size: {small_test_size}")

# Split original dataset
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Split small dataset
small_train_dataset, small_val_dataset, small_test_dataset = random_split(small_dataset, [small_train_size, small_val_size, small_test_size])

batch_size = 128

# DataLoaders for original dataset
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# DataLoaders for small dataset
small_train_loader = DataLoader(small_train_dataset, batch_size=batch_size, shuffle=True)
small_val_loader = DataLoader(small_val_dataset, batch_size=batch_size)
small_test_loader = DataLoader(small_test_dataset, batch_size=batch_size)

for batch in train_loader:
    break

for small_batch in small_train_loader:
    break

batch = small_batch

Number of unique atoms in the original dataset: 5
Unique atomic numbers: [1, 6, 7, 8, 9]
Original dataset - Train size: 104664, Val size: 13083, Test size: 13084
Small dataset - Train size: 224, Val size: 28, Test size: 28


In [3]:
batch[molecule_index]

Data(x=[9, 11], edge_index=[2, 18], edge_attr=[18, 4], y=[1, 19], pos=[9, 3], z=[9], smiles='[H][N-]C1ON=NC(=O)O1', name='gdb_21441', idx=[1])

In [7]:
def calculate_rmsd(molecule_index, r):
    # Calculate RMSD between predicted and actual positions using Kabsch algorithm
    # First center both sets of coordinates by subtracting mean
    r_centered = r - r.mean(dim=0, keepdim=True)
    pos_centered = batch[molecule_index].pos - batch[molecule_index].pos.mean(dim=0, keepdim=True)

    # Calculate covariance matrix
    covariance = torch.matmul(r_centered.t(), pos_centered)

    # Perform SVD
    U, S, Vt = torch.linalg.svd(covariance)

    # Calculate optimal rotation matrix
    # Handle reflection case by ensuring right-handed coordinate system
    d = torch.linalg.det(torch.matmul(Vt.t(), U.t()))
    reflection_matrix = torch.eye(3)
    reflection_matrix[-1, -1] = d
    R = torch.matmul(torch.matmul(Vt.t(), reflection_matrix), U.t())

    # Apply rotation to centered coordinates
    r_aligned = torch.matmul(r_centered, R)

    # Calculate RMSD
    squared_diff = torch.sum((r_aligned - pos_centered) ** 2, dim=1)
    rmsd = torch.sqrt(torch.mean(squared_diff))

    return rmsd

def bond_length_loss(pred_coords, target_coords, bond_indices):
    pred_bond_lengths = torch.norm(pred_coords[bond_indices[:, 0]] - pred_coords[bond_indices[:, 1]], dim=1)
    target_bond_lengths = torch.norm(target_coords[bond_indices[:, 0]] - target_coords[bond_indices[:, 1]], dim=1)
    return torch.mean((pred_bond_lengths - target_bond_lengths) ** 2)


In [14]:
net = MolecularShapePredictor(embedding_dim=256, num_layers=10)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

for e in range(100):
    loss = 0
    for i in range(20):
        # Initialize positions by placing bonded atoms 1 unit apart at random angles
        initial_r = torch.zeros_like(batch[i].pos)
        edge_index = batch[i].edge_index
        placed = torch.zeros(batch[i].num_nodes, dtype=torch.bool)
        placed[0] = True  # Start with first atom at origin
        
        # Process each edge to position atoms
        # for src, dst in edge_index.t():
            # if placed[src] and not placed[dst]:
            #     # Generate random angles
            #     theta = torch.rand(1) * 2 * torch.pi  # Random angle in xy plane
            #     phi = torch.rand(1) * torch.pi        # Random angle from z axis
                
            #     # Calculate unit vector in random direction
            #     x = torch.sin(phi) * torch.cos(theta)
            #     y = torch.sin(phi) * torch.sin(theta)
            #     z = torch.cos(phi)
            #     direction = torch.tensor([x, y, z])
                
            #     # Place atom 1 unit away from connected atom
            #     initial_r[dst] = initial_r[src] + direction
            #     placed[dst] = True
            # elif placed[dst] and not placed[src]:
            #     # Same process but in reverse direction
            #     theta = torch.rand(1) * 2 * torch.pi
            #     phi = torch.rand(1) * torch.pi
                
            #     x = torch.sin(phi) * torch.cos(theta)
            #     y = torch.sin(phi) * torch.sin(theta)
            #     z = torch.cos(phi)
            #     direction = torch.tensor([x, y, z])
                
            #     initial_r[src] = initial_r[dst] + direction
            #     placed[src] = True
        initial_r = batch[i].pos + torch.randn_like(batch[i].pos) * 0.7
        adj_matrix = create_adjacency_matrix(batch, i)
        h, r = net(batch[i].z, initial_r, adj_matrix)
        loss += .3*calculate_rmsd(i, r) + bond_length_loss(r, batch[i].pos, batch[i].edge_index.t())

    loss = loss/100
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Iteration {e+1}, Loss: {loss.item():.4f}")


Iteration 1, Loss: 0.2902
Iteration 2, Loss: 0.2790
Iteration 3, Loss: 0.3232
Iteration 4, Loss: 0.2780
Iteration 5, Loss: 0.3477
Iteration 6, Loss: 0.3159
Iteration 7, Loss: 0.3127
Iteration 8, Loss: 0.2904
Iteration 9, Loss: 0.2970
Iteration 10, Loss: 0.2879
Iteration 11, Loss: 0.3258
Iteration 12, Loss: 0.2964
Iteration 13, Loss: 0.3152
Iteration 14, Loss: 0.3152
Iteration 15, Loss: 0.3291
Iteration 16, Loss: 0.3552
Iteration 17, Loss: 0.2903
Iteration 18, Loss: 0.3010
Iteration 19, Loss: 0.3277
Iteration 20, Loss: 0.2853
Iteration 21, Loss: 0.3201
Iteration 22, Loss: 0.3156
Iteration 23, Loss: 0.2827
Iteration 24, Loss: 0.3286
Iteration 25, Loss: 0.3080
Iteration 26, Loss: 0.2961
Iteration 27, Loss: 0.3242
Iteration 28, Loss: 0.2762
Iteration 29, Loss: 0.2757
Iteration 30, Loss: 0.3336
Iteration 31, Loss: 0.3300
Iteration 32, Loss: 0.2903
Iteration 33, Loss: 0.3395
Iteration 34, Loss: 0.2941
Iteration 35, Loss: 0.2648
Iteration 36, Loss: 0.3087
Iteration 37, Loss: 0.2813
Iteration 

In [9]:
visualize_molecule(batch, 12)

SMILES code: [H]C1=C(N(=O)=O)ON=N1


tensor([8, 7, 8, 6, 6, 7, 7, 8, 1])

Graph information:
Number of nodes: 9
Number of edges: 18
Node features shape: torch.Size([9, 11])


In [10]:
initial_r = batch[12].pos + torch.randn_like(batch[12].pos) * 0.1
h, pos = net(batch[12].z, initial_r, create_adjacency_matrix(batch, 12))
visualize_molecule(batch, 12, pos)


SMILES code: [H]C1=C(N(=O)=O)ON=N1


tensor([8, 7, 8, 6, 6, 7, 7, 8, 1])

Graph information:
Number of nodes: 9
Number of edges: 18
Node features shape: torch.Size([9, 11])


In [11]:
calculate_rmsd(12, pos)

tensor(0.2878, grad_fn=<SqrtBackward0>)