In [1]:
import torch
from torch_geometric.data import Data
from rdkit import Chem

In [7]:
import pandas as pd

df = pd.read_csv("qs_inhibitors_cleaned.csv")

print(df.shape)
print(df.columns)

(168, 8)
Index(['smiles_canonical', 'activity_label', 'MolWt', 'LogP', 'TPSA', 'HBD',
       'HBA', 'RB'],
      dtype='object')


In [8]:
def atom_features(atom):
    return [
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        int(atom.IsInRing()),
        atom.GetHybridization().real
    ]

In [10]:
def bond_features(bond):
    bond_type = bond.GetBondType()
    return [
        int(bond_type == Chem.rdchem.BondType.SINGLE),
        int(bond_type == Chem.rdchem.BondType.DOUBLE),
        int(bond_type == Chem.rdchem.BondType.TRIPLE),
        int(bond_type == Chem.rdchem.BondType.AROMATIC),
        int(bond.GetIsConjugated()),
        int(bond.IsInRing())
    ]


In [11]:
def smiles_to_graph(smiles, label):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    x = []
    for atom in mol.GetAtoms():
        x.append(atom_features(atom))
    x = torch.tensor(x, dtype=torch.float)

    edge_index = []
    edge_attr = []

    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        bf = bond_features(bond)

        edge_index.append([i, j])
        edge_index.append([j, i])

        edge_attr.append(bf)
        edge_attr.append(bf)

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    y = torch.tensor([label], dtype=torch.long)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)


In [13]:
graphs = []

for _, row in df.iterrows():
    g = smiles_to_graph(row["smiles_canonical"], row["activity_label"])
    if g is not None:
        graphs.append(g)

print("Number of graphs:", len(graphs))

Number of graphs: 168


In [14]:
g = graphs[0]
print(g)
print("Node feature shape:", g.x.shape)
print("Edge index shape:", g.edge_index.shape)
print("Edge attr shape:", g.edge_attr.shape)
print("Label:", g.y)


Data(x=[17, 6], edge_index=[2, 34], edge_attr=[34, 6], y=[1])
Node feature shape: torch.Size([17, 6])
Edge index shape: torch.Size([2, 34])
Edge attr shape: torch.Size([34, 6])
Label: tensor([0])


In [15]:
import torch

torch.save(graphs, "qs_graphs.pt")
