# Example to use the pipeline with your own model

In [1]:
import os
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINEConv

from molx.dataset import Molecule3D
from molx.mol3d import Mol3DTrainer, eval3d
from molx.model import make_mask

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Build your own model

### 1. Define a GNN to obtain node representations
In this example, we use a simple GINE model. You can use any PyTorch graph neural networks to compute node representations. This model takes a PyG Batch object as input, and outputs node representations as a torch tensor.

In [2]:
class GINENet(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden, dropout):
        super(GINENet, self).__init__()
        self.conv1 = GINEConv(torch.nn.Sequential(torch.nn.Linear(num_node_features, hidden)),eps = 0, train_eps = True)
        self.conv2 = GINEConv(torch.nn.Sequential(torch.nn.Linear(hidden, hidden)),eps = 0, train_eps = True)
        self.conv3 = GINEConv(torch.nn.Sequential(torch.nn.Linear(hidden, hidden)),eps = 0, train_eps = True)
        self.lin1 = torch.nn.Linear(num_edge_features, num_node_features, bias = True)
        self.lin2 = torch.nn.Linear(num_node_features, hidden, bias = True)
        self.dropout = dropout
        self.fc = torch.nn.Linear(in_features=hidden, out_features=1)

    def forward(self, batch_data):
        """
        Args:
            batch_data: A PyG Batch object describing a batch of graphs 
                        as one big (disconnected) graph.
        
        Return:
            xs: A torch tensor of shape (n, hidden), which denotes 
                node representations.
        """
        
        x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr

        edge_attr = self.lin1(edge_attr.float())
        xs = F.relu(self.conv1(x, edge_index, edge_attr))
        xs = F.dropout(xs, p=self.dropout, training=self.training)
        edge_attr = self.lin2(edge_attr.float())
        xs = F.relu(self.conv2(xs, edge_index, edge_attr))
        xs = self.conv3(xs, edge_index, edge_attr)

        return xs

### 2. Define another model to compute pairwise distances
Then, you need to define a model for calculating pairwise distances between nodes(atoms). This model takes node representations as input, and outputs a distance matrix. The output should be a torch tensor of shape $(n, n)$, where $n$ denotes the total number of nodes in this batch. Here we use the elementwise max method proposed in Molecule3D as an example. You can develop your own method.

Note that only intra-molecular distances should be calculated. Inter-molecular distances are meaningless, they are thus ignored by a mask. You should also use the mask.

In [3]:
class DistNet(torch.nn.Module):
    def __init__(self, hidden, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        super(DistNet, self).__init__()
        self.fc = torch.nn.Linear(in_features=hidden, out_features=1)
        self.device = device
    
    def forward(self, xs, batch, train=False):
        """
        Args:
            xs: A torch tensor of shape (n, hidden), which denotes 
                node representations.
            
        Return:
            mask_d_pred: A torch tensor of shape (n, n), where n
                        denotes the total number of nodes in 
                        this batch.
            mask: A torch tensor of shape (n, n). Value 1 indicates 
                    intra-molecular, while value 0 indicates 
                    inter-molecular.
            count: Total number of intra-molecular in this batch.
        """
        d_pred = self.fc(torch.max(xs.unsqueeze(0), xs.unsqueeze(1))).squeeze()
        mask, count = make_mask(batch, self.device)

        if train:
            mask_d_pred = d_pred * mask
        else:
            mask_d_pred = F.relu(d_pred * mask) # Enforce non-negative distances for evaluation
        return mask_d_pred, mask, count

### 3. Define your final model

In [4]:
class MyModel(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden, dropout):
        super(MyModel, self).__init__()
        self.node_embed = GINENet(num_node_features, num_edge_features, hidden, dropout)
        self.calc_dist = DistNet(hidden)
    
    def forward(self, batch_data, train=False):
        xs = self.node_embed(batch_data)
        mask_d_pred, mask, count = self.calc_dist(xs, batch_data.batch)
        return mask_d_pred, mask, count

### 4. Initiate your model

In [9]:
model = MyModel(num_node_features=9, num_edge_features=3, hidden=256, dropout=0).to(device) 

## Configurations

In [6]:
conf = {}
conf['epochs'] = 5
conf['early_stopping'] = 3
conf['lr'] = 0.0001
conf['lr_decay_factor'] = 0.8
conf['lr_decay_step_size'] = 10
conf['dropout'] = 0
conf['weight_decay'] = 0
conf['batch_size'] = 20
conf['save_ckpt'] = 'best_valid'
conf['out_path'] = 'results/exp0/'
conf['split'] = 'random' # or 'scaffold'
conf['criterion'] = 'mse'

## Load dataset

In [7]:
root_dir = os.getcwd() # Where your data folder is located

train_dataset = Molecule3D(root=root_dir, transform=None, split='train', split_mode=conf['split'])
val_dataset = Molecule3D(root=root_dir, transform=None, split='val', split_mode=conf['split'])
test_dataset = Molecule3D(root=root_dir, transform=None, split='test', split_mode=conf['split'])

# In this example, we use a subset of dataset for illustration
train_dataset = train_dataset[:1000]
val_dataset = val_dataset[:100]
test_dataset = test_dataset[:100]

## Training

In [10]:
trainer = Mol3DTrainer(train_dataset, val_dataset, conf,
                       device=device)
model = trainer.train(model)

epoch: 1; Train -- loss: 8.898
epoch: 1; Valid -- val_MAE: 2.163; val_RMSE: 2.861; val_Validity: 2.00%; val_Validity3D: 0.00%;
epoch: 2; Train -- loss: 6.685
epoch: 2; Valid -- val_MAE: 2.067; val_RMSE: 2.669; val_Validity: 3.00%; val_Validity3D: 0.00%;
epoch: 3; Train -- loss: 6.158
epoch: 3; Valid -- val_MAE: 1.965; val_RMSE: 2.568; val_Validity: 2.00%; val_Validity3D: 0.00%;
epoch: 4; Train -- loss: 5.696
epoch: 4; Valid -- val_MAE: 2.002; val_RMSE: 2.558; val_Validity: 2.00%; val_Validity3D: 0.00%;
epoch: 5; Train -- loss: 5.425
epoch: 5; Valid -- val_MAE: 1.866; val_RMSE: 2.468; val_Validity: 2.00%; val_Validity3D: 0.00%;
Best valid epoch is 5; Best val_MAE: 1.866; Best val_RMSE: 2.468; Best val_Validity: 2.00%; Best val_Validity3D: 0.00%


## Evaluation

In [11]:
print('load best val model...')
best_model_path = os.path.join(conf['out_path'], 'ckpt_best_val.pth')
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model'])

mae, rmse, test_validity, test_validity3d = eval3d(model, test_dataset)
print('epoch: {}; Test -- test_MAE: {:.3f}; test_RMSE: {:.3f}; % test_Validity: {:.2f}%;  % test_Validity3D: {:.2f}%;'
                           .format(checkpoint['epoch'], mae, rmse, test_validity*100, test_validity3d*100))

load best val model...
epoch: 5; Test -- test_MAE: 1.808; test_RMSE: 2.324; % test_Validity: 1.00%;  % test_Validity3D: 0.00%;
