# Message-Passing Graph Variational Autoencoder

Here I will use principles from message-passing graph neural networks to try to generate
molecules in a graph variational autoencoder. This is based on my solutions to 
`geometric-gnn-dojo/geometric_gnn_101.ipynb`
Some of the code has also been taken from there

might also want to consider autoregressive model

In [1]:
import importlib
import logging
import time
from mygenai.models.graphvae import PropertyConditionedVAE
# importlib.reload(PropertyConditionedVAE)

import torch
torch.cuda.empty_cache()
import torch_geometric
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.5.0+cu124
PyG version 2.6.1


In [2]:

debug_print = False
if debug_print:
    print("Debug mode is ON")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.DEBUG)
else:
    print("Debug mode is OFF")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.INFO)

Debug mode is OFF


In [3]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from mygenai.utils.transforms import CompleteGraph

dataset = QM9(root="../data/QM9", transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std

2025-04-18 16:05:52,169 - rdkit - INFO - Enabling RDKit 2024.09.6 jupyter extensions


In [4]:
# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

In [5]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)



In [6]:
# test forward passs
batch = next(iter(train_loader))
batch = batch.to(device)
with torch.no_grad():
    outputs = model(batch)
print("Forward pass successful!")

Forward pass successful!


In [7]:
# # check if training sets are reasonably balanced

# def basic_homo_lumo_stats(loader, name):
#     total_nodes = 0
#     total_graphs = 0
#     prop_values = []

#     for batch in loader:
#         total_graphs += batch.batch.max().item() + 1
#         total_nodes += batch.x.shape[0]
#         prop_values.append(batch.y[:, 4].cpu().numpy())

#     prop_values = np.concatenate(prop_values)
#     print(f"{name} stats - graphs: {total_graphs}, avg. nodes: {total_nodes/total_graphs}")
#     print(f"{name} property stats - mean: {prop_values.mean():.4f}, std: {prop_values.std():.4f}")

# basic_homo_lumo_stats(train_loader, "Train")
# basic_homo_lumo_stats(test_loader, "Test")
# basic_homo_lumo_stats(val_loader, "Validation")

In [8]:


# TODO : move training etc. to mygenai
# logging.basicConfig(level=logging.DEBUG)

# for now, check to see if can at least reconstruct molecules
recon_weight = 1.0
kl_weight = 0.
property_weight = 0.

def train_epoch(model, optimizer, train_loader, device):
    logger = logging.getLogger('train_epoch')
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        logger.debug(f"\nBatch {batch_idx}:")
        logger.debug(f"Batch properties: x={batch.x.shape}, pos={batch.pos.shape}, batch={batch.batch.shape}")
        optimizer.zero_grad()

        # Forward pass
        node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)
        if batch_idx == 0:  # Check first batch only
            # Print statistics about generated values
            print(f"Generated node features: min={node_features.min().item():.4f}, max={node_features.max().item():.4f}")
            print(f"Original node features: min={batch.x.min().item():.4f}, max={batch.x.max().item():.4f}")
            print(f"Generated positions: min={positions.min().item():.4f}, max={positions.max().item():.4f}")
            print(f"Original positions: min={batch.pos.min().item():.4f}, max={batch.pos.max().item():.4f}")

        # Calculate loss
        loss = model.loss_function(
            node_features, positions, num_nodes,
            batch, mu, log_var, property_pred,
            recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=.5)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    batch_losses = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            batch = batch.to(device)

            # Forward pass
            node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)

            # Calculate reconstruction loss manually for debugging
            recon_loss = 0
            start_idx = 0
            total_nodes = 0

            # Process each graph in the batch
            for i, n in enumerate(num_nodes):
                n_orig = (batch.batch == i).sum().item()
                n_gen = n.item()
                nodes_to_compare = min(n_gen, n_orig)
                total_nodes += nodes_to_compare

                # # Debug node counts
                # if nodes_to_compare == 0:
                #     print(f"Warning: Zero nodes to compare in graph {i}, batch {batch_idx}")
                #     print(f"  Original: {n_orig}, Generated: {n_gen}")

                # if nodes_to_compare > 0:
                #     try:
                #         # Check for NaN or inf values
                #         if torch.isnan(node_features[start_idx:start_idx + nodes_to_compare]).any():
                #             print(f"NaN in node features, batch {batch_idx}, graph {i}")
                #         if torch.isnan(positions[start_idx:start_idx + nodes_to_compare]).any():
                #             print(f"NaN in positions, batch {batch_idx}, graph {i}")

                #         # Debug ranges
                #         print(f"Batch {batch_idx}, Graph {i}: Node features range: "
                #               f"{node_features[start_idx:start_idx + nodes_to_compare].min().item():.2f} - "
                #               f"{node_features[start_idx:start_idx + nodes_to_compare].max().item():.2f}")
                #         print(f"Batch {batch_idx}, Graph {i}: Positions range: "
                #               f"{positions[start_idx:start_idx + nodes_to_compare].min().item():.2f} - "
                #               f"{positions[start_idx:start_idx + nodes_to_compare].max().item():.2f}")
                #     except Exception as e:
                #         print(f"Error checking batch {batch_idx}, graph {i}: {e}")

                start_idx += n_gen

            # Regular loss calculation
            loss = model.loss_function(
                node_features, positions, num_nodes,
                batch, mu, log_var, property_pred,
                recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
            )

            # Check for reasonable loss values
            if loss.item() > 1e6:
                print(f"❌ Extremely high loss in validation batch {batch_idx}: {loss.item():.4f}")
                # Try to identify which graphs in the batch are problematic
                for i, n in enumerate(num_nodes):
                    # Check node feature norms
                    start_idx = sum(n_prev.item() for n_prev in num_nodes[:i])
                    end_idx = start_idx + n.item()
                    if end_idx > start_idx:
                        feat_norm = torch.norm(node_features[start_idx:end_idx]).item()
                        pos_norm = torch.norm(positions[start_idx:end_idx]).item()
                        print(f"  Graph {i}: Features norm={feat_norm:.2f}, Positions norm={pos_norm:.2f}")

            batch_losses.append(loss.item())
            total_loss += loss.item()

    # Print statistics of losses
    batch_losses = np.array(batch_losses)
    # print(f"Validation loss stats: mean={batch_losses.mean():.2f}, median={np.median(batch_losses):.2f}, "
        #   f"min={batch_losses.min():.2f}, max={batch_losses.max():.2f}")

    return total_loss / len(val_loader)


# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

batch = next(iter(train_loader))
logging.info(f"Target property shape: {batch.y.shape}")
logging.info(f"Target property sample: {batch.y[0]}")
logging.info(f"Input node features shape: {batch.x.shape}")  # Should be (N, 11)
logging.info(f"Input positions shape: {batch.pos.shape}")    # Should be (N, 3)
logging.info(f"Number of nodes: {batch.num_nodes}")
logging.info(f"Batch size: {batch.batch.max().item() + 1}")
for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(model, optimizer, train_loader, device)

    # Validate
    val_loss = validate(model, val_loader, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            # Save best model
            # Save model with timestamp in a more readable format
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            save_path = f'best_vae_model_{timestamp}.pt'
            torch.save(model.state_dict(), save_path)
            print(f'Saved best model to: {save_path}')
            break

    # Print progress
    print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

# Load best model for testing
# model.load_state_dict(torch.load('best_vae_model.pt'))

# Test final model
test_loss = validate(model, test_loader, device)
print(f'Final Test Loss: {test_loss:.4f}')

2025-04-18 16:05:53,162 - root - INFO - Target property shape: torch.Size([128, 19])
2025-04-18 16:05:53,163 - root - INFO - Target property sample: tensor([-0.2801,  0.0696,  0.7379,  1.1327,  0.7810,  0.0062,  0.6108, -0.7373,
        -0.7373, -0.7373, -0.7374,  1.1148, -0.3726, -0.3733, -0.3755, -0.3464,
        -0.0043,  0.0844, -0.0351])
2025-04-18 16:05:53,164 - root - INFO - Input node features shape: torch.Size([2331, 11])
2025-04-18 16:05:53,164 - root - INFO - Input positions shape: torch.Size([2331, 3])
2025-04-18 16:05:53,164 - root - INFO - Number of nodes: 2331
2025-04-18 16:05:53,164 - root - INFO - Batch size: 128


Generated node features: min=4.0144, max=5.0291
Original node features: min=0.0000, max=9.0000
Generated positions: min=-0.7958, max=-0.7534
Original positions: min=-7.2859, max=6.3207
Epoch 000 | Train Loss: 5.8704 | Val Loss: 65.1172
Generated node features: min=-0.1438, max=9.0182
Original node features: min=0.0000, max=9.0000
Generated positions: min=-6.9755, max=6.9755
Original positions: min=-6.8449, max=6.9834
Epoch 001 | Train Loss: 22.6095 | Val Loss: 79.5893
Generated node features: min=-0.1438, max=9.0182
Original node features: min=0.0000, max=9.0000
Generated positions: min=-6.9755, max=6.9755
Original positions: min=-6.6080, max=6.3674
Epoch 002 | Train Loss: 35.3322 | Val Loss: 78.8041
Generated node features: min=-0.1438, max=9.0182
Original node features: min=0.0000, max=9.0000
Generated positions: min=-6.9755, max=6.9755
Original positions: min=-6.5136, max=6.9398
Epoch 003 | Train Loss: 35.6613 | Val Loss: 78.7592
Generated node features: min=-0.1438, max=9.0182
Orig

  model.load_state_dict(torch.load('best_vae_model.pt'))


RuntimeError: Error(s) in loading state_dict for PropertyConditionedVAE:
	Missing key(s) in state_dict: "decoder.node_scale", "decoder.node_shift", "decoder.pos_scale". 