# Visualise NN-generated molecules

In [7]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9

from mygenai.models.graphvae import PropertyConditionedVAE
from mygenai.utils.data_exploration import to_rdkit, visualise_molecule, moltosvg
from mygenai.utils.transforms import CompleteGraph


In [8]:
# load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PropertyConditionedVAE(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, latent_dim=32)
model.load_state_dict(torch.load('best_vae_model_recon_only.pt'))

  model.load_state_dict(torch.load('best_vae_model_recon_only.pt'))


<All keys matched successfully>

In [9]:
model = model.to(device)
model.eval()

PropertyConditionedVAE(
  (encoder): Encoder(
    (lin_in): Linear(in_features=11, out_features=64, bias=True)
    (convs): ModuleList(
      (0-1): 2 x EquivariantMPNNLayer(emb_dim=64, aggr=add)
    )
    (mu): Linear(in_features=64, out_features=32, bias=True)
    (log_var): Linear(in_features=64, out_features=32, bias=True)
    (property_predictor): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=1, bias=True)
    )
  )
  (decoder): ConditionalDecoder(
    (lin_latent): Linear(in_features=33, out_features=64, bias=True)
    (node_decoder): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Linear(in_features=64, out_features=11, bias=True)
      (4): Tanh()
    )


In [12]:
dataset = QM9(root="../data/QM9", transform=CompleteGraph())
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [13]:
def reconstruct_molecule(model, data):
    data = data.to(device)
    with torch.no_grad():
        # Forward pass
        node_features, positions, mu, log_var, property_pred, num_nodes = model(data)

        # Create a copy of the data object for the reconstruction
        recon_data = data.clone()

        # Replace features and positions with reconstructed ones
        # We use only the first graph since batch_size=1
        n = num_nodes[0].item()
        recon_data.x = node_features[:n]
        recon_data.pos = positions[:n]

        return recon_data

In [17]:
test_mol = dataset[0]
test_mol.batch = torch.zeros(test_mol.x.size(0), dtype=torch.long, device=test_mol.x.device)
visualise_molecule(test_mol)

<py3Dmol.view at 0x7fb1ae184d40>

In [19]:
recon_mol = reconstruct_molecule(model, test_mol)

In [25]:
print("Original molecule:")
print("SMILES:", test_mol.smiles)
print("Z", test_mol.z)
print("pos", test_mol.pos)
print("Reconstructed molecule:")
print("SMILES:", recon_mol.smiles)
print("Z", recon_mol.z)
print("pos", recon_mol.pos)


Original molecule:
SMILES: [H]C([H])([H])[H]
Z tensor([6, 1, 1, 1, 1], device='cuda:0')
pos tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]], device='cuda:0')
Reconstructed molecule:
SMILES: [H]C([H])([H])[H]
Z tensor([6, 1, 1, 1, 1], device='cuda:0')
pos tensor([[6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
        [6.9755, 6.9755, 6.9755],
    