# Message-Passing Graph Variational Autoencoder

Here I will use principles from message-passing graph neural networks to try to generate
molecules in a graph variational autoencoder. This is based on my solutions to 
`geometric-gnn-dojo/geometric_gnn_101.ipynb`
Some of the code has also been taken from there

might also want to consider autoregressive model

In [1]:
#@title [RUN] Import python modules

import logging

import os
import time
import random
import numpy as np

from scipy.stats import ortho_group

import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential, Sigmoid

import torch_geometric
from torch_geometric.data import Data
from torch_geometric.data import Batch
from torch_geometric.datasets import QM9
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.datasets import QM9
from torch_scatter import scatter

import rdkit.Chem as Chem
from rdkit.Geometry.rdGeometry import Point3D
from rdkit.Chem import QED, Crippen, rdMolDescriptors, rdmolops
from rdkit.Chem.Draw import IPythonConsole

import py3Dmol
from rdkit.Chem import AllChem

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# from google.colab import files
from IPython.display import HTML

print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

PyTorch version 2.5.0+cu124
PyG version 2.6.1


In [2]:
debug_print = False
if debug_print:
    print("Debug mode is ON")
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.DEBUG)
else:
    print("Debug mode is OFF")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        force=True  # Ensure configuration is applied
    )
    for logger_name in ['train_epoch', 'PropertyConditionedVAE', 'ConditionalDecoder', 'Encoder']:
        logging.getLogger(logger_name).setLevel(logging.INFO)

Debug mode is OFF


In [3]:
class CompleteGraph(object):
    """
    This transform adds all pairwise edges into the edge index per data sample,
    then removes self loops, i.e. it builds a fully connected or complete graph
    """
    def __call__(self, data):
        device = data.edge_index.device

        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)

        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)
        col = col.repeat(data.num_nodes)
        edge_index = torch.stack([row, col], dim=0)

        edge_attr = None
        if data.edge_attr is not None:
            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]
            size = list(data.edge_attr.size())
            size[0] = data.num_nodes * data.num_nodes
            edge_attr = data.edge_attr.new_zeros(size)
            edge_attr[idx] = data.edge_attr

        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
        data.edge_attr = edge_attr
        data.edge_index = edge_index

        return data

In [4]:
import torch
import torch.nn as nn
from torch_geometric.utils import to_dense_batch
import torch_geometric as pyg
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
import mygenai
from sklearn.model_selection import train_test_split
import numpy as np

dataset = QM9(root="../data/QM9", transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
# mean, std = mean[:, target].item(), std[:, target].item()



## First, some review
This is the non-generative model that simply predicts properties based on input 

In [5]:
class EquivariantMPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Message Passing Neural Network Layer

        This layer is equivariant to 3D rotations and translations.

        Args:
            emb_dim: (int) - hidden dimension `d`
            edge_dim: (int) - edge feature dimension `d_e`
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        self.mlp_scalar = Sequential(
            Linear(2*emb_dim + edge_dim + 1, emb_dim),  # +1 for distance
            BatchNorm1d(emb_dim),
            ReLU()
        )
        self.mlp_vector = Sequential(
            Linear(2*emb_dim + edge_dim + 1, 1),  # Input: [h_i, h_j, edge_attr, dist]
            BatchNorm1d(1),
            ReLU()
        )

        # update MLPs
        self.mlp_h = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
        )
        self.mlp_pos = Sequential(
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, 1), BatchNorm1d(1), ReLU()
        )

    def forward(self, h, pos, edge_index, edge_attr):
        """
        The forward pass updates node features `h` via one round of message passing.

        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
            edge_attr: (e, d_e) - edge features

        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        return self.propagate(edge_index, h=h, pos=pos, edge_attr=edge_attr)

    def message(self, h_i, h_j, pos_i, pos_j, edge_attr):
        r_ij = pos_j - pos_i # equivariant
        dist = torch.norm(r_ij, dim=-1, keepdim=True)  # invariant

        # Scalar message (invariant features)
        scalar_inputs = torch.cat([h_i, h_j, edge_attr, dist], dim=-1)
        scalar_msg = self.mlp_scalar(scalar_inputs)

        # Vector message (equivariant coordinates)
        # vector message should only depend on rotation/translation invariant quantities
        vector_inputs = torch.cat([h_i, h_j, edge_attr, dist], dim=-1)
        scale = self.mlp_vector(vector_inputs) # (e, 1)
        vector_msg = scale * r_ij # (e, 3)

        return scalar_msg, vector_msg

    def aggregate(self, inputs, index):
        scalar_msgs, vector_msgs = inputs
        scalar_aggr = scatter(scalar_msgs, index, dim=self.node_dim, reduce=self.aggr)
        vector_aggr = scatter(vector_msgs, index, dim=self.node_dim, reduce=self.aggr)
        return scalar_aggr, vector_aggr

    def update(self, aggr_out, h, pos):
        scalar_aggr, vector_aggr = aggr_out

        # Update node features (h)
        h_update = self.mlp_h(torch.cat([h, scalar_aggr], dim=-1))

        # Update node positions (pos)
        scale = self.mlp_pos(scalar_aggr) # (n, 1)
        pos_update = pos + scale * vector_aggr  # (n, 3)

        return h_update, pos_update

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class EquivariantGNNPredictor(Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """Message Passing Neural Network model for graph property prediction

        This model uses both node features and coordinates as inputs, and
        is invariant to 3D rotations and translations (the constituent MPNN layers
        are equivariant to 3D rotations and translations).

        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
            out_dim: (int) - output dimension (fixed to 1)
        """
        super().__init__()

        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)

        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))

        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # Linear prediction head
        # dim: d -> out_dim
        self.lin_pred = Linear(emb_dim, out_dim)

    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns:
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        pos = data.pos

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)

            # Update node features
            h = h + h_update # (n, d) -> (n, d)
            # Note that we add a residual connection after each MPNN layer

            # Update node coordinates
            pos = pos_update # (n, 3) -> (n, 3)

        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)

        return out.view(-1)

  """Message Passing Neural Network Layer


In [6]:
model = EquivariantGNNPredictor(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)

In [None]:
# unit test
def random_orthogonal_matrix(dim=3):
  """Helper function to build a random orthogonal matrix of shape (dim, dim)
  """
  Q = torch.tensor(ortho_group.rvs(dim=dim)).float()
  return Q


def rot_trans_invariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN model/layer) is
    rotation and translation invariant.
    """
    it = iter(dataloader)
    data = next(it)

    # Forward pass on original example
    # Note: We have written a conditional forward pass so that the same unit
    #       test can be used for both the GNN model as well as the layer.
    #       The functionality for layers will be useful subsequently.
    # if isinstance(module, MPNNModel):
    out_1 = module(data)
    # else: # if ininstance(module, MessagePassing):
        # out_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    Q = random_orthogonal_matrix(dim=3)
    t = torch.rand(3)
    # ============ YOUR CODE HERE ==============
    # Perform random rotation + translation on data.
    #
    data.pos = data.pos @ Q  + t
    # ==========================================

    # Forward pass on rotated + translated example
    # if isinstance(module, MPNNModel):
    out_2 = module(data)
    # else: # if ininstance(module, MessagePassing):
        # out_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    # ============ YOUR CODE HERE ==============
    # Check whether output varies after applying transformations.
    #
    return torch.allclose(out_1, out_2, atol=1e-04)
    # ==========================================

def rot_trans_equivariance_unit_test(module, dataloader):
    """Unit test for checking whether a module (GNN layer) is
    rotation and translation equivariant.
    """
    it = iter(dataloader)
    data = next(it)

    out_1, pos_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    Q = random_orthogonal_matrix(dim=3)
    t = torch.rand(3)
    # ============ YOUR CODE HERE ==============
    # Perform random rotation + translation on data.
    #
    data.pos = data.pos @ Q + t # row vectors => post-multiply Q
    # ==========================================

    # Forward pass on rotated + translated example
    out_2, pos_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)

    # ============ YOUR CODE HERE ==============
    # Check whether output varies after applying transformations.
    return torch.allclose(out_1, out_2, atol=1e-04) and torch.allclose(pos_1 @ Q + t, pos_2, atol=1e-04)
    # ==========================================

In [8]:
# ============ YOUR CODE HERE ==============
# Instantiate temporary model, layer, and dataloader for unit testing.
# Remember that we are now unit testing the EquivariantGNNPredictor,
# which is  composed of the EquivariantMPNNLayer.
#
layer = EquivariantMPNNLayer(emb_dim=11, edge_dim=4)
model = EquivariantGNNPredictor(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)
# ==========================================
dataloader = DataLoader(dataset[:1000], batch_size=1, shuffle=True)

# Rotation and translation invariance unit test for MPNN model
print(f"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!")

# Rotation and translation invariance unit test for MPNN layer
print(f"Is {type(layer).__name__} rotation and translation equivariant? --> {rot_trans_equivariance_unit_test(layer, dataloader)}!")

Is EquivariantGNNPredictor rotation and translation invariant? --> True!
Is EquivariantMPNNLayer rotation and translation equivariant? --> True!


In [None]:
# now trying an encoder-decoder architecture
class Encoder(Module):
    def __init__(self, emb_dim=64, in_dim=11, edge_dim=4, latent_dim=32):
        """Encoder module for graph property prediction

        Args:
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
        """
        super().__init__()

        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)

        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(2):
            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))

        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # projections to latent space
        self.mu = Linear(emb_dim, latent_dim)
        self.log_var = Linear(emb_dim, latent_dim)

        # Property prediction (only one: homo-lumo gap)
        self.property_predictor = Sequential(
            Linear(latent_dim, emb_dim),
            ReLU(),
            BatchNorm1d(emb_dim),
            Linear(emb_dim, 1)
        )

    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns:
            out: [(batch_size, d),(batch_size,3)] - updated node features
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        pos = data.pos

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)

            # Update node features
            h = h + h_update # (n, d) -> (n, d)

            # Update node coordinates
            pos = pos_update # (n, 3) -> (n, 3)

        # Pool to graph level
        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        # Get latent parameters and property prediction
        mu = self.mu(h_graph)
        log_var = self.log_var(h_graph)
        property_pred = self.property_predictor(mu)

        return mu, log_var, property_pred

class ConditionalDecoder(Module):
    def __init__(self, latent_dim=32, emb_dim=64, out_node_dim=11, out_edge_dim=4):
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

        # Initial projection from latent+property space
        # expect one property
        self.lin_latent = Linear(latent_dim + 1, emb_dim)

        # Node feature generation
        self.node_decoder = Sequential(
            Linear(emb_dim, emb_dim),
            ReLU(),
            BatchNorm1d(emb_dim),
            Linear(emb_dim, out_node_dim)
        )

        # Position generation
        self.pos_decoder = Sequential(
            Linear(emb_dim, emb_dim),
            ReLU(),
            BatchNorm1d(emb_dim),
            Linear(emb_dim, 3)
        )

        # Number of nodes predictor
        self.num_nodes_predictor = Sequential(
            Linear(emb_dim, emb_dim),
            ReLU(),
            Linear(emb_dim, 1)
        )

        # Edge prediction
        self.edge_existence = Sequential(
            Linear(2 * emb_dim, emb_dim),
            ReLU(),
            BatchNorm1d(emb_dim),
            Linear(emb_dim, 1),
            Sigmoid()
        )

        self.edge_features = Sequential(
            Linear(2 * emb_dim, emb_dim),
            ReLU(),
            BatchNorm1d(emb_dim),
            Linear(emb_dim, out_edge_dim)
        )

    def forward(self, z, target_property, batch_size):
        self.logger.debug(f"Input shapes - z: {z.shape}, target_property: {target_property.shape}")

        # Make sure target_property has correct shape for concatenation
        if target_property.dim() == 3:
            target_property = target_property.squeeze(1)
        if target_property.dim() == 1:
            target_property = target_property.unsqueeze(1)

        z_cond = torch.cat([z, target_property], dim=1)
        h = self.lin_latent(z_cond)

        # Predict number of nodes per graph
        num_nodes = self.num_nodes_predictor(h).sigmoid() * 30 + 5  # 5-35 nodes
        num_nodes = num_nodes.long()

        node_features_list = []
        positions_list = []

        for i in range(batch_size):
            n = num_nodes[i].item()  # Convert tensor to integer
            h_expanded = h[i:i+1].expand(n, -1)

            # Generate node features and positions
            node_feat = self.node_decoder(h_expanded)
            pos = self.pos_decoder(h_expanded)

            node_features_list.append(node_feat)
            positions_list.append(pos)

        node_features = torch.cat(node_features_list, dim=0)
        positions = torch.cat(positions_list, dim=0)

        self.logger.debug(f"Output shapes - node_features: {node_features.shape}, positions: {positions.shape}")

        return node_features, positions, num_nodes

class PropertyConditionedVAE(Module):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, latent_dim=32):
        super().__init__()
        self.logger = logging.getLogger(self.__class__.__name__)

        self.encoder = Encoder(emb_dim, in_dim, edge_dim, latent_dim)
        self.decoder = ConditionalDecoder(latent_dim, emb_dim, in_dim)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        if self.training:
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu

    def forward(self, data, target_property=None):
        logger = logging.getLogger('PropertyConditionedVAE')

        # Encode
        mu, log_var, property_pred = self.encoder(data)
        logger.debug(f"Encoder outputs - mu: {mu.shape}, log_var: {log_var.shape}, property_pred: {property_pred.shape}")

        # Sample from latent space
        z = self.reparameterize(mu, log_var)
        logger.debug(f"Sampled z shape: {z.shape}")

        # Use predicted property if target not provided
        if target_property is None:
            target_property = property_pred
        else:
            # if all properties provided, extract just the HOMO-LUMO gap
            if target_property.size(1) != 1:
                target_property = target_property[:, 4:5]
        logger.debug(f"Target property shape before squeeze: {target_property.shape}")

        # # Ensure target_property is 2D
        # if len(target_property.shape) == 3:
        #     target_property = target_property.squeeze(1)
        # logger.debug(f"Target property shape after squeeze: {target_property.shape}")

        # Decode
        node_features, positions, num_nodes = self.decoder(
            z, target_property, data.batch.max().item() + 1
        )
        logger.debug(f"Decoder outputs - features: {node_features.shape}, positions: {positions.shape}")

        return node_features, positions, mu, log_var, property_pred, num_nodes

    def loss_function(self, node_features, positions, num_nodes, data, mu, log_var,
                    property_pred, property_weight=1.0):
        logger = logging.getLogger(self.__class__.__name__)

        # Log shapes for debugging
        logger.debug(f"Property prediction shape: {property_pred.shape}")
        logger.debug(f"Target property shape: {data.y.shape}")

        # Get batch size
        batch_size = data.batch.max().item() + 1

        # Reconstruction loss (with proper masking for variable size graphs)
        recon_loss = 0
        start_idx = 0
        total_nodes = 0

        for i, n in enumerate(num_nodes):
            n_orig = (data.batch == i).sum()
            n_gen = n.item()
            nodes_to_compare = min(n_gen, n_orig)
            total_nodes += nodes_to_compare

            if nodes_to_compare > 0:
                # Node feature reconstruction - use sum reduction
                recon_loss += F.mse_loss(
                    node_features[start_idx:start_idx + nodes_to_compare],
                    data.x[data.batch == i][:nodes_to_compare],
                    reduction='sum'  # Sum within each graph
                )

                # Position reconstruction - use sum reduction
                recon_loss += F.mse_loss(
                    positions[start_idx:start_idx + nodes_to_compare],
                    data.pos[data.batch == i][:nodes_to_compare],
                    reduction='sum'  # Sum within each graph
                )

            start_idx += n_gen

        # Normalize reconstruction loss by total nodes compared
        if total_nodes > 0:
            recon_loss = recon_loss / total_nodes

        # KL divergence (already normalized by batch size)
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size

        # # Property prediction loss - ensure shapes match
        # if property_pred.shape != data.y.shape:
        #     logger.warning(f"Property shape mismatch: pred={property_pred.shape}, target={data.y.shape}")
        #     # Fix the shape of either property_pred or data.y to match
        #     if property_pred.size(1) == 1 and data.y.size(1) > 1:
        #         # Need to modify your model to output the correct shape (19 columns)
        #         # As a temporary fix, repeat the single value to match the target width
        #         property_pred = property_pred.expand(-1, data.y.size(1))

        target_property = data.y[:, 4:5]  # HOMO-LUMO gap, keep dimension as [batch_size, 1]
        prop_loss = F.mse_loss(property_pred, target_property, reduction='mean')

        # Combine losses with scaling factors
        # Use smaller coefficients to prevent overflow
        total_loss = recon_loss + 0.01 * kl_loss + 0.1 * property_weight * prop_loss

        # Add guard against NaN or Inf
        if not torch.isfinite(total_loss):
            logger.error(f"Non-finite loss detected! recon={recon_loss}, kl={kl_loss}, prop={prop_loss}")
            # Return a backup loss that won't break training
            return torch.tensor(1000.0, device=total_loss.device, requires_grad=True)

        # Log component values for debugging
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"Losses - recon: {recon_loss.item():.4f}, KL: {kl_loss.item():.4f}, prop: {prop_loss.item():.4f}, total: {total_loss.item():.4f}")

        return total_loss

    def generate_molecule(self, target_property, num_samples=1):
        self.eval()
        with torch.no_grad():
            # Sample from prior
            z = torch.randn(num_samples, self.latent_dim).to(next(self.parameters()).device)
            target = torch.ones(num_samples, 1).to(z.device) * target_property # single value for homo-lumo gap

            # Generate
            node_features, positions, num_nodes = self.decoder(z, target, num_samples)

            return node_features, positions, num_nodes

In [24]:
# Data splitting (60/20/20)
train_val_idx, test_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.2,
    random_state=42
)
train_idx, val_idx = train_test_split(
    train_val_idx,
    test_size=0.25,
    random_state=42
)

train_loader = DataLoader(dataset[train_idx], batch_size=128, shuffle=True)
val_loader = DataLoader(dataset[val_idx], batch_size=128, shuffle=False)
test_loader = DataLoader(dataset[test_idx], batch_size=128, shuffle=False)

In [25]:
import time


# logging.basicConfig(level=logging.DEBUG)

def train_epoch(model, optimizer, train_loader, device):
    logger = logging.getLogger('train_epoch')
    model.train()
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        batch = batch.to(device)
        logger.debug(f"\nBatch {batch_idx}:")
        logger.debug(f"Batch properties: x={batch.x.shape}, pos={batch.pos.shape}, batch={batch.batch.shape}")
        optimizer.zero_grad()

        # Forward pass
        node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)

        # Calculate loss
        loss = model.loss_function(
            node_features, positions, num_nodes,
            batch, mu, log_var, property_pred
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate(model, val_loader, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)

            # Forward pass with all outputs
            node_features, positions, mu, log_var, property_pred, num_nodes = model(batch)

            # Calculate loss with all parameters
            loss = model.loss_function(
                node_features, positions, num_nodes,
                batch, mu, log_var, property_pred
            )
            total_loss += loss.item()

    return total_loss / len(val_loader)


# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-6
)

# Training loop
n_epochs = 100
best_val_loss = float('inf')
patience = 10
patience_counter = 0

batch = next(iter(train_loader))
print(f"Target property shape: {batch.y.shape}")
print(f"Target property sample: {batch.y[0]}")
print(f"Input node features shape: {batch.x.shape}")  # Should be (N, 11)
print(f"Input positions shape: {batch.pos.shape}")    # Should be (N, 3)
print(f"Number of nodes: {batch.num_nodes}")
print(f"Batch size: {batch.batch.max().item() + 1}")
# TODO : implement gradient clipping?
for epoch in range(n_epochs):
    # Train
    train_loss = train_epoch(model, optimizer, train_loader, device)

    # Validate
    val_loss = validate(model, val_loader, device)

    # Learning rate scheduling
    scheduler.step(val_loss)

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            # Save best model
            # Save model with timestamp in a more readable format
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            save_path = f'best_vae_model_{timestamp}.pt'
            torch.save(model.state_dict(), save_path)
            print(f'Saved best model to: {save_path}')
            break

    # Print progress
    print(f'Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')

# Load best model for testing
model.load_state_dict(torch.load('best_vae_model.pt'))

# Test final model
test_loss = validate(model, test_loader, device)
print(f'Final Test Loss: {test_loss:.4f}')

Target property shape: torch.Size([128, 19])
Target property sample: tensor([ 1.7801, -0.4565, -0.2499,  0.1850,  0.2999, -0.5561, -0.3152, -0.1797,
        -0.1798, -0.1798, -0.1797, -0.8097,  0.3219,  0.3204,  0.3207,  0.3240,
        -0.0039,  0.1084,  0.0737])
Input node features shape: torch.Size([2241, 11])
Input positions shape: torch.Size([2241, 3])
Number of nodes: 2241
Batch size: 128
Epoch 000 | Train Loss: 23.5831 | Val Loss: 235483165571.1219
Epoch 001 | Train Loss: 24.9270 | Val Loss: 286896091551754.0000
Epoch 002 | Train Loss: 46.5103 | Val Loss: 535562400536346.2500


KeyboardInterrupt: 