# Message-Passing Graph Variational Autoencoder

Here I will use principles from message-passing graph neural networks to try to generate
molecules in a graph variational autoencoder. This is based on my solutions to 
`geometric-gnn-dojo/geometric_gnn_101.ipynb`
Some of the code has also been taken from there

might also want to consider autoregressive model

In [1]:
import importlib
import logging
import time
from mygenai.models.graphvae import PropertyConditionedVAE
# importlib.reload(PropertyConditionedVAE)

import torch
torch.cuda.empty_cache()
import torch_geometric
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.5.0+cu124
PyG version 2.6.1


In [2]:

debug_print = False
if debug_print:
    print("Debug mode is ON")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.DEBUG)
else:
    print("Debug mode is OFF")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.INFO)

Debug mode is OFF


In [None]:
import torch
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from mygenai.utils.transforms import CompleteGraph

dataset = QM9(root="../data/QM9", transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
# focus on just using the one-hot encoding of the atomic number, for simplicity for now
dataset.data.x = dataset.data.x[:, :5]

# Normalize distances in the dataset
fixed_max_distance = 2.0

# TODO in the future maybe use a Z matrix representation for positions
# def normalize_distances(dataset, max_distance):
#     for data in dataset:
#         pos = data.pos  # (n, 3) - absolute coordinates
#         src, dst = data.edge_index  # (2, num_edges) - edge indices
#         relative_positions = pos[dst] - pos[src]  # (num_edges, 3)
#         distances = torch.norm(relative_positions, dim=1)  # (num_edges,)
#         data.normalized_distances = distances / max_distance  # Normalize distances
#     return dataset

# dataset = normalize_distances(dataset, fixed_max_distance)
# min_atomic_number = 1
# max_atomic_number = 9
# dataset.data.z = (dataset.data.z - min_atomic_number) / (max_atomic_number - min_atomic_number) # doesn't actually matter because it's not used (this information is determined by data.x)



KeyboardInterrupt: 

In [None]:
# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

In [None]:
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)

In [None]:
# test forward passs
batch = next(iter(train_loader))
batch = batch.to(device)
with torch.no_grad():
    outputs = model(batch)
print("Forward pass successful!")

Forward pass successful!


In [None]:
# # check if training sets are reasonably balanced

# def basic_homo_lumo_stats(loader, name):
#     total_nodes = 0
#     total_graphs = 0
#     prop_values = []

#     for batch in loader:
#         total_graphs += batch.batch.max().item() + 1
#         total_nodes += batch.x.shape[0]
#         prop_values.append(batch.y[:, 4].cpu().numpy())

#     prop_values = np.concatenate(prop_values)
#     print(f"{name} stats - graphs: {total_graphs}, avg. nodes: {total_nodes/total_graphs}")
#     print(f"{name} property stats - mean: {prop_values.mean():.4f}, std: {prop_values.std():.4f}")

# basic_homo_lumo_stats(train_loader, "Train")
# basic_homo_lumo_stats(test_loader, "Test")
# basic_homo_lumo_stats(val_loader, "Validation")

In [None]:


# TODO : move training etc. to mygenai
# logging.basicConfig(level=logging.DEBUG)

# for now, check to see if can at least reconstruct molecules
recon_weight = 1.0
kl_weight = 0. #0.01
property_weight = 0.

def train_epoch(model, optimizer, train_loader, device):
    logger = logging.getLogger('train_epoch')
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        logger.debug(f"\nBatch {batch_idx}:")
        logger.debug(f"Batch properties: x={batch.x.shape}, pos={batch.pos.shape}, batch={batch.batch.shape}")
        optimizer.zero_grad()

        # Forward pass
        node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = model(batch)

        if batch_idx == 0:  # Check first batch only
            # Print statistics about generated values
            print(f"Generated node features: min={node_features.min().item():.4f}, max={node_features.max().item():.4f}")
            print(f"Original node features: min={batch.x.min().item():.4f}, max={batch.x.max().item():.4f}")
            print(f"Generated distances: min={distances.min().item():.4f}, max={distances.max().item():.4f}")
            # print(f"Generated directions: min={directions.min().item():.4f}, max={directions.max().item():.4f}")

        # Calculate loss
        loss = model.loss_function(
            node_features, distances, directions, edge_features, num_nodes,
            batch, mu, log_var, property_pred,
            recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
        )

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.)
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)


def validate(model, val_loader, device):
    model.eval()
    total_loss = 0
    batch_losses = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            batch = batch.to(device)

            # Forward pass
            node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = model(batch)

            # Calculate loss
            loss = model.loss_function(
                node_features, distances, directions, edge_features, num_nodes,
                batch, mu, log_var, property_pred,
                recon_weight=recon_weight, kl_weight=kl_weight, property_weight=property_weight
            )

            batch_losses.append(loss.item())
            total_loss += loss.item()

    return total_loss / len(val_loader)


# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

batch = next(iter(train_loader))
logging.info(f"Target property shape: {batch.y.shape}")
logging.info(f"Target property sample: {batch.y[0]}")
logging.info(f"Input node features shape: {batch.x.shape}")  # Should be (N, 11)
logging.info(f"Input positions shape: {batch.pos.shape}")    # Should be (N, 3)
logging.info(f"Number of nodes: {batch.num_nodes}")
logging.info(f"Batch size: {batch.batch.max().item() + 1}")
for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(model, optimizer, train_loader, device)

    # Validate
    val_loss = validate(model, val_loader, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            # Save best model
            # Save model with timestamp in a more readable format
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            save_path = f'best_vae_model_{timestamp}.pt'
            torch.save(model.state_dict(), save_path)
            print(f'Saved best model to: {save_path}')
            break

    # Print progress
    print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

# Load best model for testing
# model.load_state_dict(torch.load('best_vae_model.pt'))

# Test final model
test_loss = validate(model, test_loader, device)
print(f'Final Test Loss: {test_loss:.4f}')

2025-04-20 18:51:03,065 - root - INFO - Target property shape: torch.Size([128, 19])
2025-04-20 18:51:03,066 - root - INFO - Target property sample: tensor([-0.5368,  0.6629, -0.0587, -0.1267, -0.0984,  0.4885,  1.0045, -0.3360,
        -0.3360, -0.3360, -0.3360,  0.8287, -0.9686, -0.9696, -0.9700, -0.9614,
        -0.0038, -0.1471, -0.2082])
2025-04-20 18:51:03,066 - root - INFO - Input node features shape: torch.Size([2360, 5])
2025-04-20 18:51:03,066 - root - INFO - Input positions shape: torch.Size([2360, 3])
2025-04-20 18:51:03,067 - root - INFO - Number of nodes: 2360
2025-04-20 18:51:03,067 - root - INFO - Batch size: 128


Generated node features: min=0.0807, max=0.9049
Original node features: min=0.0000, max=1.0000
Generated distances: min=0.0000, max=1.9338
Generated directions: min=-0.9967, max=0.9990
Epoch 000 | Train Loss: 74.0588 | Val Loss: 9.8723
Generated node features: min=0.0030, max=0.6648
Original node features: min=0.0000, max=1.0000
Generated distances: min=0.0000, max=3.7843
Generated directions: min=-0.9946, max=0.9916
Epoch 001 | Train Loss: 9.8253 | Val Loss: 5.7595
Generated node features: min=0.0000, max=0.6689
Original node features: min=0.0000, max=1.0000
Generated distances: min=0.0000, max=4.0604
Generated directions: min=-0.9901, max=0.9926
Epoch 002 | Train Loss: 2.9047 | Val Loss: 2.7684
Generated node features: min=0.0014, max=0.6640
Original node features: min=0.0000, max=1.0000
Generated distances: min=2.8320, max=4.4453
Generated directions: min=-0.9961, max=0.9937
Epoch 003 | Train Loss: 2.7186 | Val Loss: 2.6102
Generated node features: min=0.0005, max=0.6837
Original no

In [None]:
model

PropertyConditionedVAE(
  (encoder): Encoder(
    (lin_in): Linear(in_features=5, out_features=64, bias=True)
    (convs): ModuleList(
      (0-1): 2 x EquivariantMPNNLayer(emb_dim=64, aggr=add)
    )
    (mu): Linear(in_features=64, out_features=32, bias=True)
    (log_var): Linear(in_features=64, out_features=32, bias=True)
    (property_predictor): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (decoder): ConditionalDecoder(
    (lin_latent): Linear(in_features=33, out_features=64, bias=True)
    (node_decoder): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=5, bias=True)
      (4): Sigmoid()
    )