# A Simple Variational Autoencoder

Here we explore using a simple variational autoencoder trained on the QM9 dataset to generate new molecules.

The goal of this notebook is to provide a proof of concept that you can use a VAE to generate new molecules. Some physical restrictions will also be enforced. The model will also be saved for easy reloading in later notebooks.

In [11]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai
from sklearn.model_selection import train_test_split
import numpy as np

dataset = QM9(root="../data/QM9")

In [12]:
# check if CUDA is available
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [16]:
class GraphVAE(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim):
        super(GraphVAE, self).__init__()
        self.in_channels = in_channels  # Store input feature dimension

        self.encoder = pyg_nn.GCNConv(in_channels, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_channels)
        )

        # Edge Decoder (for adjacency matrix reconstruction)
        self.edge_decoder = nn.Linear(latent_dim, dataset.num_edge_features)

    def forward(self, x, edge_index):
        # Encode
        h = self.encoder(x, edge_index)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        # Decode node features
        recon_x = self.decoder(z)

        # QM9 edges represent interactions between pairs of atoms, so need two nodes per edge
        edge_features = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1)  # Concatenating both node embeddings
        edge_pred = self.edge_decoder(edge_features)

        return recon_x, edge_pred, mu, logvar

In [17]:
def print_shapes(batch, node_pred, edge_pred, mu, logvar):
    """Print the shapes of the batch and predictions for debugging."""
    print("\nShape Information:")
    print(f"Batch features (batch.x): {batch.x.shape}")
    print(f"Batch edge index (batch.edge_index): {batch.edge_index.shape}")
    print(f"Batch edge attributes (batch.edge_attr): {batch.edge_attr.shape if hasattr(batch, 'edge_attr') else 'None'}")
    print(f"Node predictions: {node_pred.shape}")
    print(f"Edge predictions: {edge_pred.shape}")
    print(f"Mu: {mu.shape}")
    print(f"Logvar: {logvar.shape}")

In [18]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to GPU
vae = GraphVAE(in_channels=dataset.num_features, hidden_dim=64, latent_dim=32).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(node_pred, edge_pred, node_true, edge_true, mu, logvar):
    node_loss = torch.nn.functional.cross_entropy(node_pred, node_true, reduction='sum')
    edge_loss = torch.nn.functional.binary_cross_entropy_with_logits(
        edge_pred, edge_true, reduction='sum'
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = node_loss + edge_loss + kl_loss

    return total_loss, {
        'node_loss': node_loss.item(),
        'edge_loss': edge_loss.item(),
        'kl_loss': kl_loss.item()
    }

# Data splitting
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

# Training loop with validation
best_val_loss = float('inf')
patience = 5
patience_counter = 0
n_epochs = 50

# Training Loop
for epoch in range(n_epochs):
    # Training phase
    vae.train()
    train_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Forward pass
        node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
        print_shapes(batch, node_pred, edge_pred, mu, logvar)

        # Calculate loss
        loss, metrics = loss_function(
            node_pred, edge_pred,
            batch.x, batch.edge_attr,
            mu, logvar
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update metrics
        train_metrics['total_loss'] += loss.item()
        for k, v in metrics.items():
            train_metrics[k] += v

    # Validation phase
    vae.eval()
    val_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
            loss, metrics = loss_function(
                node_pred, edge_pred,
                batch.x, batch.edge_attr,
                mu, logvar
            )

            val_metrics['total_loss'] += loss.item()
            for k, v in metrics.items():
                val_metrics[k] += v

    print(f"\nEpoch {epoch+1}/{n_epochs}")
    print("Training metrics:")
    for k, v in train_metrics.items():
        print(f"{k}: {v/len(train_loader):.4f}")
    print("\nValidation metrics:")
    for k, v in val_metrics.items():
        print(f"{k}: {v/len(val_loader):.4f}")

    # Early stopping
    val_loss = val_metrics['total_loss'] / len(val_loader)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(vae.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("\nEarly stopping triggered!")
            break

Using device: cuda


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4746x64 and 32x4)

In [None]:
# Load best model and evaluate on test set
vae.load_state_dict(torch.load('best_model.pt'))
vae.eval()
test_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
        loss, metrics = loss_function(
            node_pred, edge_pred,
            batch.x, batch.edge_attr,
            mu, logvar
        )

        test_metrics['total_loss'] += loss.item()
        for k, v in metrics.items():
            test_metrics[k] += v

print("\nTest metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v/len(test_loader):.4f}")

In [None]:
# Generate a random latent vector
latent_sample = torch.randn(1, 32).to(device)

# Decode the latent vector to get molecule features
with torch.no_grad():
    generated_features = vae.decoder(latent_sample)

# Print the generated features
print("Generated molecule features:", generated_features.shape)
print(generated_features)

# Note: This generates the feature vector for a molecule.
# To convert this into a proper molecular structure, you would need additional
# post-processing steps to convert the features back into a valid molecular graph

Generated molecule features: torch.Size([1, 11])
tensor([[ 9.7741e-01,  8.8419e-03,  8.6682e-03, -8.1913e-03, -2.4448e-04,
          1.1425e+00,  6.6321e-03, -3.7067e-04, -1.6801e-04, -3.3279e-05,
         -3.5893e-02]], device='cuda:0')
