In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
from rdkit.Chem import rdMolDescriptors
from rdkit.Chem import rdmolops

In [42]:
# Let us make a Graph Neural Network
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_output_layers):
        """"params: 
        input_dim: input feature dimension, 
        hidden_dim: hidden layer dimension, 
        output_dim: output dimension, 
        num_layers: number of hidden layers
        """
        super(GNN, self).__init__()
        self.fc1 = nn.Embedding(input_dim, hidden_dim)
        self.fcs = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for i in range(num_layers)])
        self.fc2 = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for i in range(num_output_layers)])

        self.property = nn.Linear(hidden_dim, 1)

    def forward(self, x, adj):
        """params: 
        x: input features, 
        adj: adjacency matrix
        output: predicted labels"""
        x = self.fc1(x)
        x = x.to(torch.float32)
        adj = adj.to(torch.float32)
        for i, fc in enumerate(self.fcs):
            x = torch.sparse.mm(adj, x)
            x = F.relu(fc(x))
        for i, fc in enumerate(self.fc2):
            x = torch.sparse.mm(adj, x)
            x = F.relu(fc(x))
        x = torch.sparse.mm(adj, x)
        x = self.property(x)
        return x
    
    def loss(self, pred, label):
        return F.mse_loss(pred, label)
    
    def accuracy(self, pred, label):
        return F.l1_loss(pred, label)
    
    def predict(self, x, adj):
        return self.forward(x, adj)
    
    def update(self, lr):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        optimizer.zero_grad()
        optimizer.step()

In [6]:
# Get the data
data = pd.read_csv('data_no_qm.csv')
# Remove 'rep' column
data = data.drop(columns=['rep'])
# Take a sample of the data
data = data.sample(1000)
# Make sure the index is correct
data = data.reset_index(drop=True)

In [66]:
# Create a function to convert smiles to graph
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    # Get the adjacency matrix
    adj_mat = rdmolops.GetAdjacencyMatrix(mol)
    # Get the feature matrix
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(atom.GetAtomicNum())
    return adj_mat, atom_features



In [85]:
# Now we create the graphs for all the smiles
adjs = []
atom_features = []
for i in range(data.shape[0]):
    adj, atom_feature = smiles_to_graph(data['smiles'][i])
    adjs.append(adj)
    atom_features.append(atom_feature)

# Now we convert the data to tensors
adjs = [torch.tensor(adj) for adj in adjs]
atom_features = [torch.tensor(atom_feature) for atom_feature in atom_features]

# labels are the 7 first columns called 'storage', 'tbr', 'max_abs', 'osc_str', 'abs_prod', 'osc_prod', 'sce'
labels = data[['storage', 'tbr', 'max_abs', 'osc_str', 'abs_prod', 'osc_prod', 'sce']].values
labels = torch.tensor(labels)
labels = labels.to(torch.float32)
# Now we create the model
model = GNN(100, 100, 2, 2)

# Now we train the model
lr = 0.01

for epoch in range(10):
    for i in range(len(adjs)):
        pred = model.predict(atom_features[i], adjs[i]).unsqueeze(0)
        loss = model.loss(pred, labels[i])
        loss.backward()
        model.update(lr)
    print(f'Epoch {epoch}, Loss {loss}')

# Now we test the model
preds = []
for i in range(len(adjs)):
    pred = model.predict(atom_features[i], adjs[i])
    preds.append(pred)

  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)
  return F.mse_loss(pred, label)


Epoch 0, Loss 30226.0234375
Epoch 1, Loss 30226.0234375
Epoch 2, Loss 30226.0234375
Epoch 3, Loss 30226.0234375
Epoch 4, Loss 30226.0234375
Epoch 5, Loss 30226.0234375
Epoch 6, Loss 30226.0234375
Epoch 7, Loss 30226.0234375
Epoch 8, Loss 30226.0234375
Epoch 9, Loss 30226.0234375
