# Message-Passing Graph Variational Autoencoder

Here I will use principles from message-passing graph neural networks to try to generate
molecules in a graph variational autoencoder. This is based on my solutions to 
`geometric-gnn-dojo/geometric_gnn_101.ipynb`
Some of the code has also been taken from there

might also want to consider autoregressive model

In [None]:
import logging
import time
from mygenai.models.graphvae import PropertyConditionedVAE

import torch
import torch_geometric
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.5.0+cu124
PyG version 2.6.1


In [2]:

debug_print = False
if debug_print:
    print("Debug mode is ON")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.DEBUG)
else:
    print("Debug mode is OFF")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.INFO)

Debug mode is OFF


In [3]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai
from sklearn.model_selection import train_test_split
import numpy as np
from mygenai.utils.transforms import CompleteGraph

dataset = QM9(root="../data/QM9", transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std

2025-04-18 15:10:40,013 - rdkit - INFO - Enabling RDKit 2024.09.6 jupyter extensions


In [None]:
# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

In [None]:
# test forward passs
batch = next(iter(train_loader))
batch = batch.to(device)
with torch.no_grad():
    outputs = model(batch)
print("Forward pass successful!")

In [5]:
# check if training sets are reasonably balanced

def basic_homo_lumo_stats(loader, name):
    total_nodes = 0
    total_graphs = 0
    prop_values = []

    for batch in loader:
        total_graphs += batch.batch.max().item() + 1
        total_nodes += batch.x.shape[0]
        prop_values.append(batch.y[:, 4].cpu().numpy())

    prop_values = np.concatenate(prop_values)
    print(f"{name} stats - graphs: {total_graphs}, avg. nodes: {total_nodes/total_graphs}")
    print(f"{name} property stats - mean: {prop_values.mean():.4f}, std: {prop_values.std():.4f}")

basic_homo_lumo_stats(train_loader, "Train")
basic_homo_lumo_stats(test_loader, "Test")
basic_homo_lumo_stats(val_loader, "Validation")

Train stats - graphs: 78498, avg. nodes: 18.033746082702745
Train property stats - mean: 0.0014, std: 0.9989
Test stats - graphs: 26167, avg. nodes: 18.038445370122673
Test property stats - mean: -0.0037, std: 1.0037
Validation stats - graphs: 26166, avg. nodes: 18.0228158679202
Validation property stats - mean: -0.0005, std: 0.9995


In [None]:


# TODO : move training etc. to mygenai
# logging.basicConfig(level=logging.DEBUG)

# for now, check to see if can at least reconstruct molecules
recon_weight = 1.0
kl_weight = 0.
property_weight = 0.

def train_epoch(model, optimizer, train_loader, device):
    logger = logging.getLogger('train_epoch')
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        logger.debug(f"\nBatch {batch_idx}:")
        logger.debug(f"Batch properties: x={batch.x.shape}, pos={batch.pos.shape}, batch={batch.batch.shape}")
        optimizer.zero_grad()

        # Forward pass
        node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)

        # Calculate loss
        loss = model.loss_function(
            node_features, positions, num_nodes,
            batch, mu, log_var, property_pred,
            recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)

            # Forward pass with all outputs
            node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)

            # Calculate loss with all parameters
            loss = model.loss_function(
                node_features, positions, num_nodes,
                batch, mu, log_var, property_pred,
            recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
            )
            total_loss += loss.item()

    return total_loss / len(val_loader)


# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)

# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

batch = next(iter(train_loader))
logging.info(f"Target property shape: {batch.y.shape}")
logging.info(f"Target property sample: {batch.y[0]}")
logging.info(f"Input node features shape: {batch.x.shape}")  # Should be (N, 11)
logging.info(f"Input positions shape: {batch.pos.shape}")    # Should be (N, 3)
logging.info(f"Number of nodes: {batch.num_nodes}")
logging.info(f"Batch size: {batch.batch.max().item() + 1}")
for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(model, optimizer, train_loader, device)

    # Validate
    val_loss = validate(model, val_loader, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            # Save best model
            # Save model with timestamp in a more readable format
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            save_path = f'best_vae_model_{timestamp}.pt'
            torch.save(model.state_dict(), save_path)
            print(f'Saved best model to: {save_path}')
            break

    # Print progress
    print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

# Load best model for testing
model.load_state_dict(torch.load('best_vae_model.pt'))

# Test final model
test_loss = validate(model, test_loader, device)
print(f'Final Test Loss: {test_loss:.4f}')

2025-04-18 15:11:38,218 - root - INFO - Target property shape: torch.Size([128, 19])
2025-04-18 15:11:38,219 - root - INFO - Target property sample: tensor([-0.6957,  0.7926,  1.1932,  1.5083,  0.9399, -0.3079,  0.7014,  1.1254,
         1.1254,  1.1254,  1.1254, -0.4195, -0.6963, -0.7002, -0.7000, -0.6920,
        -0.0032, -0.0611,  0.0994])
2025-04-18 15:11:38,220 - root - INFO - Input node features shape: torch.Size([2349, 11])
2025-04-18 15:11:38,220 - root - INFO - Input positions shape: torch.Size([2349, 3])
2025-04-18 15:11:38,220 - root - INFO - Number of nodes: 2349
2025-04-18 15:11:38,220 - root - INFO - Batch size: 128


AttributeError: 'NoneType' object has no attribute 'shape'