## Binding Affinity Prediction with ACNN

In [None]:
import pandas as pd
from rdkit import Chem
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

In [None]:
from dgllife.data import PDBBind

trainset = PDBBind(subset='refined')

In [None]:
import dgl

def collate(samples):
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch_hetero(graphs)
    return batched_graph, torch.tensor(labels)

trainset_list = []
glist = trainset[:][3]
label_list = trainset[:][4]

for h in range(len(glist)):
    trainset_list.append(tuple([glist[h], label_list[h]]))
    
dataset = tuple(trainset_list)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True,
                         collate_fn=collate)

In [None]:
from dgllife.model.model_zoo.acnn import ACNN

model = ACNN()

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch 

loss_func = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(1000):
    epoch_loss = 0
    for step, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label.reshape(-1,1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        print("Epoch {} | Step {} | loss {} |".format(epoch, step, loss.detach().item()))
        if step % 10 == 0:
            torch.save(model.state_dict(), 'COVID19_binding_affinity.pth')
    epoch_loss /= (step + 1) 
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    torch.save(model.state_dict(), 'COVID19_binding_affinity.pth')
    epoch_losses.append(epoch_loss)

## Molecule Generation with DGMG

In [None]:
from rdkit.Chem import rdmolfiles, rdmolops
from utils import Subset
from utils import Optimizer
from dgl.data.chem import utils
from dgl.model_zoo.chem import load_pretrained
from dgl.model_zoo.chem.dgmg import MoleculeEnv
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import copy

In [None]:
from rdkit import Chem

atom_types=['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', 'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr', 'Pt', 'Hg', 'Pb']
bond_types=[Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
node_hidden_size=128
num_prop_rounds=2
dropout=0.2
nepochs=400
batch_size=1
lr=1e-4

In [None]:
molecule_env = MoleculeEnv(atom_types=atom_types, bond_types=bond_types)
smile_data = pd.read_csv('COVID19_molecule.csv', index_col=0)
subs = Subset(smile_data['SMILES'].tolist(),'canonical', molecule_env)
loader = DataLoader(subs, batch_size)

In [None]:
from dgllife.model.model_zoo.dgmg import DGMG

model = DGMG(atom_types=atom_types,
                            bond_types=bond_types,
                            node_hidden_size=node_hidden_size,
                            num_prop_rounds=num_prop_rounds,
                            dropout=dropout)

In [None]:
optimizer = Optimizer(lr, Adam(model.parameters(), lr))
model.train()
for epoch in range(100):
    for step, data in enumerate(loader):
        optimizer.zero_grad()
        logp = model(actions=data, compute_log_prob=True)
        prob=logp.detach().exp()
        loss_averaged = - logp
        prob_averaged = prob
        optimizer.backward_and_step(loss_averaged)
        torch.save(model.state_dict(), 'COVID19.pth')
        print("Epoch {} | Step {} | loss_averaged {} | Output {} |".format(epoch, step, loss_averaged.item(), model(rdkit_mol=True, max_num_steps=400)))