- **Created by Zehao Li (Takuho Ri)**
- **Created on 2025-02-14 (Fri)  14:57:03 (+09:00)**

transform all molecule from smiles to graph (for AttentiveFP)

In [12]:
import os
import sys
from tqdm.notebook import tqdm
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch_geometric.data import Data

project_path = os.getcwd().split("/prep")[0]
sys.path.append(project_path)

In [13]:
from src.AttentiveFP.featurizer import _prep_feats

In [14]:
csv_dir = os.path.join(project_path, "data", "tox_csv")
graph_dir = os.path.join(project_path, "data", "AttentiveFP_graphs")

os.makedirs(graph_dir, exist_ok=True)

### prepare graphs except BACE

In [5]:
for data in os.listdir(csv_dir):
    if data == "bace.csv":
        continue
    print(f"=== {data.replace(".csv", "")} ===")
    df = pd.read_csv(os.path.join(csv_dir, data), index_col=0)
    smiles = np.array(df["cano_smi"])
    ys = np.array(df[df.columns.values[2:]])
    assert len(smiles) == len(ys), "difficult length between smiles and y !"
    graphs = []
    for smi, y in tqdm(zip(smiles, ys), total=len(smiles)):
        y = torch.tensor(y).float().view(1, -1)
        atom_feats, edge_idx, edge_feats = _prep_feats(smi)
        graph = Data(atom_feats, edge_idx, edge_feats, y)
        graphs.append(graph)
    with open(os.path.join(graph_dir, f"{data.replace(".csv", ".pkl")}"), "wb") as f:
        pickle.dump(graphs, f)

=== sider ===


  0%|          | 0/1384 [00:00<?, ?it/s]



=== herg_karim ===


  0%|          | 0/13445 [00:00<?, ?it/s]

=== tox21_M ===


  0%|          | 0/7811 [00:00<?, ?it/s]



=== cyp3a4_inhib ===


  0%|          | 0/12319 [00:00<?, ?it/s]

=== cyp2c9_inhib ===


  0%|          | 0/12083 [00:00<?, ?it/s]

=== toxcast_M ===


  0%|          | 0/8558 [00:00<?, ?it/s]



=== clintox_M ===


  0%|          | 0/1468 [00:00<?, ?it/s]

=== ld50 ===


  0%|          | 0/7385 [00:00<?, ?it/s]

### prepare graphs for BACE

In [7]:
print("=== bace ===")
df = pd.read_csv(os.path.join(csv_dir, "bace.csv"), index_col=0)
smiles = np.array(df["cano_smi"])
ys_c = np.array(df[["Class"]])
ys_r = np.array(df[["pIC50"]])
assert len(smiles) == len(ys_c) == len(ys_r), "difficult length between smiles and y !"
graphs_c = []
graphs_r = []
cnt = 0
for smi, y_c, y_r in tqdm(zip(smiles, ys_c, ys_r), total=len(smiles)):
    y_c = torch.tensor(y_c).float().view(1, -1)
    y_r = torch.tensor(y_r).float().view(1, -1)
    atom_feats, edge_idx, edge_feats = _prep_feats(smi)
    graph_c = Data(atom_feats, edge_idx, edge_feats, y_c)
    graph_r = Data(atom_feats, edge_idx, edge_feats, y_r)
    graphs_c.append(graph_c)
    graphs_r.append(graph_r)
    if cnt == 0:
        print(y_c, y_c.shape)
        print(y_r, y_r.shape)
        cnt += 1
with open(os.path.join(graph_dir, f"bace_c.pkl"), "wb") as f:
    pickle.dump(graphs_c, f)
with open(os.path.join(graph_dir, f"bace_r.pkl"), "wb") as f:
    pickle.dump(graphs_r, f)

=== bace ===


  0%|          | 0/1513 [00:00<?, ?it/s]

tensor([[1.]]) torch.Size([1, 1])
tensor([[9.1549]]) torch.Size([1, 1])


In [11]:
with open("/workspace/ToxPred/MolKAN/molkan/data/AttentiveFP_graphs/tox21_M.pkl", "rb") as f:
    test = pickle.load(f)
print(test[0].y, test[0].y.shape)

tensor([[0., 0., 1., nan, nan, 0., 0., 1., 0., 0., 0., 0.]]) torch.Size([1, 12])


### add herg small for test

In [18]:
print("=== herg small ===")
df = pd.read_csv("/workspace/ToxPred/MolKAN/molkan/data/original_csv/TDC/Tox/herg.csv", index_col=0)
smiles = np.array(df["Drug"])
ys = np.array(df[["Y"]])
assert len(smiles) == len(ys), "difficult length between smiles and y !"
graphs = []
cnt = 0
for smi, y in tqdm(zip(smiles, ys), total=len(smiles)):
    y = torch.tensor(y).float().view(1, -1)
    atom_feats, edge_idx, edge_feats = _prep_feats(smi)
    graph = Data(atom_feats, edge_idx, edge_feats, y)
    graphs.append(graph)
    if cnt == 0:
        print(y, y.shape)
        cnt += 1
with open(os.path.join(graph_dir, f"herg_small.pkl"), "wb") as f:
    pickle.dump(graphs, f)

=== herg small ===


  0%|          | 0/655 [00:00<?, ?it/s]

tensor([[1.]]) torch.Size([1, 1])




### add herg central for test

In [None]:
print("=== herg central ===")
df = pd.read_csv("/workspace/ToxPred/MolKAN/molkan/data/original_csv/TDC/Tox/herg_central.csv", index_col=0)
smiles = np.array(df["X"])
ys = np.array(df[["hERG_inhib"]])
assert len(smiles) == len(ys), "difficult length between smiles and y !"
graphs = []
cnt = 0
for smi, y in tqdm(zip(smiles, ys), total=len(smiles)):
    y = torch.tensor(y).float().view(1, -1)
    atom_feats, edge_idx, edge_feats = _prep_feats(smi)
    graph = Data(atom_feats, edge_idx, edge_feats, y)
    graphs.append(graph)
    if cnt == 0:
        print(y, y.shape)
        cnt += 1
with open(os.path.join(graph_dir, f"herg_central.pkl"), "wb") as f:
    pickle.dump(graphs, f)

=== herg small ===


  0%|          | 0/306893 [00:00<?, ?it/s]

tensor([[0.]]) torch.Size([1, 1])
