## Example of GraphEBM: Random Generation

In [18]:
import os
import torch
from torch_geometric.data import DenseDataLoader
from rdkit import RDLogger

from dig.ggraph.dataset import QM9, ZINC250k
from dig.ggraph.method import GraphEBM
from dig.ggraph.evaluation import Rand_Gen_Evaluator

In [19]:
device = torch.device('cuda:0')

#### Prepare Dataset

In [5]:
dataset = ZINC250k(one_shot=True, root='./')
splits = dataset.get_split_idx()
train_set = dataset[splits['train_idx']]
train_dataloader = DenseDataLoader(train_set, batch_size=128, shuffle=True, num_workers=0)

#### Training

Before starting training, we need to define an object `graphebm` as an instance of class `GraphEBM`.

**Skip training**: You can also download our trained models on [ZINC250k](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM_zinc250k_uncond.pt) and [QM9](https://github.com/divelab/DIG_storage/blob/main/ggraph/GraphEBM_qm9_uncond.pt).

In [6]:
graphebm = GraphEBM(n_atom=38, n_atom_type=10, n_edge_type=4, hidden=64, device=device)

In [10]:
graphebm.train_rand_gen(train_dataloader, lr=1e-4, wd=0, max_epochs=20, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, alpha=1, save_interval=1, save_dir='./checkpoints')

  0%|          | 6/1755 [00:08<43:19,  1.49s/it]


KeyboardInterrupt: 

#### Generation

To construct molecules from our generated node matrices and adjacency tensors, we need the `atomic_num_list`, which denotes what atom each dimension of the node matrix corresponds to. `0` denotes the virtual atom type.

In [7]:
### Ignore info output by RDKit
RDLogger.DisableLog('rdApp.error') 
RDLogger.DisableLog('rdApp.warning')

atomic_num_list = dataset.atom_list+[0]
gen_mols = graphebm.run_rand_gen(checkpoint_path='./GraphEBM_zinc250k_uncond.pt.pt', n_samples=10000, c=0, ld_step=150, ld_noise=0.005, ld_step_size=30, clamp=True, atomic_num_list=atomic_num_list)

Loading paramaters from ./checkpoints_zinc/epoch_1.pt
Initializing samples...
Generating samples...


KeyboardInterrupt: 

#### Evaluations

In [15]:
train_smiles = [data.smile for data in dataset[splits['train_idx']]]
res_dict = {'mols':gen_mols, 'train_smiles': train_smiles}
evaluator = Rand_Gen_Evaluator()
results = evaluator.eval(res_dict)
print(res_dict)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

