# Message-Passing Graph Variational Autoencoder

Here I will use principles from message-passing graph neural networks to try to generate
molecules in a graph variational autoencoder. This is based on my solutions to 
`geometric-gnn-dojo/geometric_gnn_101.ipynb`
Some of the code has also been taken from there

might also want to consider autoregressive model

In [1]:
import importlib
import logging
import time
from mygenai.models.graphvae import PropertyConditionedVAE
# importlib.reload(PropertyConditionedVAE)

import torch
torch.cuda.empty_cache()
import torch_geometric
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.5.0+cu124
PyG version 2.6.1


In [2]:

debug_print = False
if debug_print:
    print("Debug mode is ON")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.DEBUG)
else:
    print("Debug mode is OFF")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.INFO)

Debug mode is OFF


In [None]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from mygenai.utils.transforms import CompleteGraph
from mygenai.utils.transforms import AddEdgeExistence

dataset = QM9(root="../data/QM9", transform=AddEdgeExistence()) #, transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
# focus on just using the one-hot encoding of the atomic number, for simplicity for now
dataset.data.x = dataset.data.x[:, :5]

# Normalize distances in the dataset
fixed_max_distance = 2.0

# TODO in the future maybe use a Z matrix representation for positions
# def normalize_distances(dataset, max_distance):
#     for data in dataset:
#         pos = data.pos  # (n, 3) - absolute coordinates
#         src, dst = data.edge_index  # (2, num_edges) - edge indices
#         relative_positions = pos[dst] - pos[src]  # (num_edges, 3)
#         distances = torch.norm(relative_positions, dim=1)  # (num_edges,)
#         data.normalized_distances = distances / max_distance  # Normalize distances
#     return dataset

# dataset = normalize_distances(dataset, fixed_max_distance)
# min_atomic_number = 1
# max_atomic_number = 9
# dataset.data.z = (dataset.data.z - min_atomic_number) / (max_atomic_number - min_atomic_number) # doesn't actually matter because it's not used (this information is determined by data.x)



In [4]:
# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

In [5]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)

In [6]:
# test forward passs
batch = next(iter(train_loader))
batch = batch.to(device)
with torch.no_grad():
    outputs = model(batch)
print("Forward pass successful!")

Forward pass successful!


In [7]:
# # check if training sets are reasonably balanced

# def basic_homo_lumo_stats(loader, name):
#     total_nodes = 0
#     total_graphs = 0
#     prop_values = []

#     for batch in loader:
#         total_graphs += batch.batch.max().item() + 1
#         total_nodes += batch.x.shape[0]
#         prop_values.append(batch.y[:, 4].cpu().numpy())

#     prop_values = np.concatenate(prop_values)
#     print(f"{name} stats - graphs: {total_graphs}, avg. nodes: {total_nodes/total_graphs}")
#     print(f"{name} property stats - mean: {prop_values.mean():.4f}, std: {prop_values.std():.4f}")

# basic_homo_lumo_stats(train_loader, "Train")
# basic_homo_lumo_stats(test_loader, "Test")
# basic_homo_lumo_stats(val_loader, "Validation")

In [8]:


# TODO : move training etc. to mygenai
# logging.basicConfig(level=logging.DEBUG)

# for now, check to see if can at least reconstruct molecules
recon_weight = 1.0
kl_weight = 0. #0.001
property_weight = 0.

def train_epoch(model, optimizer, train_loader, device):
    logger = logging.getLogger('train_epoch')
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        logger.debug(f"\nBatch {batch_idx}:")
        logger.debug(f"Batch properties: x={batch.x.shape}, pos={batch.pos.shape}, batch={batch.batch.shape}")
        optimizer.zero_grad()

        # Forward pass
        node_features, distances, directions, edge_features, num_nodes, edge_existence, mu, log_var, property_pred = model(batch)

        if batch_idx == 0:  # Check first batch only
            # Print statistics about generated values
            print(f"Generated node features: min={node_features.min().item():.4f}, max={node_features.max().item():.4f}")
            # print(f"Original node features: min={batch.x.min().item():.4f}, max={batch.x.max().item():.4f}") # always in [0,1] because one-hot encoding
            print(f"Generated distances: min={distances.min().item():.4f}, max={distances.max().item():.4f}")
            # print(f"Generated directions: min={directions.min().item():.4f}, max={directions.max().item():.4f}")

        # Calculate loss
        loss = model.loss_function(
            node_features, distances, directions, edge_features, num_nodes,
            edge_existence, batch, mu, log_var, property_pred,
            recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    batch_losses = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            batch = batch.to(device)

            # Forward pass
            node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = model(batch)

            # Calculate loss
            loss = model.loss_function(
                node_features, distances, directions, edge_features, num_nodes,
                batch, mu, log_var, property_pred,
                recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
            )

            batch_losses.append(loss.item())
            total_loss += loss.item()

    return total_loss / len(val_loader)


# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

# TODO normalise for a minimum distance!!
# TODO !!!!! The minimal and maximal distances are from preprocessing WITHOUT CompleteGraph >____<<< !!!!!!
#   No wonder you get fucked up distances...
# TODO don't use a complete graph, check edge detection and use softmax (since it is one-hot) for bond type!
batch = next(iter(train_loader))
logging.info(f"Target property shape: {batch.y.shape}")
logging.info(f"Target property sample: {batch.y[0]}")
logging.info(f"Input node features shape: {batch.x.shape}")  # Should be (N, 11)
logging.info(f"Input positions shape: {batch.pos.shape}")    # Should be (N, 3)
logging.info(f"Number of nodes: {batch.num_nodes}")
logging.info(f"Batch size: {batch.batch.max().item() + 1}")
for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(model, optimizer, train_loader, device)

    # Validate
    val_loss = validate(model, val_loader, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            # Save best model
            # Save model with timestamp in a more readable format
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            save_path = f'best_vae_model_{timestamp}.pt'
            torch.save(model.state_dict(), save_path)
            print(f'Saved best model to: {save_path}')
            break

    # Print progress
    print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

# Load best model for testing
# model.load_state_dict(torch.load('best_vae_model.pt'))

# Test final model
test_loss = validate(model, test_loader, device)
print(f'Final Test Loss: {test_loss:.4f}')

2025-04-21 13:29:35,288 - root - INFO - Target property shape: torch.Size([128, 19])
2025-04-21 13:29:35,289 - root - INFO - Target property sample: tensor([ 1.1939e+00,  1.2649e+00, -2.1483e+00, -2.4554e+00, -1.4376e+00,
         1.9015e+00, -1.4796e+00,  2.8287e-01,  2.8287e-01,  2.8287e-01,
         2.8283e-01, -4.2958e-01,  8.9549e-01,  9.0289e-01,  9.0640e-01,
         8.4868e-01, -2.3879e-03, -4.6962e-01, -4.6332e-01])
2025-04-21 13:29:35,289 - root - INFO - Input node features shape: torch.Size([2279, 5])
2025-04-21 13:29:35,289 - root - INFO - Input positions shape: torch.Size([2279, 3])
2025-04-21 13:29:35,290 - root - INFO - Number of nodes: 2279
2025-04-21 13:29:35,290 - root - INFO - Batch size: 128


Generated node features: min=0.0174, max=0.6122
Generated distances: min=1.0282, max=1.7611


AttributeError: 'GlobalStorage' object has no attribute 'edge_existence'

In [None]:
test_molecule = dataset[0]
test_molecule = test_molecule.to(device)

from torch_geometric.data import Batch
# Create a batch with a single molecule
test_batch = Batch.from_data_list([test_molecule])
outputs = model(test_batch)
node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = outputs

In [None]:
# Print the generated node features
print("Generated node features:")
print(node_features)
print("Generated distances:")
print(distances)
print("Generated directions:")
print(directions)
print("Generated edge features:")
print(edge_features)
print("Generated num_nodes:")
print(num_nodes)
print("Generated mu:")
print(mu)
print("Generated log_var:")
print(log_var)
print("Generated property prediction:")
print(property_pred)

Generated node features:
tensor([[0.5131, 0.3238, 0.0774, 0.0667, 0.0009],
        [0.5131, 0.3238, 0.0774, 0.0667, 0.0009],
        [0.5131, 0.3238, 0.0774, 0.0667, 0.0009],
        [0.5131, 0.3238, 0.0774, 0.0667, 0.0009],
        [0.5131, 0.3238, 0.0774, 0.0667, 0.0009]], device='cuda:0',
       grad_fn=<SigmoidBackward0>)
Generated distances:
tensor([[0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756],
        [0.9756]], device='cuda:0', grad_fn=<AddBackward0>)
Generated directions:
tensor([[ 0.5079,  0.3173, -0.8008],
        [ 0.5079,  0.3173, -0.8008],
        [ 0.5079,  0.3173, -0.8008],
        [ 0.5079,  0.3173, -0.8008],
        [ 0.5079,  0.3173, -0.8008],
        [ 0.5079,  0.3173, -0.8008],
     

In [None]:
from mygenai.utils.visualisation import visualise_molecule
visualise_molecule(test_molecule)

<py3Dmol.view at 0x7fd40204b620>

In [None]:
# import importlib
# import mygenai.models.graphvae
# from mygenai.models.graphvae import PropertyConditionedVAE
# importlib.reload(mygenai.models.graphvae)
logging.getLogger('PropertyConditionedVAE').setLevel(logging.DEBUG)
batch_data = dataset[:100]
batch = Batch.from_data_list(batch_data).to(device)
with torch.no_grad():  # Disable gradient computation
    outputs = model(batch)
node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = outputs

# Compute the loss
loss = model.loss_function(
    node_features=node_features,
    distances=distances,
    directions=directions,
    edge_features=edge_features,
    num_nodes=num_nodes,
    data=batch,
    mu=mu,
    log_var=log_var,
    property_pred=property_pred,
    property_weight=property_weight,  # Use the same weights as during training
    recon_weight=recon_weight,
    kl_weight=kl_weight
)

2025-04-20 22:15:35,154 - PropertyConditionedVAE - DEBUG - Input data - batch_size: 100, nodes: 1005
2025-04-20 22:15:35,154 - PropertyConditionedVAE - DEBUG - Forward called without target_property (None)
2025-04-20 22:15:35,157 - PropertyConditionedVAE - DEBUG - Encoder outputs - mu: torch.Size([100, 32]), log_var: torch.Size([100, 32]), property_pred: torch.Size([100, 1])
2025-04-20 22:15:35,158 - PropertyConditionedVAE - DEBUG - Sampled z shape: torch.Size([100, 32])
2025-04-20 22:15:35,158 - PropertyConditionedVAE - DEBUG - Using encoder prediction for property, shape: torch.Size([100, 1])
2025-04-20 22:15:35,161 - PropertyConditionedVAE - DEBUG - Decoder outputs - node_features: torch.Size([1005, 5]), distances: torch.Size([10070, 1]), directions: torch.Size([10070, 3]), edge_features: torch.Size([10070, 4]), num_nodes: tensor([ 5.0497,  4.1476,  3.2143,  3.9463,  2.7335,  3.9900,  8.1959,  6.0648,
         7.1128,  6.0964,  7.1026,  6.1230, 11.0618,  9.1959,  9.2088,  9.1748,
  