# Visualise NN-generated molecules

In [1]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9

from mygenai.models.graphvae import PropertyConditionedVAE
from mygenai.utils.visualisation import to_rdkit, visualise_molecule, moltosvg
from mygenai.utils.transforms import CompleteGraph
import logging


In [2]:
# load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE(num_layers=4, emb_dim=64, edge_dim=4, latent_dim=32)
model.load_state_dict(torch.load('best_vae_model_20250420_220525.pt'))

  model.load_state_dict(torch.load('best_vae_model_20250420_220525.pt'))


<All keys matched successfully>

In [3]:
model = model.to(device)
model.eval()

PropertyConditionedVAE(
  (encoder): Encoder(
    (lin_in): Linear(in_features=5, out_features=64, bias=True)
    (convs): ModuleList(
      (0-1): 2 x EquivariantMPNNLayer(emb_dim=64, aggr=add)
    )
    (mu): Linear(in_features=64, out_features=32, bias=True)
    (log_var): Linear(in_features=64, out_features=32, bias=True)
    (property_predictor): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (decoder): ConditionalDecoder(
    (lin_latent): Linear(in_features=33, out_features=64, bias=True)
    (node_decoder): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=5, bias=True)
      (4): Sigmoid()
    )

In [4]:
dataset = QM9(root="../data/QM9", transform=CompleteGraph())
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
# focus on just using the one-hot encoding of the atomic number, for simplicity for now
dataset.data.x = dataset.data.x[:, :5]

# Normalize distances in the dataset
fixed_max_distance = 2.0

dataloader = DataLoader(dataset, batch_size=1, shuffle=True)



In [None]:
from torch_geometric.data import Batch
# pass debug logger to the model
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.getLogger('PropertyConditionedVAE').setLevel(logging.DEBUG)
logging.getLogger('ConditionalDecoder').setLevel(logging.DEBUG)
recon_weight = 1.0
kl_weight = 0.01
property_weight = 0.

batch_data = dataset[:100]
batch = Batch.from_data_list(batch_data).to(device)
with torch.no_grad():  # Disable gradient computation
    outputs = model(batch)
node_features, distances, directions, edge_features, num_nodes, mu, log_var, property_pred = outputs

# Compute the loss
loss = model.loss_function(
    node_features=node_features,
    distances=distances,
    directions=directions,
    edge_features=edge_features,
    num_nodes=num_nodes,
    data=batch,
    mu=mu,
    log_var=log_var,
    property_pred=property_pred,
    property_weight=property_weight,  # Use the same weights as during training
    recon_weight=recon_weight,
    kl_weight=kl_weight
)

2025-04-20 22:52:08,820 - PropertyConditionedVAE - DEBUG - Input data - batch_size: 100, nodes: 1005
2025-04-20 22:52:08,820 - PropertyConditionedVAE - DEBUG - Forward called without target_property (None)
2025-04-20 22:52:08,823 - PropertyConditionedVAE - DEBUG - Encoder outputs - mu: torch.Size([100, 32]), log_var: torch.Size([100, 32]), property_pred: torch.Size([100, 1])
2025-04-20 22:52:08,823 - PropertyConditionedVAE - DEBUG - Sampled z shape: torch.Size([100, 32])
2025-04-20 22:52:08,824 - PropertyConditionedVAE - DEBUG - Using encoder prediction for property, shape: torch.Size([100, 1])
2025-04-20 22:52:08,824 - ConditionalDecoder - DEBUG - Input shapes - z: torch.Size([100, 32]), target_property: torch.Size([100, 1])
2025-04-20 22:52:08,825 - ConditionalDecoder - DEBUG - Direction decoder output norm: 1.000000
2025-04-20 22:52:08,826 - ConditionalDecoder - DEBUG - Output shapes - node_features: torch.Size([1005, 5]), distances: torch.Size([10070, 1]), directions: torch.Size([1

In [9]:
directions.norm()

tensor(100.3494, device='cuda:0')

In [None]:
def reconstruct_molecule(model, data):
    data = data.to(device)
    with torch.no_grad():
        # Forward pass
        node_features, positions, mu, log_var, property_pred, num_nodes = model(data)

        # Create a copy of the data object for the reconstruction
        recon_data = data.clone()

        # Replace features and positions with reconstructed ones
        # Use the actual number of nodes in the original data
        n_orig = data.x.size(0)
        n_gen = min(num_nodes[0].item(), n_orig)

        recon_data.x = node_features[:n_gen]
        recon_data.pos = positions[:n_gen]

        # If the generated number of nodes is less than original, trim the data
        if n_gen < n_orig:
            recon_data.edge_index = data.edge_index[:, data.edge_index[0] < n_gen]
            recon_data.edge_index = recon_data.edge_index[:, recon_data.edge_index[1] < n_gen]
            if hasattr(recon_data, 'edge_attr') and recon_data.edge_attr is not None:
                mask = (recon_data.edge_index[0] < n_gen) & (recon_data.edge_index[1] < n_gen)
                recon_data.edge_attr = recon_data.edge_attr[mask]

        return recon_data

In [35]:
test_mol = dataset[0]
test_mol.batch = torch.zeros(test_mol.x.size(0), dtype=torch.long, device=test_mol.x.device)
visualise_molecule(test_mol)

<py3Dmol.view at 0x7fb1ae185160>

In [36]:
recon_mol = reconstruct_molecule(model, test_mol)

In [37]:
print("Original molecule:")
print("SMILES:", test_mol.smiles)
print("Z", test_mol.z)
print("pos", test_mol.pos)
print("Reconstructed molecule:")
print("SMILES:", recon_mol.smiles)
print("Z", recon_mol.z)
print("pos", recon_mol.pos)


Original molecule:
SMILES: [H]C([H])([H])[H]
Z tensor([6, 1, 1, 1, 1], device='cuda:0')
pos tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]], device='cuda:0')
Reconstructed molecule:
SMILES: [H]C([H])([H])[H]
Z tensor([6, 1, 1, 1, 1], device='cuda:0')
pos tensor([[6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755]], device='cuda:0')


In [39]:
print(recon_mol)
recon_mol.edge_index

Data(x=[5, 11], edge_index=[2, 20], edge_attr=[20, 4], y=[1, 19], pos=[5, 3], z=[5], smiles='[H]C([H])([H])[H]', name='gdb_1', idx=[1], batch=[5])


tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
        [1, 2, 3, 4, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 0, 1, 2, 3]],
       device='cuda:0')