# A Simple Variational Autoencoder

Here we explore using a simple variational autoencoder trained on the QM9 dataset to generate new molecules.

The goal of this notebook is to provide a proof of concept that you can use a VAE to generate new molecules. Some physical restrictions will also be enforced. The model will also be saved for easy reloading in later notebooks.

In [None]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai

dataset = QM9(root="../data/QM9")

In [10]:
# check if CUDA is available
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

True
1
NVIDIA GeForce RTX 3050 Ti Laptop GPU


In [None]:
# TODO the decoder currently only has node features (atoms),
#   but we also need edge features (bonds) to reconstruct the graph
class GraphVAE(nn.Module):
    def __init__(self, in_channels, hidden_dim, latent_dim):
        super(GraphVAE, self).__init__()

        # Graph Encoder
        self.encoder = pyg_nn.GCNConv(in_channels, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)  # Mean
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)  # Log variance

        # Graph Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, in_channels)
        )

    def forward(self, x, edge_index):
        # Encode
        h = self.encoder(x, edge_index)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.exp(0.5 * logvar) * torch.randn_like(mu)

        # Decode
        recon_x = self.decoder(z)
        return recon_x, mu, logvar


In [None]:
loader = DataLoader(dataset, batch_size=128, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to GPU
vae = GraphVAE(in_channels=dataset.num_features, hidden_dim=64, latent_dim=32).to(device)

optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, logvar):
    mse_loss = torch.nn.functional.mse_loss(recon_x, x, reduction="sum")
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return mse_loss + kl_divergence

# TODO: training, test and validation splits

# Training Loop
for epoch in range(20):
    loss_total = 0
    for batch in loader:
        # Move batch data to GPU
        batch = batch.to(device)

        optimizer.zero_grad()
        recon_x, mu, logvar = vae(batch.x, batch.edge_index)
        loss = loss_function(recon_x, batch.x, mu, logvar)
        loss.backward()
        optimizer.step()
        loss_total += loss.item()
    print(f"Epoch {epoch+1}, Loss: {loss_total}")

Using device: cuda
Epoch 1, Loss: 12982132.532470703
Epoch 2, Loss: 6653357.563659668
Epoch 3, Loss: 5738879.426879883
Epoch 4, Loss: 5044610.07434082
Epoch 5, Loss: 4919992.088256836
Epoch 6, Loss: 4884896.159057617
Epoch 7, Loss: 4875195.200317383
Epoch 8, Loss: 4868426.782836914
Epoch 9, Loss: 4862542.235900879
Epoch 10, Loss: 4864959.352783203
Epoch 11, Loss: 4861159.000854492
Epoch 12, Loss: 4856974.647644043
Epoch 13, Loss: 4855976.251159668
Epoch 14, Loss: 4855054.975708008
Epoch 15, Loss: 4851494.086303711
Epoch 16, Loss: 4856495.376403809
Epoch 17, Loss: 4849909.075866699
Epoch 18, Loss: 4855618.200927734
Epoch 19, Loss: 4847263.398742676
Epoch 20, Loss: 4856062.029968262


In [None]:
# Generate a random latent vector
latent_sample = torch.randn(1, 32).to(device)

# Decode the latent vector to get molecule features
with torch.no_grad():
    generated_features = vae.decoder(latent_sample)

# Print the generated features
print("Generated molecule features:", generated_features.shape)
print(generated_features)

# Note: This generates the feature vector for a molecule.
# To convert this into a proper molecular structure, you would need additional
# post-processing steps to convert the features back into a valid molecular graph

Generated molecule features: torch.Size([1, 11])
tensor([[ 9.7741e-01,  8.8419e-03,  8.6682e-03, -8.1913e-03, -2.4448e-04,
          1.1425e+00,  6.6321e-03, -3.7067e-04, -1.6801e-04, -3.3279e-05,
         -3.5893e-02]], device='cuda:0')
