In [1]:
!pip install rdkit
!pip install selfies
!pip install datamol
!pip install scikit-learn
!pip install torch-geometric
!pip install torch_scatter



In [2]:
import os
import random
import time
from pathlib import Path
import numpy as np
import pandas as pd
import tensorflow as tf
import sklearn

from sklearn.model_selection import train_test_split
from sklearn.metrics import (ConfusionMatrixDisplay, confusion_matrix,
                             roc_auc_score, precision_score, recall_score,
                             f1_score, roc_curve)

import matplotlib.pyplot as plt

from rdkit import Chem, RDLogger
from rdkit.Chem.rdmolops import GetAdjacencyMatrix
from rdkit.Chem import Draw
import datamol as dm
import selfies as sf

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Bilinear, Linear, Parameter, Sequential

from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, reset

from torch_scatter import scatter
from torch_geometric.data import Data

dm.disable_rdkit_log()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Sets seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

In [4]:
def preprocess(smiles_string):
    """

    Preprocesses smiles strings/mols for molecular standardization

    """
    

    mol = dm.to_mol(smiles_string, ordered=True)
    if mol is None:
      return None

    mol = dm.fix_mol(mol)
    mol = dm.sanitize_mol(mol, sanifix=True, charge_neutral=False)
    mol = dm.standardize_mol(
        mol,
        disconnect_metals = False,
        normalize = True,
        reionize = True,
        uncharge = False,
        stereo = True,
    )

    standard_smiles = dm.standardize_smiles(dm.to_smiles(mol))
    # row["selfies"] = dm.to_selfies(mol)
    # row["inchi"] = dm.to_inchi(mol)
    # row["inchikey"] = dm.to_inchikey(mol)

    return standard_smiles

In [5]:
def one_hot_encoding(x, permitted_list):

    """

    Maps input elements x which are not in the permitted list to the last element
    of the permitted list.

    """
    if x not in permitted_list:
        x = permitted_list[-1]
    binary_encoding = [int(boolean_value) for boolean_value in list(map(lambda s: x == s, permitted_list))]
    return binary_encoding

def process_atoms(mol):
  """

  Makes sure there are no none atoms

  """
  atom_features_list = []

  for atom in mol.GetAtoms():
    atom_features = get_atom_features(atom)

    if atom is None:
      atom_features_list.append(atom_features)
      print(f"Skipping invalid atom with index: {atom.GetIdx()}")
      # append with zero vector with len of the feature vector
      atom_features_list.append(np.zeros(len(atom_features)))
    else:
       atom_features_list.append(atom_features)

  return atom_features_list

def get_atom_features(atom,
                      use_chirality=True,
                      hydrogens_implicit=True):
    """

    Takes an RDKit atom object as input and gives a 1d-numpy array of atom features as output.

    """
    try:
        if atom is None:
            # Handle the case where the atom is None
            return np.zeros(0)  # Returning a zero vector as a fallback

        # Proceed with feature extraction
        permitted_list_of_atoms = ['C','N','O','S','F','Si','P','Cl','Br','Mg','Na','Ca','Fe','As','Al','I', 'B','V','K','Tl','Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn', 'Li','Ge','Cu','Au','Ni','Cd','In','Mn','Zr','Cr','Pt','Hg','Pb','Unknown']

        if hydrogens_implicit == False:
            permitted_list_of_atoms = ['H'] + permitted_list_of_atoms

        atom_type_enc = one_hot_encoding(str(atom.GetSymbol()), permitted_list_of_atoms)
        n_heavy_neighbors_enc = one_hot_encoding(int(atom.GetDegree()), [0, 1, 2, 3, 4, "MoreThanFour"])
        formal_charge_enc = one_hot_encoding(int(atom.GetFormalCharge()), [-3, -2, -1, 0, 1, 2, 3, "Extreme"])
        hybridisation_type_enc = one_hot_encoding(str(atom.GetHybridization()), ["S", "SP", "SP2", "SP3", "SP3D", "SP3D2", "OTHER"])
        is_in_a_ring_enc = [int(atom.IsInRing())]
        is_aromatic_enc = [int(atom.GetIsAromatic())]
        atomic_mass_scaled = [float((atom.GetMass() - 10.812) / 116.092)]
        vdw_radius_scaled = [float((Chem.GetPeriodicTable().GetRvdw(atom.GetAtomicNum()) - 1.5) / 0.6)]
        covalent_radius_scaled = [float((Chem.GetPeriodicTable().GetRcovalent(atom.GetAtomicNum()) - 0.64) / 0.76)]

        atom_feature_vector = atom_type_enc + n_heavy_neighbors_enc + formal_charge_enc + hybridisation_type_enc + is_in_a_ring_enc + is_aromatic_enc + atomic_mass_scaled + vdw_radius_scaled + covalent_radius_scaled

        if use_chirality:
            chirality_type_enc = one_hot_encoding(str(atom.GetChiralTag()), ["CHI_UNSPECIFIED", "CHI_TETRAHEDRAL_CW", "CHI_TETRAHEDRAL_CCW", "CHI_OTHER"])
            atom_feature_vector += chirality_type_enc

        if hydrogens_implicit:
            n_hydrogens_enc = one_hot_encoding(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4, "MoreThanFour"])
            atom_feature_vector += n_hydrogens_enc

        return np.array(atom_feature_vector)

    except Exception as e:
        print(f"Error in processing atom {atom}: {e}")
        # Return a zero vector if any exception is encountered
        return np.zeros(0)

def get_bond_features(bond,
                      use_stereochemistry = True):
    """

    Takes an RDKit bond object as input and gives a 1d-numpy array of bond features as output.

    """
    if bond is None:
      return np.zeros(4)

    permitted_list_of_bond_types = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
    bond_type_enc = one_hot_encoding(bond.GetBondType(), permitted_list_of_bond_types)

    bond_is_conj_enc = [int(bond.GetIsConjugated())]

    bond_is_in_ring_enc = [int(bond.IsInRing())]

    bond_feature_vector = bond_type_enc + bond_is_conj_enc + bond_is_in_ring_enc

    if use_stereochemistry == True:
        stereo_type_enc = one_hot_encoding(str(bond.GetStereo()), ["STEREOZ", "STEREOE", "STEREOANY", "STEREONONE"])
        bond_feature_vector += stereo_type_enc

    return np.array(bond_feature_vector)

def smiles_to_data(smiles):
    """
    Inputs:

    x_smiles = [smiles_1, smiles_2, ....] ... a list of SMILES strings

    Outputs:

    data_list = [G_1, G_2, ...] ... a list of torch_geometric.data.Data objects which represent molecular graphs that can readily be used for machine learning

    """
    print("smiles:", smiles)
    data_list = []

    # convert SMILES to RDKit mol object
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
      print(f"Invalid SMILES string: {smiles}")
      return None

    #processing atoms to make sure there are no NoneType atoms
    atom_features_list = process_atoms(mol)

    # get feature dimensions
    n_nodes = len(atom_features_list)
    n_edges = 2 * mol.GetNumBonds()
    unrelated_smiles = "O=O"
    unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)

    n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))

    n_edge_features = len(get_bond_features(unrelated_mol.GetBondBetweenAtoms(0, 1)))

    # construct node feature matrix X_features of shape (n_nodes, n_node_features)
    X_features = np.zeros((n_nodes, n_node_features))

    for atom in mol.GetAtoms():
        atom_features = get_atom_features(atom)
        X_features[atom.GetIdx(), :] = get_atom_features(atom)

    X_features = torch.tensor(X_features, dtype = torch.float)

    # construct edge index array of shape (2, n_edges)
    (rows, cols) = np.nonzero(GetAdjacencyMatrix(mol))

    torch_rows = torch.from_numpy(rows.astype(np.int64)).to(torch.long)
    torch_cols = torch.from_numpy(cols.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([torch_rows, torch_cols], dim = 0)

    # construct edge feature array of shape (n_edges, n_edge_features)
    edge_attr = np.zeros((n_edges, n_edge_features))

    for (k, (i,j)) in enumerate(zip(rows, cols)):
        edge_attr[k] = get_bond_features(mol.GetBondBetweenAtoms(int(i),int(j)))

    edge_attr = torch.tensor(edge_attr, dtype = torch.float)

    # construct Pytorch Geometric data object and append to data list
    data_list.append(Data(x = X_features, edge_index = edge_index, edge_attr = edge_attr))

    return data_list

In [6]:
# ---------------------------------------
# Attention layer
# ---------------------------------------
class FeatureAttention(nn.Module):
    def __init__(self, channels, reduction):
        super().__init__()
        self.mlp = Sequential(
            Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            Linear(channels // reduction, channels, bias=False),
        )

        self.reset_parameters()

    def reset_parameters(self):
        reset(self.mlp)

    def forward(self, x, batch, size=None):
        max_result = scatter(x, batch, dim=0, dim_size=size, reduce="max")
        sum_result = scatter(x, batch, dim=0, dim_size=size, reduce="sum")
        max_out = self.mlp(max_result)
        sum_out = self.mlp(sum_result)
        y = torch.sigmoid(max_out + sum_out)
        y = y[batch]
        return x * y

# ---------------------------------------
# Neural tensor networks conv
# ---------------------------------------
class NTNConv(MessagePassing):
    def __init__(
        self, in_channels, out_channels, slices, dropout, edge_dim=None, **kwargs
    ):
        kwargs.setdefault("aggr", "add")
        super(NTNConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.slices = slices
        self.dropout = dropout
        self.edge_dim = edge_dim

        self.weight_node = Parameter(torch.Tensor(in_channels, out_channels))
        if edge_dim is not None:
            self.weight_edge = Parameter(torch.Tensor(edge_dim, out_channels))
        else:
            self.weight_edge = self.register_parameter("weight_edge", None)

        self.bilinear = Bilinear(out_channels, out_channels, slices, bias=False)

        if self.edge_dim is not None:
            self.linear = Linear(3 * out_channels, slices)
        else:
            self.linear = Linear(2 * out_channels, slices)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight_node)
        glorot(self.weight_edge)
        self.bilinear.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, x, edge_index, edge_attr=None, return_attention_weights=None):
        x = torch.matmul(x, self.weight_node)

        if self.weight_edge is not None:
            assert edge_attr is not None
            edge_attr = torch.matmul(edge_attr, self.weight_edge)

        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)

        alpha = self._alpha
        self._alpha = None

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            return out, (edge_index, alpha)
        else:
            return out

    def message(self, x_i, x_j, edge_attr):
        score = self.bilinear(x_i, x_j)
        if edge_attr is not None:
            vec = torch.cat((x_i, edge_attr, x_j), 1)
            block_score = self.linear(vec)  # bias already included
        else:
            vec = torch.cat((x_i, x_j), 1)
            block_score = self.linear(vec)
        scores = score + block_score
        alpha = torch.tanh(scores)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        dim_split = self.out_channels // self.slices
        out = torch.max(x_j, edge_attr).view(-1, self.slices, dim_split)

        out = out * alpha.view(-1, self.slices, 1)
        out = out.view(-1, self.out_channels)
        return out


class CustomGNN(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        edge_dim,
        num_layers,
        slices,
        dropout,
        f_att=False,
        r=4,
    ):
        super(CustomGNN, self).__init__()

        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout

        self.f_att = f_att

        # atom feature transformation
        self.lin_a = Linear(in_channels, hidden_channels)
        self.lin_b = Linear(edge_dim, hidden_channels)

        # convs block
        self.atom_convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = NTNConv(
                hidden_channels,
                hidden_channels,
                slices=slices,
                dropout=dropout,
                edge_dim=hidden_channels,
            )
            self.atom_convs.append(conv)

        self.lin_gate = Linear(3 * hidden_channels, hidden_channels)

        if self.f_att:
            self.feature_att = FeatureAttention(channels=hidden_channels, reduction=r)

        self.out = Linear(hidden_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.lin_a.reset_parameters()
        self.lin_b.reset_parameters()

        for conv in self.atom_convs:
            conv.reset_parameters()

        self.lin_gate.reset_parameters()

        if self.f_att:
            self.feature_att.reset_parameters()

        self.out.reset_parameters()

    def forward(self, data):
        # get mol input
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        batch = data.batch

        x = F.relu(self.lin_a(x))  # (N, 46) -> (N, hidden_channels)
        edge_attr = F.relu(self.lin_b(edge_attr))  # (N, 10) -> (N, hidden_channels)

        # mol conv block
        for i in range(0, self.num_layers):
            h = F.relu(self.atom_convs[i](x, edge_index, edge_attr))
            beta = self.lin_gate(torch.cat([x, h, x - h], 1)).sigmoid()
            x = beta * x + (1 - beta) * h
            if self.f_att:
                x = self.feature_att(x, batch)

        mol_vec = global_add_pool(x, batch).relu_()
        out = F.dropout(mol_vec, p=self.dropout, training=self.training)
        return self.out(out)


class FocalLoss(nn.Module):
    """
    Multi-class Focal Loss
    """

    def __init__(self, gamma=2):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, input, target):
        """
        input: [N, C], float32
        target: [N, ], int64
        """
        eps = 1e-10  # Avoid inf for log
        p = F.softmax(input, dim=1)
        logits = (1 - p) ** self.gamma * torch.log(p + eps)
        loss = F.nll_loss(logits, target)
        return loss

In [7]:
def get_data_distribution(df):
    best_dist = []
    for task in list(df.columns.values)[1:]:
        # print(f"Processing {task}...")
        count = dict(df[task].value_counts(dropna=True))
        # print(count)
        zeros = count[0.0]
        ones = count[1.0]

        dist = ones / (ones + zeros)
        # print(f"Total: {ones+zeros}, ones: {ones}, ratio: {dist}")
        if dist > 0.5 and dist < 0.6:
            best_dist.append((task, dist))
    # print(f"Best task: {_label}, distribution: {best_dist}")
    return best_dist

In [11]:
df = pd.read_csv("./toxcast_data.csv")

print(f"Number of rows in original dataset: {df.shape[0]}")
print(f"Number of tasks: {df.shape[1] - 1}")
print(f"Some Tasks: {list(df.columns.values)[1:5]}")

best_dist = get_data_distribution(df)
print(f"Number of tasks that aren't imbalanced: {len(best_dist)}")
print(f"Some Tasks: {best_dist[:5]}")
task = ["CLD_CYP1A2_48hr"]# best_dist[0]
sub_df = df[[col for col in df.columns if col in [task[0], "smiles"]]]
sub_df = sub_df.dropna(axis=0)
sub_df



Number of rows in original dataset: 8597
Number of tasks: 617
Some Tasks: ['ACEA_T47D_80hr_Negative', 'ACEA_T47D_80hr_Positive', 'APR_HepG2_CellCycleArrest_24h_dn', 'APR_HepG2_CellCycleArrest_24h_up']
Number of tasks that aren't imbalanced: 32
Some Tasks: [('CLD_CYP1A2_48hr', 0.5148514851485149), ('CLD_CYP2B6_24hr', 0.5761589403973509), ('CLD_CYP2B6_48hr', 0.5412541254125413), ('NCCT_HEK293T_CellTiterGLO', 0.5386666666666666), ('NCCT_QuantiLum_inhib_dn', 0.5517241379310345)]


Unnamed: 0,smiles,CLD_CYP1A2_48hr
46,OB(O)O,0.0
71,CCNC1=NC(N)=NC(Cl)=N1,1.0
75,COC(=O)C1=C(N(C)N=C1Cl)S(=O)(=O)NC(=O)NC1=NC(O...,1.0
92,ClC1=CC=CC=C1NC1=NC(Cl)=NC(Cl)=N1,0.0
94,CC(OC1=CC(Cl)=CC=C1)C(O)=O,0.0
...,...,...
8400,CSC(=O)C1=C(N=C(C(C(=O)SC)=C1CC(C)C)C(F)(F)F)C...,1.0
8458,CCOP(=O)(SC(C)CC)N1CCSC1=O,1.0
8468,CC1=NC2=NC(=NN2C=C1)S(=O)(=O)NC1=C(F)C=CC=C1F,0.0
8483,NC1=C(Cl)C=C(C=C1Cl)[N+]([O-])=O,1.0


In [43]:
result_dir = "tasks-final-2"

### Load model
unrelated_smiles = "O=O"
unrelated_mol = Chem.MolFromSmiles(unrelated_smiles)
n_node_features = len(get_atom_features(unrelated_mol.GetAtomWithIdx(0)))
model = CustomGNN(
    in_channels=n_node_features,
    hidden_channels=128,
    out_channels=2,
    edge_dim=10,
    num_layers=3,
    dropout=0.3,
    slices=2,
    f_att=True,
    r=4,
)
model.load_state_dict(
    torch.load(
        os.path.join(result_dir, task[0], "best_models/best_auc_model.pt"),
        weights_only=True,
    )
)

<All keys matched successfully>

In [59]:
def predict_one_sample(smiles):
    
    # smiles = sub_df.sample(n=1)["smiles"].values[0]
    
    # print("original smiles:", smiles)
    device = torch.device("cpu")
    
    standard_smiles = preprocess(smiles)
    # print("standard smiles:", standard_smiles)
    data_obj = smiles_to_data(standard_smiles)[0]
    
    print("data_obj:", data_obj)
    
    data_obj.batch = torch.tensor([0], dtype=torch.long)
    data_obj = data_obj.to(device)
    res = model(data_obj)
    probs = F.softmax(res, dim=1)
    return probs
    



In [60]:
# smiles = "CCNC1=NC(N)=NC(Cl)=N1"

ytrue, ypred = [], []
for idx, row in sub_df.iterrows():
    # print(row)
    smiles, target = row['smiles'], row[task[0]]
    smiles
    if "FAIL" in smiles:
        continue
    probs = predict_one_sample(smiles)
    predict = torch.argmax(probs, dim=1)[0]
    ytrue.append(target)
    ypred.append(predict)
    print(f"Target: {target}, predict: {predict}, probs: {np.round(probs[0].detach().numpy(), 3)}")
    # print(probs)
    # break

# predict_one_sample(smiles)

smiles: OB(O)O
data_obj: Data(x=[4, 79], edge_index=[2, 6], edge_attr=[6, 10])
Target: 0.0, predict: 0, probs: [0.995 0.005]
smiles: CCNc1nc(N)nc(Cl)n1
data_obj: Data(x=[11, 79], edge_index=[2, 22], edge_attr=[22, 10])
Target: 1.0, predict: 1, probs: [0.002 0.998]
smiles: COC(=O)c1c(Cl)nn(C)c1S(=O)(=O)NC(=O)Nc1nc(OC)cc(OC)n1
data_obj: Data(x=[28, 79], edge_index=[2, 58], edge_attr=[58, 10])
Target: 1.0, predict: 1, probs: [0.003 0.997]
smiles: Clc1nc(Cl)nc(Nc2ccccc2Cl)n1
data_obj: Data(x=[16, 79], edge_index=[2, 34], edge_attr=[34, 10])
Target: 0.0, predict: 1, probs: [0. 1.]
smiles: CC(Oc1cccc(Cl)c1)C(=O)O
data_obj: Data(x=[13, 79], edge_index=[2, 26], edge_attr=[26, 10])
Target: 0.0, predict: 0, probs: [1. 0.]
smiles: COC(=O)c1ccccc1S(=O)(=O)NC(=O)N(C)c1nc(C)nc(OC)n1
data_obj: Data(x=[27, 79], edge_index=[2, 56], edge_attr=[56, 10])
Target: 1.0, predict: 1, probs: [0.106 0.894]
smiles: CC(C)OC(=O)Nc1cccc(Cl)c1
data_obj: Data(x=[14, 79], edge_index=[2, 28], edge_attr=[28, 10])
Target:

In [61]:
auc = roc_auc_score(ytrue, ypred)
print(auc)


0.9197012138188609


In [51]:
from sklearn.manifold import TSNE
target = torch.randint(0, 2, (1000,))
feats = torch.matmul(F.one_hot(target).float(), criterion_cent.centers).detach().numpy()
X_embedded = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=50).fit_transform(feats)
target = target.reshape(-1, 1)

# Define colors for each class
colors = ['red' if label == 0 else 'blue' for label in target]
plt.scatter(x=X_embedded[:, 0], y=X_embedded[:, 1], c=colors, alpha=0.7, edgecolors='k')
plt.xlabel("x1-axis")
plt.ylabel("x2-axis")
plt.title("plot")
plt.show()

NameError: name 'criterion_cent' is not defined