# Neural molecular fingerprints

## Imports

In [33]:
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import QM9
from torch_geometric.nn import GCNConv, MessagePassing, global_mean_pool

In [5]:
qm9 = QM9("/Volumes/OXYTOCIN/datasets/qm9-pyg")

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/molnet_publish/qm9.zip
Extracting /Volumes/OXYTOCIN/datasets/qm9-pyg/raw/qm9.zip
Downloading https://ndownloader.figshare.com/files/3195404
Processing...
100%|█████████████████████████████████████████████████████████████████████████████████████████████| 133885/133885 [05:01<00:00, 444.16it/s]
Done!


## Simple GCN

In [63]:
class ResGCNConv(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        out = F.relu(self.conv1(x, edge_index))
        out = self.conv2(out, edge_index)
        out += x
        
        return F.relu(out)

class MoleculeEncoder(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        
        self.conv1 = ResGCNConv(in_features, 512, 512)
        self.conv2 = ResGCNConv(512, 512, 512)
        self.conv3 = ResGCNConv(512, 512, 512)
        self.lin = nn.Linear(512, out_features)
    
    def forward(self, x, edge_index, batch):
        # x is node features
        # edge_index is connectivity
        # batch assigns each node to its graph index
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.conv3(x, edge_index))
        
        # takes the average over all node embeddings
        # if we implement this ourselves, we need to account for batch
        # it's not automatic
        x = global_mean_pool(x, batch)
        x = F.dropout(x, training=self.training)
        x = self.lin(x)
        
        return x

In [64]:
def training_loop(model, optimizer, train_loader, n_epochs=200, device="cpu"):
    model.train()
    model.to(device)
    
    for epoch in range(1, n_epochs + 1):
        epoch_loss = 0
        epoch_batches = 0
        
        for batch in train_loader:
            batch = batch.to(device)
            
            pred = model(batch.x[:, 5:6], batch.edge_index, batch.batch)
            # output: shape=(batch x 2); (homo, lumo)
            pred_homo = pred[:, 0]
            pred_lumo = pred[:, 1]
            target_homo = batch.y[:, 2]
            target_lumo = batch.y[:, 3]
            
            optimizer.zero_grad()
            loss_homo = F.mse_loss(pred_homo, target_homo)
            loss_lumo = F.mse_loss(pred_lumo, target_lumo)
            loss = loss_homo + loss_lumo
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.detach().item()
            epoch_batches += 1
        
        now = datetime.now()
        now_str = now.strftime("%H:%M:%S.%f")
        print(f"[{now_str}] Epoch {epoch}, Mean loss {epoch_loss / epoch_batches}")

In [65]:
model = MoleculeEncoder(1, 2)
train_loader = DataLoader(qm9, batch_size=32, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

training_loop(model, optimizer, train_loader, n_epochs=50, device="cpu")

[13:21:49.096665] Epoch 1, Mean loss 1960.1026913108817
[13:27:14.504866] Epoch 2, Mean loss 1.5525841934373659
[13:32:44.430286] Epoch 3, Mean loss 1.3654323521264464
[13:38:23.712392] Epoch 4, Mean loss 1.2323828374415338
[13:43:57.504363] Epoch 5, Mean loss 1.950079728277775
[13:49:19.284198] Epoch 6, Mean loss 2.0053588697979694
[13:54:47.643061] Epoch 7, Mean loss 1.9961259295404554
[14:00:11.689897] Epoch 8, Mean loss 1.9954121644294591
[14:05:51.317388] Epoch 9, Mean loss 1.9942465832251504
[14:11:19.547738] Epoch 10, Mean loss 1.994700126439209
[14:16:37.335621] Epoch 11, Mean loss 1.9957061108064174
[14:22:06.349430] Epoch 12, Mean loss 1.993931232702181
[14:27:34.861998] Epoch 13, Mean loss 1.994260351941824
[14:33:06.627932] Epoch 14, Mean loss 1.994554471199948
[14:38:41.660058] Epoch 15, Mean loss 1.9953444549377928
[14:44:01.246647] Epoch 16, Mean loss 1.9955293163748293
[14:49:18.981081] Epoch 17, Mean loss 1.9939039585554794
[14:54:34.318535] Epoch 18, Mean loss 1.99685

KeyboardInterrupt: 