In [1]:
import warnings
from rdkit import RDLogger

warnings.simplefilter(action='ignore', category=FutureWarning)
RDLogger.DisableLog('rdApp.*')

In [2]:
import pandas as pd
import numpy as np
import torch
import os
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, rdDepictor, rdDistGeom, MACCSkeys, rdMolDescriptors
from torch_geometric.data import InMemoryDataset, download_url, extract_gz, Data, DataLoader, Batch
from torch.utils.data import Dataset
from torch_geometric import utils as pyg_utils
from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer

# 필요한 별도 파일
from splitters import random_split, scaffold_split ## split
from download_preprocess import CustomMoleculeNet, atom_features, EDGE_FEATURES

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
2024-11-30 16:32:01.098732: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-30 16:32:01.103520: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-11-30 16:32:01.113874: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732984321.131128   94737 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732984321.136195   94737 cuda_blas.cc:1418] Unable to register cuBLAS f

Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


### Data Download

In [3]:
# dataset 다운로드 - smiles와 label을 저장
dataset_name = 'bace'
dataset = CustomMoleculeNet('dataset', name=dataset_name.upper())

Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv
Processing...


Data processed and saved successfully.


Done!


In [4]:
data_list = torch.load('dataset/' + dataset_name + '/processed/smiles_labels.pt')

In [5]:
len(data_list), data_list[0]['smiles'], data_list[0]['label']

(1513,
 'O1CC[C@@H](NC(=O)[C@@H](Cc2cc3cc(ccc3nc2N)-c2ccccc2C)C)CC1(C)C',
 [[1.0]])

### Graph Data로 변환

In [7]:
# smiles를 graph data로 변환
def smiles_to_graph(data_list: list, with_hydrogen: bool = False, kekulize: bool = False) :

    graph_list = []
    for data in data_list:
        smiles = data['smiles']
        label = data['label']
        
        mol = Chem.MolFromSmiles(smiles)
        
        # smiles가 객체로 변환되지 않는 경우 filtering
        if mol is None :
            continue
        else: 
            # mol -> graph
            if with_hydrogen:
                mol = Chem.AddHs(mol)
            if kekulize:
                Chem.Kekulize(mol)
        
            xs: List[List[int]] = []
            for atom in mol.GetAtoms():
                current_atom_feat = atom_features(atom)
                xs.append(current_atom_feat)
        
            x = torch.tensor(xs, dtype=torch.long).view(-1, 133)
        
            edge_indices, edge_attrs = [], []
            for bond in mol.GetBonds():
                i = bond.GetBeginAtomIdx()
                j = bond.GetEndAtomIdx()
        
                edge_feature = [EDGE_FEATURES['possible_bonds'].index(bond.GetBondType())] + [
                    EDGE_FEATURES['possible_bond_dirs'].index(bond.GetBondDir())]
        
                edge_indices += [[i, j], [j, i]]
                edge_attrs += [edge_feature, edge_feature]
        
            edge_index = torch.tensor(edge_indices)
            edge_index = edge_index.t().to(torch.long).view(2, -1)
            edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 2)
        
            if edge_index.numel() > 0:  # Sort indices.
                perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
                edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

            # label -> y
            y = torch.tensor([label], dtype=torch.float).view(1, -1)
            graph_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles, y=y))
    
    return graph_list

In [8]:
graph_list = smiles_to_graph(data_list)

In [9]:
len(data_list),len(graph_list)

(1513, 1513)

In [10]:
seed = 1
train_dataset, valid_dataset, test_dataset = random_split(graph_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=seed)

In [11]:
def loader_dataset(data_list, batch_size, shuffle=False):
    """
    DataLoader로 변경
    """
    collate = Batch.from_data_list(data_list)
    loader = DataLoader(data_list, batch_size=batch_size, collate_fn=collate, shuffle=shuffle)

    return loader

In [12]:
batch_size = 256
train_loader = loader_dataset(data_list=train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = loader_dataset(data_list=valid_dataset, batch_size=batch_size, shuffle=False) 
test_loader = loader_dataset(data_list=test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
# loader 확인
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[8647, 133], edge_index=[2, 18706], edge_attr=[18706, 2], y=[256, 1], smiles=[256], batch=[8647], ptr=[257])


### 문자열 token으로 변환

In [14]:
def smiles_to_token(data_list):
    # token 종류 수집
    vocab = []
    max_len = 0
    tokenizer = BasicSmilesTokenizer()
    for data in data_list:
        tokens = tokenizer.tokenize(data['smiles'])
        max_len = max(max_len, len(tokens))
        vocab += tokens
        
    uniq_vocab = sorted(set(vocab))
    smiles_vocab = {v: i for i, v in enumerate(uniq_vocab)}
    smiles_vocab['Unk'] = len(smiles_vocab)

    # token으로 변환
    tokens_list = []
    for data in data_list :
        label = data['label']
        tokens = [smiles_vocab[token] for token in tokenizer.tokenize(data['smiles'])]
        pad_len = max_len -len(tokens)
        tokens = tokens + ([0]*pad_len) 
        
        x = torch.tensor(tokens, dtype=torch.float).unsqueeze(1)
        y = torch.tensor([label], dtype=torch.float).view(1, -1)
        tokens_list.append(Data(x=x ,y=y))
        
    return tokens_list

In [15]:
tokens_list = smiles_to_token(data_list)
for tokens in tokens_list:
    print(f"length of each tokens: {len(tokens.x)}")
    print(tokens)
    break

length of each tokens: 178
Data(x=[178, 1], y=[1, 1])


In [16]:
seed = 1
train_dataset, valid_dataset, test_dataset = random_split(tokens_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=seed)

In [17]:
batch_size = 256
train_loader = loader_dataset(data_list=train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = loader_dataset(data_list=valid_dataset, batch_size=batch_size, shuffle=False) 
test_loader = loader_dataset(data_list=test_dataset, batch_size=batch_size, shuffle=False)

In [18]:
# loader 확인
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[45568, 1], y=[256, 1], batch=[45568], ptr=[257])


### Fingerprint로 변환

In [19]:
# choice : 'rdkit', 'maccs', 'morgan' 중 변환할  fingerprint 선택
def smiles_to_fingerprint(data_list, choice):
    fp_list = []
    for data in data_list:
        smiles = data['smiles']
        label = data['label']
        
        molecule = Chem.MolFromSmiles(smiles)
        
        if molecule is None :
            continue
        else:     
            if choice == 'rdkit':
                rdkit_fp = Chem.RDKFingerprint(molecule)
                x = rdkit_fp
            elif choice == 'maccs':
                maccs_fp = MACCSkeys.GenMACCSKeys(molecule)
                x = maccs_fp
            elif choice == 'morgan':
                morgan_fp = AllChem.GetMorganFingerpirntAsBitVect(molecule, radius=2, nBits=1024)
                x = morgan_fp
                
            x = torch.tensor(x, dtype=torch.float).unsqueeze(1)
            y = torch.tensor([label], dtype=torch.float).view(1, -1)
            fp_list.append(Data(x=x, y=y))
                           
    return fp_list

In [20]:
# 데이터 불러오기
fp_list = smiles_to_fingerprint(data_list, 'rdkit')

In [21]:
seed = 1
train_dataset, valid_dataset, test_dataset = random_split(fp_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=seed)

In [22]:
batch_size = 256
train_loader = loader_dataset(data_list=train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = loader_dataset(data_list=valid_dataset, batch_size=batch_size, shuffle=False) 
test_loader = loader_dataset(data_list=test_dataset, batch_size=batch_size, shuffle=False)

In [23]:
# loader 확인
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[524288, 1], y=[256, 1], batch=[524288], ptr=[257])


### Descriptors로 변환

In [25]:
def smiles_to_descriptors(data_list):

    des_list = []
    for data in data_list:
        smiles = data['smiles']
        label = data['label']
        
        molecule = Chem.MolFromSmiles(smiles)

        # filtering
        if molecule is None :
            continue
        else:
            descriptors_dict= Descriptors.CalcMolDescriptors(molecule)
            descriptor_vec = np.array([value for value in descriptors_dict.values()]) # dictionary의 value만 추출하여 vector 생성
            x = torch.tensor(descriptor_vec, dtype=torch.float).unsqueeze(1)
            y = torch.tensor([label], dtype=torch.float).view(1, -1)
            des_list.append(Data(x=x, y=y))
        
    return des_list

In [26]:
des_list = smiles_to_descriptors(data_list)

In [27]:
des_list[0]

Data(x=[210, 1], y=[1, 1])

In [28]:
seed = 1
train_dataset, valid_dataset, test_dataset = random_split(des_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=seed)

In [29]:
batch_size = 256
train_loader = loader_dataset(data_list=train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = loader_dataset(data_list=valid_dataset, batch_size=batch_size, shuffle=False) 
test_loader = loader_dataset(data_list=test_dataset, batch_size=batch_size, shuffle=False)

In [30]:
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[53760, 1], y=[256, 1], batch=[53760], ptr=[257])


### 3D Graph로 변환

In [31]:
def molecule_to_3d(data_list):
    
    graph3d_list = []
    for data in data_list:
        smiles = data['smiles']
        label = data['label']
        
        molecule = Chem.MolFromSmiles(smiles)

        # filtering
        if molecule is None :
            continue
        else:              
            atom_info = [(atom.GetIdx(), atom.GetSymbol()) for atom in molecule.GetAtoms()]             
            status = rdDistGeom.EmbedMolecule(molecule)
            
            # 3d graph 변환 filtering
            if status != 0: # 0이 아닌 경우 변환 실패
                continue
            else:
                conf = molecule.GetConformer()
                pos = np.array([conf.GetAtomPosition(idx) for idx, symbol in atom_info])
            
                graph_data = pyg_utils.from_smiles(smiles)
                graph_data.pos = pos
                graph_data.y = torch.tensor([label], dtype=torch.float).view(1, -1)
                graph3d_list.append(graph_data)
        
    return graph3d_list

In [32]:
# 데이터 불러오기
graph3d_list = molecule_to_3d(data_list)

In [33]:
seed = 1
train_dataset, valid_dataset, test_dataset = random_split(graph3d_list, null_value=0, frac_train=0.8, frac_valid=0.1, frac_test=0.1, seed=seed)

In [34]:
batch_size = 256
train_loader = loader_dataset(data_list=train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = loader_dataset(data_list=valid_dataset, batch_size=batch_size, shuffle=False) 
test_loader = loader_dataset(data_list=test_dataset, batch_size=batch_size, shuffle=False)

In [35]:
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[8588, 9], edge_index=[2, 18586], edge_attr=[18586, 3], smiles=[256], pos=[256], y=[256, 1], batch=[8588], ptr=[257])
