In [1]:
import os
os.environ["KERAS_BACKEND"] = "torch" # Comment out for tensorflow backend

from molexpress import layers
from molexpress.datasets import featurizers
from molexpress.datasets import encoders

from rdkit import Chem

import torch

## 1. Featurizers

In [2]:
mol = Chem.MolFromSmiles('CCO')

print(featurizers.AtomType(vocab={'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'O'}, oov=True)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=True)(mol.GetAtoms()[0]))

[0.]
[0. 1.]
[1. 0.]
[1. 0. 0.]
[1. 0. 0. 0.]


## 2. Encoder

In [3]:
atom_featurizers = [
    featurizers.AtomType({'C', 'O', 'N'}),
    featurizers.Hybridization(),
]

bond_featurizers = [
    featurizers.BondType()
]

encoder = encoders.MolecularGraphEncoder(
    atom_featurizers=atom_featurizers, 
    bond_featurizers=bond_featurizers,
    self_loops=True # adds one dim to edge state
)

encoder(mol)

{'node_state': array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0., 0., 0.]], dtype=float32),
 'edge_src': array([0, 0, 1, 1, 1, 2, 2], dtype=int32),
 'edge_dst': array([0, 1, 0, 1, 2, 1, 2], dtype=int32),
 'edge_state': array([[0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 1.]], dtype=float32)}

## 3. Dataset

In [4]:
x_dummy = ['CC', 'CC', 'CCO', 'CCCN']
y_dummy = [1., 2., 3., 4.]


class TinyDataset(torch.utils.data.Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        x = encoder(x)
        return x, y

torch_dataset = TinyDataset(x_dummy, y_dummy)

dataset = torch.utils.data.DataLoader(
    torch_dataset, batch_size=2, collate_fn=encoder._collate_fn)

for x, y in dataset:
    print(f'x = {x}\ny = {y}', end='\n' + '---' * 30 + '\n')

x = {'node_state': array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=float32), 'edge_state': array([[0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32), 'edge_src': array([0, 0, 1, 1, 2, 2, 3, 3]), 'edge_dst': array([0, 1, 0, 1, 2, 3, 2, 3]), 'graph_indicator': array([0, 0, 1, 1])}
y = [1. 2.]
------------------------------------------------------------------------------------------
x = {'node_state': array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [

## 4. Model

In [5]:
class TinyGCNModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.gcn1 = layers.GINConv(32)
        self.gcn2 = layers.GINConv(32)
        self.readout = layers.Readout()
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x):
        x = self.gcn1(x)
        x = self.gcn2(x)
        x = self.readout(x)
        x = self.linear(x)
        return x

model = TinyGCNModel().to('cuda')

## 5. Fit

In [7]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
loss_fn = torch.nn.MSELoss()

for _ in range(30):
    loss_sum = 0.
    for x, y in dataset:
        optimizer.zero_grad()
    
        outputs = model(x)
    
        y = torch.tensor(y, dtype=torch.float32).to('cuda')
        loss = loss_fn(outputs, y[:, None])
        loss.backward()
        optimizer.step()

        loss_sum += loss
        
    print(loss_sum)

tensor(1.7158, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.7103, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.7018, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6916, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6810, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6712, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6627, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6557, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6499, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6447, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6398, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6346, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6289, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6227, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6158, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6085, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.6009, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1.5930, device='cuda:0', grad_fn=<AddBack