In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch" # Comment out for tensorflow backend

from molexpress import layers
from molexpress.datasets import featurizers
from molexpress.datasets import encoders
from molexpress.ops.chem_ops import get_molecule

import torch

## 1. Featurizers

In [None]:
mol = get_molecule('C(C(=O)O)N')

print(featurizers.AtomType(vocab={'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'O'}, oov=True)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=False)(mol.GetAtoms()[0]))
print(featurizers.AtomType(vocab={'C', 'O', 'N'}, oov=True)(mol.GetAtoms()[0]))

## 2. Encoder

In [None]:
atom_featurizers = [
    featurizers.AtomType({'C', 'O', 'N'}),
    featurizers.Hybridization(),
]

bond_featurizers = [
    featurizers.BondType(),
    featurizers.Conjugated()
]

# Currently, collate_fn performs the masking. Should it be performed in PeptideGraphEncoder instead?
# ***IMPORTANT***: When fine-tuning a pretrained model, the PeptideGraphEncoder for the fine-tuning task 
#                  also needs supports_masking=True (as dimensions need to match).
peptide_graph_encoder = encoders.PeptideGraphEncoder(
    atom_featurizers=atom_featurizers, 
    bond_featurizers=bond_featurizers,
    self_loops=False, # self_loops True adds one feature dim to edge state
    supports_masking=True, # supports_masking True adds one feature dim to node and edge state
)

mol2 = get_molecule('CC(C(=O)O)N')

peptide_graph_encoder([mol, mol2])

## 3. Dataset

In [None]:
x_dummy = [
    ['CC(C)C(C(=O)O)N', 'C(C(=O)O)N'], 
    ['C(C(=O)O)N', 'CC(C(=O)O)N', 'C(C(=O)O)N'], 
    ['CC(C(=O)O)N']
]

class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, x):
        self.x = x

    def __len__(self):
        return len(self.x)
        
    def __getitem__(self, index):
        graph = peptide_graph_encoder(self.x[index])
        return graph
        
torch_dataset = Dataset(x_dummy)

dataset = torch.utils.data.DataLoader(
    torch_dataset, batch_size=2, collate_fn=peptide_graph_encoder.masked_collate_fn)


## 4. Model

In [None]:
class GraphNeuralNetwork(torch.nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.gcn1 = layers.GINConv(dim)
        self.gcn2 = layers.GINConv(dim)
        self.gcn3 = layers.GINConv(dim)
        self.gcn4 = layers.GINConv(dim)
        
    def forward(self, x):
        x = self.gcn1(x)
        x = self.gcn2(x)
        x = self.gcn3(x)
        x = self.gcn4(x)
        return x


class NodePrediction(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_dim, input_dim) 
        self.linear2 = torch.nn.Linear(input_dim, output_dim) 
        
    def forward(self, x):
        x = self.linear1(x['node_state'])
        x = torch.relu(x)
        x = self.linear2(x)
        return x


class EdgePrediction(torch.nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_dim, input_dim) 
        self.linear2 = torch.nn.Linear(input_dim, output_dim)
        self.gather_incident = layers.GatherIncident()
        
    def forward(self, x):
        x = self.gather_incident(x) # We do not use edge states but incident node states.
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

## 5. Fit

In [None]:
graph_model = GraphNeuralNetwork(32).to('cuda')
node_pred_model = NodePrediction(32, 11).to('cuda')
edge_pred_model = EdgePrediction(32 * 2, 6).to('cuda')

optimizer = torch.optim.SGD(
    (
        list(graph_model.parameters()) + 
        list(node_pred_model.parameters()) + 
        list(edge_pred_model.parameters())
    ),
    lr=0.001, momentum=0.5
)
loss_fn = torch.nn.BCELoss(reduction='none') # use BCELoss if node/edge label (initial node/edge state) is multi-hot.
# loss_fn = torch.nn.CrossEntropyLoss(reduction='none') # use CrossEntropyLoss if node/edge label is one-hot.

def weighted_loss(pred, true, weight):
    pred = torch.nn.Sigmoid()(pred)    # Sigmoid() only with BCELoss
    loss = loss_fn(pred, true)
    loss = loss * weight[:, None]      # weight[:, None] only with BCELoss
    return torch.mean(loss)
    
for i in range(150):
    
    loss_sum = 0.
    for x in dataset:
        graph = graph_model(x)
        
        node_pred = node_pred_model(graph)
        edge_pred = edge_pred_model(graph)

        node_loss = weighted_loss(node_pred, graph['node_label'], graph['node_loss_weight'])
        edge_loss = weighted_loss(edge_pred, graph['edge_label'], graph['edge_loss_weight'])

        loss = node_loss + edge_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_sum += loss

    if i % 5 == 0:
        loss_numpy = loss_sum.detach().cpu().numpy()
        print(f"Iteration {i:<3} - Loss {loss_numpy:.3f}")