In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
from molfeat.trans.graph.adj import PYGGraphTransformer
from molfeat.calc.atom import AtomCalculator
from molfeat.calc.bond import EdgeMatCalculator

### Dataset
Define dataset

In [22]:
import pandas as pd

df = pd.read_csv("https://raw.githubusercontent.com/deepchem/deepchem/master/datasets/delaney-processed.csv")

In [23]:
df.head() 

Unnamed: 0,Compound ID,ESOL predicted log solubility in mols per litre,Minimum Degree,Molecular Weight,Number of H-Bond Donors,Number of Rings,Number of Rotatable Bonds,Polar Surface Area,measured log solubility in mols per litre,smiles
0,Amigdalin,-0.974,1,457.432,7,3,7,202.32,-0.77,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...
1,Fenfuram,-2.885,1,201.225,1,2,2,42.24,-3.3,Cc1occc1C(=O)Nc2ccccc2
2,citral,-2.579,1,152.237,0,0,4,17.07,-2.06,CC(C)=CCCC(C)=CC(=O)
3,Picene,-6.618,2,278.354,0,5,0,0.0,-7.87,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43
4,Thiophene,-2.232,2,84.143,0,1,0,0.0,-1.33,c1ccsc1


In [24]:
df = df.rename(columns={"Compound ID": "id", "measured log solubility in mols per litre":"solubility"})[["smiles", "id", "solubility"]]

In [25]:
df.head()

Unnamed: 0,smiles,id,solubility
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,Amigdalin,-0.77
1,Cc1occc1C(=O)Nc2ccccc2,Fenfuram,-3.3
2,CC(C)=CCCC(C)=CC(=O),citral,-2.06
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,Picene,-7.87
4,c1ccsc1,Thiophene,-1.33


### Featurizer
Define featurizer

In [26]:
trans = PYGGraphTransformer(atom_featurizer=AtomCalculator(), bond_featurizer=EdgeMatCalculator())

### Network + Training
Define GNN

In [27]:
import torch
import numpy as np
import torch.nn.functional as F
from torch_geometric.nn.models import AttentiveFP
from tqdm.auto import tqdm
DEVICE = "cpu"

In [40]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DTset(Dataset):
    def __init__(self, smiles, y, featurizer):
        super().__init__()
        self.smiles = smiles
        self.featurizer = featurizer
        self.featurizer.auto_self_loop()
        self.y = torch.tensor(y).unsqueeze(-1).float()
        self.transformed_mols = self.featurizer(smiles)

    @property
    def num_atom_features(self):
        return self.featurizer.atom_dim

    @property
    def num_output(self):
        return self.y.shape[-1]
    
    def __len__(self):
        return len(self.transformed_mols)
    
    @property
    def num_bond_features(self):
        return self.featurizer.bond_dim
    
    def collate_fn(self, **kwargs):
        return self.featurizer.get_collate_fn(**kwargs)
    
    def __getitem__(self, index):
        return self.transformed_mols[index], self.y[index]

In [41]:
dataset = DTset(df.smiles.values, df.solubility.values, trans)

In [48]:
model = AttentiveFP(in_channels=dataset.num_atom_features, 
                    hidden_channels=128, 
                    out_channels=dataset.num_output, 
                    num_layers=2, num_timesteps=2,
                    dropout=0.2, 
                    edge_dim=dataset.num_bond_features)
model = model.to(DEVICE).float()

loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=dataset.collate_fn(return_pair=False))

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

with tqdm(range(2)) as pbar:
    for epoch in pbar:
        losses = []
        for data in loader:
            data = data.to(DEVICE)
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.edge_attr, data.batch)
            loss = F.mse_loss(out, data.y)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        pbar.set_description(f"Epoch {epoch} - Loss {np.mean(losses):.3f}")


  0%|          | 0/2 [00:00<?, ?it/s]

  loss = F.mse_loss(out, data.y)
  loss = F.mse_loss(out, data.y)
