In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch" # Comment out for tensorflow backend

from molexpress import layers
from molexpress.datasets import features
from molexpress.datasets import encoders

from rdkit import Chem

import keras 
import torch

## 1. Features

In [None]:
mol = Chem.MolFromSmiles('CCO')

print(features.AtomType(allowable_set={'O'}, oov=False)(mol.GetAtoms()[0]))
print(features.AtomType(allowable_set={'O'}, oov=True)(mol.GetAtoms()[0]))
print(features.AtomType(allowable_set={'C', 'O'}, oov=False)(mol.GetAtoms()[0]))
print(features.AtomType(allowable_set={'C', 'O', 'N'}, oov=False)(mol.GetAtoms()[0]))
print(features.AtomType(allowable_set={'C', 'O', 'N'}, oov=True)(mol.GetAtoms()[0]))

## 2. Featurizer

In [None]:
atom_featurizer = features.Featurizer([
    features.AtomType({'C', 'O', 'N'}),
    features.Hybridization(),
])

bond_featurizer = features.Featurizer([
    features.BondType()
])

print(mol.GetAtoms()[0].GetSymbol(), atom_featurizer(mol.GetAtoms()[0]))
print(mol.GetBonds()[0].GetBondType(), bond_featurizer(mol.GetBonds()[0]))

## 3. Encoder

In [None]:
encoder = encoders.MolecularGraphEncoder(
    atom_featurizer=atom_featurizer, 
    bond_featurizer=bond_featurizer,
    self_loops=True # adds one dim to edge state
)

encoder(mol)

## 4. Dataset

In [None]:
x_dummy = ['CC', 'CC', 'CCO', 'CCCN']
y_dummy = [1., 2., 3., 4.]


class TinyDataset(torch.utils.data.Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]
        x = encoder(x)
        return x, y

torch_dataset = TinyDataset(x_dummy, y_dummy)

dataset = torch.utils.data.DataLoader(
    torch_dataset, batch_size=2, collate_fn=encoder._collate_fn)

for x, y in dataset:
    print(f'x = {x}\ny = {y}', end='\n' + '---' * 30 + '\n')

## 5. Model

In [None]:
class TinyGCNModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.gcn1 = layers.GCNConv(64, skip_connection=False)
        self.gcn2 = layers.GCNConv(64, skip_connection=False)
        self.readout = layers.Readout()
        self.linear = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = self.gcn1(x)
        x = self.gcn2(x)
        x = self.readout(x)
        x = self.linear(x)
        return x

model = TinyGCNModel().to('cuda')

## 6. Fit

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
loss_fn = torch.nn.MSELoss()

for _ in range(10):
    loss_sum = 0.
    for x, y in dataset:
        optimizer.zero_grad()
    
        outputs = model(x)
    
        y = torch.tensor(y, dtype=torch.float32).to('cuda')
        loss = loss_fn(outputs, y[:, None])
        loss.backward()
        optimizer.step()

        loss_sum += loss
        
    print(loss_sum)