In [1]:
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem.Draw import IPythonConsole
from rdkit import Chem
import numpy as np
from collections.abc import Generator
import pandas as pd
from sklearn import cluster as sk_clustering
from sklearn import datasets as sk_datasets
from sklearn import model_selection as sk_model_selection
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils import data as torch_data
if torch.cuda.is_available():
    print("CUDA AVAILABLE")
from neuralfingerprint import featurizer
from neuralfingerprint import datasets

CUDA AVAILABLE


In [2]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")


Using cuda device


In [3]:
df = pd.read_csv("delaney-processed.csv")

In [4]:
df["mol"] = df["smiles"].apply(Chem.MolFromSmiles)
df["num_atoms"] = df["mol"].apply(lambda x : x.GetNumAtoms())
df_filtered = df[df["num_atoms"]>=6].reset_index(drop=True)

In [5]:
train_df, test_df = sk_model_selection.train_test_split(df_filtered, test_size=0.2, random_state=32)
train_df, valid_df = sk_model_selection.train_test_split(train_df, test_size=0.15, random_state=32)

In [6]:
train_df.shape, valid_df.shape, test_df.shape, df.shape

((694, 12), (123, 12), (205, 12), (1128, 12))

# Convolutional Networks on Graphs for Learning Molecular Fingerprints

In [10]:
def molecule_collate_fn(batch):
    atom_batch = []
    bond_batch = []
    labels = []


    max_atoms = max(atom.shape[0] for atom, _, _ in batch)
    max_bonds = max(bond.shape[0] for _, bond, _ in batch)

    for atom_features, bond_features, label in batch:
        atoms_to_pad = max_atoms - atom_features.shape[0]
        bonds_to_pad = max_bonds - bond_features.shape[0]
        
        atom_features_padded = F.pad(atom_features, pad=(0, 0,  0, atoms_to_pad), value=0)
        bond_features_padded = F.pad(bond_features, pad=(0, 0, 0, bonds_to_pad), value=0)
        atom_batch.append(atom_features_padded)
        bond_batch.append(bond_features_padded)
        labels.append(label)

    return torch.stack(atom_batch), torch.stack(bond_batch), torch.tensor(labels)

In [11]:
train_moldataset = datasets.NeuralFingerprintDataset(smiles=tuple(train_df["smiles"]),
                                       targets=tuple(train_df["ESOL predicted log solubility in mols per litre"]))

test_moldataset = datasets.NeuralFingerprintDataset(smiles=tuple(test_df["smiles"]),
                                      targets=tuple(test_df["ESOL predicted log solubility in mols per litre"]))

valid_moldataset = datasets.NeuralFingerprintDataset(smiles=tuple(valid_df["smiles"]),
                                       targets=tuple(valid_df["ESOL predicted log solubility in mols per litre"]))

train_dataloader = torch_data.DataLoader(train_moldataset,
                                         batch_size=64,
                                         shuffle=True,
                                         collate_fn=molecule_collate_fn)
test_dataloader = torch_data.DataLoader(test_moldataset,
                                        batch_size=64, shuffle=False,
                                        collate_fn=molecule_collate_fn)
valid_dataloader = torch_data.DataLoader(valid_moldataset,
                                         batch_size=64, shuffle=False,
                                         collate_fn=molecule_collate_fn)

In [12]:
class VanillaNet(nn.Module):
    def __init__(self, n_atom_features: int, n_bond_features: int, n_out_features: int=100):
        super().__init__()
        self.atom_layer = nn.Linear(n_atom_features, n_out_features)
        self.bond_layer = nn.Linear(n_bond_features, n_out_features)
        self.activation = nn.ReLU()
        self.output_layer = nn.Linear(n_out_features*2, 1)
    def forward(self, x):
        atom_features, bond_features = x
        atom_x = torch.mean(self.atom_layer(atom_features), dim=1)
        bond_x = torch.mean(self.bond_layer(bond_features), dim=1)
        
        concat_x = self.activation(torch.cat((atom_x, bond_x), dim=-1))

        return self.output_layer(concat_x).squeeze(-1)


In [13]:
model = VanillaNet(n_atom_features=5, n_bond_features=3)
print(model)

VanillaNet(
  (atom_layer): Linear(in_features=5, out_features=100, bias=True)
  (bond_layer): Linear(in_features=3, out_features=100, bias=True)
  (activation): ReLU()
  (output_layer): Linear(in_features=200, out_features=1, bias=True)
)


In [14]:
def train_model(model, train_loader, val_loader, epochs=20, lr=1e-3, device='cpu'):

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    best_val_loss = float('inf')

    train_losses = []
    val_losses = []

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0.0

        for batch_idx, (atom_batch, bond_batch, labels) in enumerate(train_loader):
            atom_batch = atom_batch.to(device)
            bond_batch = bond_batch.to(device)
            labels = labels.to(device).float()

            optimizer.zero_grad()
            preds = model((atom_batch, bond_batch))

            if preds.shape != labels.shape:
                preds = preds.view_as(labels)

            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * len(labels)

        avg_train_loss = total_train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for atom_batch, bond_batch, labels in val_loader:
                atom_batch = atom_batch.to(device)
                bond_batch = bond_batch.to(device)
                labels = labels.to(device).float()

                preds = model((atom_batch, bond_batch))
                if preds.shape != labels.shape:
                    preds = preds.view_as(labels)

                loss = criterion(preds, labels)
                total_val_loss += loss.item() * len(labels)

        avg_val_loss = total_val_loss / len(val_loader.dataset)
        val_losses.append(avg_val_loss)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_model.pt")
            print("📦 Saved best model.")

    print("✅ Training complete.")
    return train_losses, val_losses


In [15]:
train_loss, val_loss = train_model(model, train_dataloader, valid_dataloader, epochs=30, device="cpu")

Epoch 1/30 | Train Loss: 9.2170 | Val Loss: 7.3471
📦 Saved best model.
Epoch 2/30 | Train Loss: 5.9606 | Val Loss: 4.2785
📦 Saved best model.
Epoch 3/30 | Train Loss: 3.6604 | Val Loss: 2.4460
📦 Saved best model.
Epoch 4/30 | Train Loss: 2.3322 | Val Loss: 1.6507
📦 Saved best model.
Epoch 5/30 | Train Loss: 1.8016 | Val Loss: 1.5055
📦 Saved best model.
Epoch 6/30 | Train Loss: 1.6073 | Val Loss: 1.5604
Epoch 7/30 | Train Loss: 1.7667 | Val Loss: 1.5399
Epoch 8/30 | Train Loss: 1.6551 | Val Loss: 1.5027
📦 Saved best model.
Epoch 9/30 | Train Loss: 1.6674 | Val Loss: 1.4964
📦 Saved best model.
Epoch 10/30 | Train Loss: 1.6194 | Val Loss: 1.4866
📦 Saved best model.
Epoch 11/30 | Train Loss: 1.6230 | Val Loss: 1.4845
📦 Saved best model.
Epoch 12/30 | Train Loss: 1.6064 | Val Loss: 1.4832
📦 Saved best model.
Epoch 13/30 | Train Loss: 1.6636 | Val Loss: 1.4978
Epoch 14/30 | Train Loss: 1.5627 | Val Loss: 1.4994
Epoch 15/30 | Train Loss: 1.5518 | Val Loss: 1.5130
Epoch 16/30 | Train Loss: 1.6

# Refactor

## One hot encode atom features

In [17]:
smi = "O=C1OC(CN1c1ccc(cc1)N1CCOCC1=O)CNC(=O)c1ccc(s1)Cl"
molecule = Chem.MolFromSmiles(smi)

In [45]:
feats = featurizer.featurize_atoms(molecule)
atomic_number = feats[:, 0]
degree = feats[:, 1]
valence = feats[:, 2]
num_hydrogens = feats[:, 3]
is_aromatic = feats[:, 4].unsqueeze(1)



In [170]:

from typing import Final

from rdkit.Chem import rdchem

MAX_DEGREE: Final[int] = 6
MAX_VALENCE: Final[int] = 6
MAX_HYDROGEN_ATOMS: Final[int] = 5
MAX_ATOMIC_NUMBER: Final[int] = 118
MAX_BOND_TYPES: Final[int] = 22


BOND_TYPES: dict[rdchem.BondType, int] = {
    v: k for k, v in rdchem.BondType.values.items()
}


def featurize_bonds(molecule: Chem.Mol) -> torch.Tensor:
    """Generates a tensor of bond features.
    This function returns a tensor representing bond features for every
    bond in the input molecule. Bond features consists of bond type,
    conjugation and information whether its on a ring.

    We use `bond.GetBondTypeAsDouble()` to get a numeric value for each type:

    Single: 1.0
    Aromatic: 1.5
    Double: 2.0
    Triple: 3.0

    Args:
        molecule: a Chem.Mol object.

    Returns:
        A tensor of bond features of shape (N, 3), where N is the number of
        bonds in the molecule.


    """
    bonds = molecule.GetBonds()
    raw_features = []
    for bond in bonds:
        is_conjugated = torch.tensor(int(bond.GetIsConjugated())).unsqueeze(-1)
        bond_type = F.one_hot(
            torch.tensor(BOND_TYPES[bond.GetBondType()]),
            num_classes=MAX_BOND_TYPES,
        )
        
        is_in_ring = torch.tensor(int(bond.IsInRing())).unsqueeze(-1)

        bond_feat = torch.cat([bond_type, is_conjugated, is_in_ring], -1)
        raw_features.append(bond_feat)

    return torch.stack(raw_features)

In [175]:
def featurize_atoms(molecule: Chem.Mol) -> torch.Tensor:
    """Generates a tensor of atom features.

    Atom features consist of:
    - Atomic number
    - Degree (number of directly bonded atoms)
    - Implicit valence
    - Total number of hydrogen atoms
    - Aromaticity flag

    For the original implementation, see:
    https://github.com/HIPS/neural-fingerprint/blob/master/neuralfingerprint/features.py
    Args:
        molecule: a Chem.Mol object.

    Returns:
        A tensor of atom features of shape (N, 135), where N is the number of
        bonds in the molecule.


    """
    atoms = molecule.GetAtoms()
    raw_features = []
    for atom in atoms:
        atomic_number_one_hot = F.one_hot(
            torch.tensor(atom.GetAtomicNum()),
            num_classes=MAX_ATOMIC_NUMBER,
        )
        degree_one_hot = F.one_hot(
            torch.tensor(atom.GetDegree()),
            num_classes=MAX_DEGREE,
        )
        valence_one_hot = F.one_hot(
            torch.tensor(atom.GetImplicitValence()),
            num_classes=MAX_VALENCE,
        )
        num_hydrogens_one_hot = F.one_hot(
            torch.tensor(atom.GetTotalNumHs()),
            num_classes=MAX_HYDROGEN_ATOMS,
        )
        is_aromatic = torch.tensor(atom.GetIsAromatic()).unsqueeze(-1)

        atom_feat = torch.cat(
            [
                atomic_number_one_hot,
                degree_one_hot,
                valence_one_hot,
                num_hydrogens_one_hot,
                is_aromatic,
            ],
            -1,
        )
        raw_features.append(atom_feat)

    return torch.stack(raw_features)

In [176]:
feats_bonds = featurize_bonds(molecule)
feats_atoms = featurize_atoms(molecule)
test_feats = feats_bonds[0]

In [177]:
feats_atoms.shape

torch.Size([29, 136])

torch.Size([1, 24])

In [30]:
atomic_number_one_hot = F.one_hot(atomic_number, num_classes=118)
degree_one_hot = F.one_hot(degree, num_classes=6)
valence_one_hot = F.one_hot(valence, num_classes=6)
num_hydrogens_one_hot = F.one_hot(num_hydrogens, num_classes=5)

In [36]:
atomic_number_one_hot.shape, degree_one_hot.shape, valence_one_hot.shape, num_hydrogens_one_hot.shape

(torch.Size([29, 118]),
 torch.Size([29, 6]),
 torch.Size([29, 6]),
 torch.Size([29, 5]))

In [48]:
torch.cat([atomic_number_one_hot, degree_one_hot, valence_one_hot, num_hydrogens_one_hot, is_aromatic], dim=1)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 0]])

In [33]:
F.one_hot(arr, num_classes=118)

tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [49]:
bond = molecule.GetBonds()[0]

In [64]:
F.one_hot(torch.tensor(bond.GetBondType()))

tensor([0, 0, 1])

In [61]:
def bond_features(bond):
    bt = bond.GetBondType()
    return np.array([bt == Chem.rdchem.BondType.SINGLE,
                     bt == Chem.rdchem.BondType.DOUBLE,
                     bt == Chem.rdchem.BondType.TRIPLE,
                     bt == Chem.rdchem.BondType.AROMATIC,
                     # bond.GetIsConjugated(),
                     # bond.IsInRing()
                    ])


In [66]:
F.one_hot(torch.tensor(bond_features(bond)))

RuntimeError: one_hot is only applicable to index tensor of type LongTensor.

In [63]:
featurizer.featurize_bonds(molecule)[0]

tensor([2, 1, 0])

In [60]:
is_conjugated = bond.GetIsConjugated()
bond_type = int(bond.GetBondTypeAsDouble())
is_in_ring = int(bond.IsInRing())
bond_type

2

In [90]:
bond_type_name.as_integer_ratio

(12, 1)

In [98]:
bond_type_name

rdkit.Chem.rdchem.BondType.AROMATIC

In [None]:
from rdkit.Chem import rdchem

In [97]:
for k, v in Chem.BondType.values.items():
    print(k.as_integer_ratio()[0], k.bit_count(), v)

0 0 UNSPECIFIED
1 1 SINGLE
2 1 DOUBLE
3 2 TRIPLE
4 1 QUADRUPLE
5 2 QUINTUPLE
6 2 HEXTUPLE
7 3 ONEANDAHALF
8 1 TWOANDAHALF
9 2 THREEANDAHALF
10 2 FOURANDAHALF
11 3 FIVEANDAHALF
12 2 AROMATIC
13 3 IONIC
14 3 HYDROGEN
15 4 THREECENTER
16 1 DATIVEONE
17 2 DATIVE
18 2 DATIVEL
19 3 DATIVER
20 2 OTHER
21 3 ZERO


In [68]:
for bond in molecule.GetBonds():
    bond_type = bond.GetBondTypeAsDouble()
    bond_type_name = bond.GetBondType()
    print(int(bond_type), bond_type_name)

2 DOUBLE
1 SINGLE
1 SINGLE
1 SINGLE
1 SINGLE
1 SINGLE
1 AROMATIC
1 AROMATIC
1 AROMATIC
1 AROMATIC
1 AROMATIC
1 SINGLE
1 SINGLE
1 SINGLE
1 SINGLE
1 SINGLE
1 SINGLE
2 DOUBLE
1 SINGLE
1 SINGLE
1 SINGLE
2 DOUBLE
1 SINGLE
1 AROMATIC
1 AROMATIC
1 AROMATIC
1 AROMATIC
1 SINGLE
1 SINGLE
1 AROMATIC
1 SINGLE
1 AROMATIC
