In [1]:
import mygenai
from mygenai.models.graphvae import GraphVAE
from mygenai.utils.transforms import CompleteGraph, SetTarget, PadToFixedSize, ExtractFeatures
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

import torch
import torch_geometric
import torch_geometric.transforms
import numpy as np

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))


PyTorch version 2.5.0+cu124
PyG version 2.6.1
Using device: cpu


In [2]:
# Transforms which are applied during data loading:
# (1) Fully connect the graphs, (2) Select the target/label

transform = torch_geometric.transforms.Compose([
        ExtractFeatures(),
        PadToFixedSize(),
        CompleteGraph(),
        SetTarget()
    ])
target = 4

# Load the QM9 dataset with the transforms defined
dataset = QM9("../data/QM9", transform=transform)

# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean[:, target].item(), std[:, target].item()
# dataset = dataset[1000]
print(mean)

6.858491897583008




In [3]:
print(f"Total number of samples: {len(dataset)}.")

# let's just use the first 3000 samples

# Split datasets (our 3K subset)
train_dataset = dataset[:1000]
val_dataset = dataset[1000:2000]
test_dataset = dataset[2000:3000]
print(f"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.")

# Create dataloaders with batch size = 32
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Total number of samples: 130831.
Created dataset splits with 1000 training, 1000 validation, 1000 test samples.


In [4]:
water = dataset[2]

BOND_TYPES = {
    0: "single",
    1: "double",
    2: "triple",
    3: "aromatic",
    4: "no bond"
}

water = water.to(device)
# print("Water molecule edge attributes: ", water.edge_attr)
# print("Water molecule node attributes: ", water.x)
# print("Water molecule edge indices: ", water.edge_index)
from torch_geometric.utils import to_dense_adj
# create ground-truth adjacency matrix
adj = to_dense_adj(water.edge_index, batch=water.batch, edge_attr=water.edge_attr)
# print("Ground-truth adjacency matrix: ", adj)
# water.x
water.edge_attr

# check the raw edge_index and edge_attr for water
# for some reason H1 and H2 are connected with a single bond...
print("Raw water molecule edges:")
for i in range(water.edge_index.shape[1]):
    src = water.edge_index[0, i].item()
    dst = water.edge_index[1, i].item()
    attr = water.edge_attr[i].argmax().item()
    bond_type = BOND_TYPES[attr]
    if src < 3 and dst < 3:  # Only real atoms
        print(f"Atom {src} - Atom {dst}: {bond_type}")

# Look at dense adjacency matrix directly
dense_adj = to_dense_adj(water.edge_index, batch=water.batch, edge_attr=water.edge_attr)[0]
print("Dense adjacency for H1-H2 connection:")
print(dense_adj[1, 2])  # Should be [0,0,0,0,1] for "no bond"

water.num_real_atoms

Raw water molecule edges:
Atom 0 - Atom 1: single
Atom 0 - Atom 2: single
Atom 1 - Atom 0: single
Atom 1 - Atom 2: no bond
Atom 2 - Atom 0: single
Atom 2 - Atom 1: no bond
Dense adjacency for H1-H2 connection:
tensor([0., 0., 0., 0., 1.])


3

In [5]:
model = GraphVAE().to(device)

In [6]:
# test forward passs
batch = next(iter(train_loader))
batch = batch.to(device)
with torch.no_grad():
    outputs = model(batch)
print("Forward pass successful!")

Forward pass successful!


In [7]:
import mygenai.training.training as training
training.train_model(model, train_loader, val_loader, device)

Epoch 000 | Train Loss: 0.5110 | Val Loss: 0.2155
Epoch 001 | Train Loss: 0.1824 | Val Loss: 0.1464
Epoch 002 | Train Loss: 0.1570 | Val Loss: 0.1345
Epoch 003 | Train Loss: 0.1393 | Val Loss: 0.1220
Epoch 004 | Train Loss: 0.1240 | Val Loss: 0.1195
Epoch 005 | Train Loss: 0.1205 | Val Loss: 0.1177
Epoch 006 | Train Loss: 0.1189 | Val Loss: 0.1183
Epoch 007 | Train Loss: 0.1186 | Val Loss: 0.1156
Epoch 008 | Train Loss: 0.1184 | Val Loss: 0.1169
Epoch 009 | Train Loss: 0.1187 | Val Loss: 0.1172
Epoch 010 | Train Loss: 0.1176 | Val Loss: 0.1208
Epoch 011 | Train Loss: 0.1171 | Val Loss: 0.1181
Epoch 012 | Train Loss: 0.1168 | Val Loss: 0.1183
Epoch 013 | Train Loss: 0.1172 | Val Loss: 0.1184
Epoch 014 | Train Loss: 0.1177 | Val Loss: 0.1219
Epoch 015 | Train Loss: 0.1172 | Val Loss: 0.1189
Epoch 016 | Train Loss: 0.1178 | Val Loss: 0.1198
Epoch 017 | Train Loss: 0.1172 | Val Loss: 0.1180
Early stopping triggered at epoch 18
Training finished


In [8]:
import torch.nn.functional as F

# feed water to the model
with torch.no_grad():
    outputs = model(water)
edge_attr_logits, mu, logvar, property_pred = outputs

edge_probs = F.softmax(edge_attr_logits, dim=-1)

# Now you can look at the probabilities for each edge
print("Probabilities for edge between atoms 0 and 1:")
print(edge_probs[0, 0, 1])  # First batch item, edge from atom 0 to 1

# Get the most likely bond type
most_likely_type = edge_probs[0, 0, 1].argmax().item()
print(f"Most likely bond type: {most_likely_type}")

# Print confidence in prediction
confidence = edge_probs[0, 0, 1, most_likely_type].item()
print(f"Confidence: {confidence:.4f}")

print("Probabilities for edge between (padding) atoms 5 and 10:")
print(edge_probs[0, 5, 10])

print("Probabilities for edge between (padding) atom 8 and (real) atom 1:")
print(edge_probs[0, 8, 1])

Probabilities for edge between atoms 0 and 1:
tensor([8.5438e-01, 7.7687e-02, 5.1896e-02, 2.7549e-04, 1.5765e-02])
Most likely bond type: 0
Confidence: 0.8544
Probabilities for edge between (padding) atoms 5 and 10:
tensor([0., 0., 0., 0., 1.])
Probabilities for edge between (padding) atom 8 and (real) atom 1:
tensor([0., 0., 0., 0., 1.])


In [9]:

# first check that each of the "no atom" nodes have "no bond" edges
# first three nodes are O, H, H
for i in range(3, water.x.shape[0]):
    for j in range(water.x.shape[0]):
        # should always be [0., 0., 0., 0., 0., 1.]
        most_likely_type = edge_probs[0, i, j].argmax().item()
        if most_likely_type != 4: # no bond
            print(f"Unexpected bond type ({i}-{j}): {BOND_TYPES[most_likely_type]}")


In [10]:
# check single bonds between O and each H
o_to_h_1 = edge_probs[0, 0, 1].argmax().item()
o_to_h_2 = edge_probs[0, 0, 2].argmax().item()
h_to_o_1 = edge_probs[0, 1, 0].argmax().item()
h_to_o_2 = edge_probs[0, 2, 0].argmax().item()
print(f"Bond type between O and H1: {BOND_TYPES[o_to_h_1]}")
print(f"Bond type between O and H2: {BOND_TYPES[o_to_h_2]}")
print(f"Bond type between H1 and O: {BOND_TYPES[h_to_o_1]}")
print(f"Bond type between H2 and O: {BOND_TYPES[h_to_o_2]}")

# check no bond between H and H
# looks like the model is predicting a single bond between H1 and H2 :(
h_to_h = edge_probs[0, 1, 2].argmax().item()
print(f"Bond type between H1 and H2: {BOND_TYPES[h_to_h]}")

Bond type between O and H1: single
Bond type between O and H2: no bond
Bond type between H1 and O: single
Bond type between H2 and O: no bond
Bond type between H1 and H2: single


In [11]:
# sanity check: intensely train the model on *only* the water molecule
# this should make the model overfit to it and predict the correct bond types
# if it does that, we know the problem is in the model
# the most likely culprit is the overwhelming number of "no atom" nodes and
# "no bond" edges
import copy
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

# Create a dataloader with just the water molecule (repeating it to match API)
def create_water_loader(water):
    from torch_geometric.data import Batch
    # Create a batch with just the water molecule repeated
    water_batch = Batch.from_data_list([water] * 4)  # Batch size of 4

    # Create a simple loader that just returns this batch
    class WaterLoader:
        def __iter__(self):
            yield water_batch
        def __len__(self):
            return 1

    return WaterLoader()

water_loader = create_water_loader(water)

# Get the ground truth adjacency matrix
def get_ground_truth_bonds(molecule, natom):
    print("Ground truth adjacency matrix for water:")
    edges = molecule.edge_index.cpu().numpy()
    attrs = molecule.edge_attr.cpu().numpy()

    # Print in human-readable format
    print("Edges:")
    for i in range(min(10, edges.shape[1])):
        src, dst = edges[0, i], edges[1, i]
        if src < natom and dst < natom:  # Only real atoms
            attr_idx = attrs[i].argmax()
            bond_type = BOND_TYPES[attr_idx]
            print(f"  Atom {src} - Atom {dst}: {bond_type}")

    return edges, attrs

get_ground_truth_bonds(water, 3) # Water has 3 atoms: O, H, H

print("\nOverfitting model to water molecule...")
water_model = copy.deepcopy(model)
training.train_model(
    water_model,
    water_loader,
    water_loader,
    device,
    n_epochs=200,
    patience=1000           # prevent early stopping
)

# Check predictions from overfitted model
with torch.no_grad():
    edge_logits, mu, logvar, property_pred = water_model(water)
    edge_probs = F.softmax(edge_logits, dim=-1)

print("\nBond predictions after overfitting:")
print(f"Bond type between O and H1: {BOND_TYPES[edge_probs[0, 0, 1].argmax().item()]}")
print(f"Bond type between O and H2: {BOND_TYPES[edge_probs[0, 0, 2].argmax().item()]}")
print(f"Bond type between H1 and H2: {BOND_TYPES[edge_probs[0, 1, 2].argmax().item()]}")
print("Checking for spurious bonding to non-atom nodes...")
for i in range(3, water.x.shape[0]):
    for j in range(water.x.shape[0]):
        # should always be [0., 0., 0., 0., 0., 1.]
        most_likely_type = edge_probs[0, i, j].argmax().item()
        if most_likely_type != 4: # no bond
            print(f"Unexpected bond type ({i}-{j}): {BOND_TYPES[most_likely_type]}")

# Show probabilities for key bonds
print("\nProbabilities:")
print(f"O-H1: {[round(x.item(), 4) for x in edge_probs[0, 0, 1]]}")
print(f"O-H2: {[round(x.item(), 4) for x in edge_probs[0, 0, 2]]}")
print(f"H1-H2: {[round(x.item(), 4) for x in edge_probs[0, 1, 2]]}")

# Simple visualization of predictions vs ground truth
def compare_water_bonds(edge_probs):
    """Compare predicted bonds against water's known structure"""
    structure = [
        ["", "O", "H1", "H2"],
        ["O", "-", "", ""],
        ["H1", "", "-", ""],
        ["H2", "", "", "-"]
    ]

    # Fill in predicted bonds
    for i in range(1, 4):
        for j in range(1, 4):
            if i != j:
                src, dst = i-1, j-1
                bond_type = BOND_TYPES[edge_probs[0, src, dst].argmax().item()]
                if bond_type != "no bond":
                    structure[i][j] = bond_type

    # Print table
    print("\nPredicted water molecule structure:")
    for row in structure:
        print("  ".join(f"{cell:8s}" for cell in row))

compare_water_bonds(edge_probs)

Ground truth adjacency matrix for water:
Edges:
  Atom 0 - Atom 1: single
  Atom 0 - Atom 2: single

Overfitting model to water molecule...
Epoch 000 | Train Loss: 0.6018 | Val Loss: 0.4967
Epoch 001 | Train Loss: 0.5026 | Val Loss: 0.4254
Epoch 002 | Train Loss: 0.4168 | Val Loss: 0.3621
Epoch 003 | Train Loss: 0.3712 | Val Loss: 0.3225
Epoch 004 | Train Loss: 0.3278 | Val Loss: 0.2888
Epoch 005 | Train Loss: 0.2918 | Val Loss: 0.2629
Epoch 006 | Train Loss: 0.2591 | Val Loss: 0.2288
Epoch 007 | Train Loss: 0.2308 | Val Loss: 0.2001
Epoch 008 | Train Loss: 0.1983 | Val Loss: 0.1697
Epoch 009 | Train Loss: 0.1727 | Val Loss: 0.1440
Epoch 010 | Train Loss: 0.1426 | Val Loss: 0.1137
Epoch 011 | Train Loss: 0.1156 | Val Loss: 0.0925
Epoch 012 | Train Loss: 0.0942 | Val Loss: 0.0711
Epoch 013 | Train Loss: 0.0716 | Val Loss: 0.0522
Epoch 014 | Train Loss: 0.0523 | Val Loss: 0.0378
Epoch 015 | Train Loss: 0.0383 | Val Loss: 0.0267
Epoch 016 | Train Loss: 0.0271 | Val Loss: 0.0180
Epoch 017 