# A Simple Variational Autoencoder

Here we explore using a simple variational autoencoder trained on the QM9 dataset to generate new molecules.

The goal of this notebook is to provide a proof of concept that you can use a VAE to generate new molecules. Some physical restrictions will also be enforced. The model will also be saved for easy reloading in later notebooks.

TODO: 
- [ ] add brief theory summary
- [ ] move what makes sense to mygenai src
- [ ] create nn configuration section at top for better maintainability
- [ ] use proper logging instead of print statements
- [ ] make the model take edge_index optionally and construct a full graph if not given
- [ ] have the model predict the number of nodes (atoms)

In [2]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai
from sklearn.model_selection import train_test_split
import numpy as np

dataset = QM9(root="../data/QM9")

In [3]:
# check if CUDA is available
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [4]:
class GraphVAE(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim):
        super(GraphVAE, self).__init__()
        self.in_channels = in_channels  # Store input feature dimension

        self.encoder = pyg_nn.GCNConv(in_channels, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance

        self.node_decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_channels)
        )

        # Edge Decoder (for adjacency matrix reconstruction, must handle concatenated node embedding)
        self.edge_decoder = nn.Sequential(
            nn.Linear(2 * latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dataset.num_edge_features)
        )
        # self.edge_decoder = nn.Linear(latent_dim, dataset.num_edge_features)

    # def generate_complete_graph(self, num_nodes):
    #     # Generate a complete graph with the given number of nodes
    #     edge_index = torch.combinations(torch.arange(num_nodes), r=2).t().contiguous()
    #     edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) # add reverse edges for undirected graph
    #     return edge_index.to(self.device) # TODO this doesn't work, no device set

    def decode(self, z, edge_index=None):
        # Decode node features
        node_pred = self.node_decoder(z)

        # QM9 edges represent interactions between pairs of atoms, so need two nodes per edge
        edge_features = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=-1)
        edge_pred = self.edge_decoder(edge_features)
        return node_pred, edge_pred

    def forward(self, x, edge_index):
        # Encode
        h = self.encoder(x, edge_index)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        # latent space sampling with reparametrisation trick
        z = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        node_pred, edge_pred = self.decode(z, edge_index)

        return node_pred, edge_pred, mu, logvar

In [5]:
def print_shapes(batch, node_pred, edge_pred, mu, logvar):
    """Print the shapes of the batch and predictions for debugging."""
    print("\nShape Information:")
    print(f"Batch features (batch.x): {batch.x.shape}")
    print(f"Batch edge index (batch.edge_index): {batch.edge_index.shape}")
    print(f"Batch edge attributes (batch.edge_attr): {batch.edge_attr.shape if hasattr(batch, 'edge_attr') else 'None'}")
    print(f"Node predictions: {node_pred.shape}")
    print(f"Edge predictions: {edge_pred.shape}")
    print(f"Mu: {mu.shape}")
    print(f"Logvar: {logvar.shape}")

In [6]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to GPU
vae = GraphVAE(in_channels=dataset.num_features, hidden_dim=64, latent_dim=32).to(device)


Using device: cuda


In [7]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(node_pred, edge_pred, node_true, edge_true, mu, logvar):
    node_loss = torch.nn.functional.cross_entropy(node_pred, node_true, reduction='sum')
    edge_loss = torch.nn.functional.binary_cross_entropy_with_logits(
        edge_pred, edge_true, reduction='sum'
    )
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    total_loss = node_loss + edge_loss + kl_loss

    return total_loss, {
        'node_loss': node_loss.item(),
        'edge_loss': edge_loss.item(),
        'kl_loss': kl_loss.item()
    }

# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

# Training loop with validation
try:
    vae.load_state_dict(torch.load('best_basic_model.pt'))
    print("Loaded existing model from best_basic_model.pt")
except FileNotFoundError:
    print("No existing model found, starting training...")
    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0
    n_epochs = 50

    for epoch in range(n_epochs):
        # Training phase
        vae.train()
        train_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            # Forward pass
            node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
            # print_shapes(batch, node_pred, edge_pred, mu, logvar)

            # Calculate loss
            loss, metrics = loss_function(
                node_pred, edge_pred,
                batch.x, batch.edge_attr,
                mu, logvar
            )

            # Backward pass
            loss.backward()
            optimizer.step()

            # Update metrics
            train_metrics['total_loss'] += loss.item()
            for k, v in metrics.items():
                train_metrics[k] += v

        # Validation phase
        vae.eval()
        val_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
                loss, metrics = loss_function(
                    node_pred, edge_pred,
                    batch.x, batch.edge_attr,
                    mu, logvar
                )

                val_metrics['total_loss'] += loss.item()
                for k, v in metrics.items():
                    val_metrics[k] += v

        print(f"\nEpoch {epoch+1}/{n_epochs}")
        print("Training metrics:")
        for k, v in train_metrics.items():
            print(f"{k}: {v/len(train_loader):.4f}")
        print("\nValidation metrics:")
        for k, v in val_metrics.items():
            print(f"{k}: {v/len(val_loader):.4f}")

        # Early stopping
        val_loss = val_metrics['total_loss'] / len(val_loader)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(vae.state_dict(), 'best_basic_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("\nEarly stopping triggered!")
                break

Loaded existing model from best_basic_model.pt


In [13]:
# evaluate model on test set
vae.eval()
test_metrics = {'total_loss': 0, 'node_loss': 0, 'edge_loss': 0, 'kl_loss': 0}

with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        node_pred, edge_pred, mu, logvar = vae(batch.x, batch.edge_index)
        loss, metrics = loss_function(
            node_pred, edge_pred,
            batch.x, batch.edge_attr,
            mu, logvar
        )

        test_metrics['total_loss'] += loss.item()
        for k, v in metrics.items():
            test_metrics[k] += v

print("\nTest metrics:")
for k, v in test_metrics.items():
    print(f"{k}: {v/len(test_loader):.4f}")


Test metrics:
total_loss: 13272.6769
node_loss: 9239.2612
edge_loss: 1599.3260
kl_loss: 2434.0897


In [None]:
from torch_geometric.data import Data

def generate_molecule(vae, num_nodes=9, device='cuda'):
    """Generate a random molecule with fixed number of nodes."""
    # Sample from latent space
    latent_sample = torch.randn(1, 32).to(device)
    print(f"Latent sample shape: {latent_sample.shape}")

    # Expand latent vector for each node
    z = latent_sample.repeat(num_nodes, 1)
    print(f"Expanded z shape: {z.shape}")

    # Create complete graph
    edge_index = torch.combinations(torch.arange(num_nodes), r=2).t().contiguous()
    edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1).to(device)
    print(f"Edge index shape: {edge_index.shape}")

    # Generate features
    with torch.no_grad():
        node_features, edge_features = vae.decode(z, edge_index)
        print(f"Node features shape: {node_features.shape}")
        print(f"Edge features shape: {edge_features.shape}")

        # Convert to probabilities
        node_probs = torch.softmax(node_features, dim=-1)
        edge_probs = torch.sigmoid(edge_features)
        print(f"Edge probs shape: {edge_probs.shape}")

        # Get discrete structure
        node_types = torch.argmax(node_probs, dim=-1)

        # Take maximum probability across edge features
        # This assumes the edge with highest probability determines existence
        edge_exists = edge_probs.max(dim=1).values > 0.5
        print(f"Edge existence mask shape: {edge_exists.shape}")

        # Select edges based on mask
        final_edge_index = edge_index[:, edge_exists]

    # Create PyG Data object
    molecule = Data(
        x=node_types,
        edge_index=final_edge_index,
        edge_attr=edge_features[edge_exists]  # Include edge features for existing edges
    )

    return molecule

# Generate and inspect molecule
generated_mol = generate_molecule(vae, num_nodes=9)
print(f"Generated molecule with {generated_mol.num_nodes} nodes and {generated_mol.num_edges} edges")
print(f"Node feature shape: {generated_mol.x.shape}")
print(f"Edge index shape: {generated_mol.edge_index.shape}")

Latent sample shape: torch.Size([1, 32])
Expanded z shape: torch.Size([9, 32])
Edge index shape: torch.Size([2, 72])
Node features shape: torch.Size([9, 11])
Edge features shape: torch.Size([72, 4])
Edge probs shape: torch.Size([72, 4])
Edge mask shape: torch.Size([72, 4])


IndexError: too many indices for tensor of dimension 2

tensor([[  6.7315,  -3.1818,  -2.9390,  -1.9929,  -8.6633,   6.7304, -18.2955,
         -18.0176, -18.5171, -17.8618,  -1.4803]], device='cuda:0')
