In [1]:
import matplotlib.pylab as pl
import matplotlib.pyplot as plt

import gnn
from rdkit import Chem
from rdkit.Chem import Draw
import utils
import util_metrics

import torch
import torch.nn as nn
import torch.nn.functional as F

import gnn_decoder
from torch_geometric.data import Data,Batch
from torch_geometric.utils import to_dense_adj,to_dense_batch,get_embeddings
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
import vgae
import copy
import visualize
import wandb

Import canonical SMILES. Transform to dataset.

In [None]:
import pickle
file_path='../data/QM9_canonical.pkl'
with open(file_path, 'rb') as file:
    dataset_smiles = pickle.load(file)

dataset=utils.smiles2data_edge_index(dataset_smiles)
dataset=utils.filter_hydrogen(dataset)

Set hyperparameters.

In [None]:
wandb.login(key="your-wandb-key-here")
wandb.init(
    project="Graph Generation",
    config={
        "epochs": 20,
        "batch_size": 32,
        "encoder_norm":"layer",
        "lr": 1e-4,
        "dropout": 0.1,
        "sample_size":100,
        "loss":"MSE",
        "eval":False,
        "hard":True,
        "soft":True,
        "dataset_size": 129428,
        "latent_dim": 30
        })

Initialize model.

In [None]:
latent_dims=wandb.config["latent_dim"]
edge_dim=4
node_dim=4
train_loader=DataLoader(dataset,wandb.config["batch_size"],shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder=gnn.GINE(in_channels=node_dim,hidden_channels=[16,32,64,128],out_channels=128,norm=wandb.config["encoder_norm"],num_layers=4,
                  act="leakyrelu",edge_dim=4,dropout=0.1,add_self_loops=True,affine=False).to(device)
decoder=gnn_decoder.Generator_Decoder(latent_dim=latent_dims,node_dim=5,edge_dim=5,num_layers=4).to(device)
vgae_model=vgae.VGAE(encoder=encoder,decoder=decoder,latent_dims=latent_dims,embedding="graph",eval=wandb.config["eval"]).to(device)
vgae_model.set_encoder()

Train model for given configuration. Log and evaluate model.

In [None]:
vgae_model.train()
optimizer = torch.optim.Adam(vgae_model.parameters(), lr=wandb.config["lr"],weight_decay=1e-5)

#for epoch in range(wandb.config["epochs"]):
for epoch in range(wandb.config["epochs"]):
    vgae_model.set_encoder()
    for data in train_loader:
      data = data.to(device)
      optimizer.zero_grad()
      data_rec,mu,logvar,_= vgae_model(data)
      features,edges,ot_dists=vgae_model.sinkhorn_solver.soft_topk(data,data_rec,loss=wandb.config["loss"],
                                                                   hard=wandb.config["hard"],scatter=True,soft=wandb.config["soft"])
      reconstruction_loss=features+edges
      kl_loss=vgae_model.sinkhorn_solver.kl_loss(mu,logvar)
      loss=reconstruction_loss+0.1*kl_loss
      loss.backward()
      optimizer.step()
      #Log data and evaluate reconstruction quality
      valid_rec,num_rec=util_metrics.evaluate_reconstruction(data,data_rec)
      wandb.log({"feature_loss":features,"edge_loss":edges,"reconstruction_loss": reconstruction_loss,
                  "loss": loss,"ot_dist":ot_dists,"kl_loss":kl_loss,"valid_rec":valid_rec,"frac_rec":num_rec})
      optimizer.step()

wandb.finish()


Evaluate Model

In [None]:
vgae_model.eval()
sample=vgae_model.sample(wandb.config["sample_size"])    
validity,unique,novelty,_=util_metrics.compute_metrics(sample,dataset_smiles)

In [None]:
PATH="YOUR/PATH/HERE"
torch.save(vgae_model.state_dict(), PATH)

Generate Samples. Compute fraction of valid, unique, and novel molecules. Return SMILES representation of all novel SMILES.
Plot generated molecules

In [None]:
vgae_model.eval()
sample=vgae_model.sample(1000)
validity,unique,novel,smiles_list=util_metrics.compute_metrics(sample,dataset_smiles)

In [None]:
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import AllChem as Chem
mol_list = [Chem.MolFromSmiles(smiles) for smiles in list(smiles_list)]
Draw.MolsToGridImage(mol_list[:50], molsPerRow=10, subImgSize=(400,400),maxMols=100)