In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch" # Comment out for tensorflow backend

from molexpress import layers
from molexpress.datasets import featurizers
from molexpress.datasets import encoders

from rdkit import Chem

import torch

## 1. Featurizers

In [None]:
mol = Chem.MolFromSmiles('CCO')

print(featurizers.AtomType(vocab={'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'O'}, oov=True)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=True)(mol.GetAtoms()[0]))

## 2. Encoder

In [None]:
atom_featurizers = [
    featurizers.AtomType({'C', 'O', 'N'}),
    featurizers.Hybridization(),
]

bond_featurizers = [
    featurizers.BondType()
]

encoder = encoders.MolecularGraphEncoder(
    atom_featurizers=atom_featurizers, 
    bond_featurizers=bond_featurizers,
    self_loops=True # adds one dim to edge state
)

encoder(mol)

## 3. Dataset

In [None]:
x_dummy = ['CC', 'CC', 'CCO', 'CCCN']
y_dummy = [1., 2., 3., 4.]


class TinyDataset(torch.utils.data.Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        x = encoder(x)
        return x, y

torch_dataset = TinyDataset(x_dummy, y_dummy)

dataset = torch.utils.data.DataLoader(
    torch_dataset, batch_size=2, collate_fn=encoder._collate_fn)

for x, y in dataset:
    print(f'x = {x}\ny = {y}', end='\n' + '---' * 30 + '\n')

## 4. Model

In [None]:
class TinyGCNModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.gcn1 = layers.GINConv(32)
        self.gcn2 = layers.GINConv(32)
        self.readout = layers.Readout()
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x):
        x = self.gcn1(x)
        x = self.gcn2(x)
        x = self.readout(x)
        x = self.linear(x)
        return x

model = TinyGCNModel().to('cuda')

## 5. Fit

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.00001, momentum=0.9)
loss_fn = torch.nn.MSELoss()

for _ in range(30):
    loss_sum = 0.
    for x, y in dataset:
        optimizer.zero_grad()
    
        outputs = model(x)
    
        y = torch.tensor(y, dtype=torch.float32).to('cuda')
        loss = loss_fn(outputs, y[:, None])
        loss.backward()
        optimizer.step()

        loss_sum += loss
        
    print(loss_sum)