# A Simple Variational Autoencoder

Here we explore using a simple variational autoencoder trained on the QM9 dataset to generate new molecules.

The goal of this notebook is to provide a proof of concept that you can use a VAE to generate new molecules. Some physical restrictions will also be enforced. The model will also be saved for easy reloading in later notebooks.

In [None]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai

dataset = QM9(root="../data/QM9")

In [10]:
# check if CUDA is available
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [8]:
class GraphVAE(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim):
        super(GraphVAE, self).__init__()

        # Graph Encoder
        self.encoder = pyg_nn.GCNConv(in_channels, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance

        # Graph Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_channels)
        )

    def forward(self, x, edge_index):
        # Encode
        h = self.encoder(x, edge_index)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        # Decode
        recon_x = self.decoder(z)
        return recon_x, mu, logvar


In [None]:
loader = DataLoader(dataset, batch_size=128, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to GPU
vae = GraphVAE(in_channels=dataset.num_features, hidden_dim=64, latent_dim=32).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, logvar):
    mse_loss = torch.nn.functional.mse_loss(recon_x, x, reduction="sum")
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return mse_loss + kl_divergence

# Training Loop
for epoch in range(20):
    loss_total = 0
    for batch in loader:
        # Move batch data to GPU
        batch = batch.to(device)

        optimizer.zero_grad()
        recon_x, mu, logvar = vae(batch.x, batch.edge_index)
        loss = loss_function(recon_x, batch.x, mu, logvar)
        loss.backward()
        optimizer.step()
        loss_total += loss.item()
    print(f"Epoch {epoch+1}, Loss: {loss_total}")

Using device: cuda
Epoch 1, Loss: 814.581787109375
Epoch 2, Loss: 714.4912109375
Epoch 3, Loss: 606.0233154296875


KeyboardInterrupt: 