In [1]:
import json
import os
import pickle
import random
from collections import OrderedDict
from math import sqrt
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from rdkit import Chem
from rdkit.Chem import AllChem, MolFromSmiles
from scipy import stats
from sklearn.model_selection import train_test_split
from torch import Tensor
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric import data as DATA
from torch_geometric.data import Dataset, Data, DataLoader, InMemoryDataset, Batch
from torch_geometric.nn import GATConv, GCNConv, global_max_pool as gmp, global_mean_pool as gap

In [2]:
#원본
def atom_features(atom, explicit_H = False, use_chirality=True):
    symbol_one_hot = one_of_k_encoding_unk(
      atom.GetSymbol(), #37
      ['Al', 'Sb', 'Cl', 'Te', 'Si', 'Br', 'Cd', 'S', 'Mn', 'Ba',
       'Ga', 'Cr', 'I', 'Mo', 'B', 'Te', 'As', 'Sb', 'N', 'V',
       'Sn', 'P', 'Sb', 'Ni', 'Pb', 'Se', 'In', 'Be', 'F','Ti',
       'O', 'Hg', 'H', 'C', 'Co', 'Fe', 'Zr'])
    
    degree_one_hot = one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4 ,5]) #차수
    formal_charge_one_hot = one_of_k_encoding_unk(atom.GetFormalCharge(),[-1, 0, 1]) #형식전하
    explicit_valence_one_hot = one_of_k_encoding(atom.GetExplicitValence(), [0, 1, 2, 3, 4, 5, 6]) #명시적원자가
    implicit_valence_one_hot = one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])
    hybridization_one_hot = one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3]) #혼성화
    aromatic_one_hot = [atom.GetIsAromatic()]

    radical_one_hot = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1, 2])

    results = radical_one_hot +symbol_one_hot + degree_one_hot + explicit_valence_one_hot +implicit_valence_one_hot+formal_charge_one_hot + hybridization_one_hot + aromatic_one_hot
#    results = radical_one_hot +symbol_one_hot + degree_one_hot + explicit_valence_one_hot +formal_charge_one_hot + hybridization_one_hot + aromatic_one_hot

    #false인 경우 명시적 수소수가 아니라 총 수소수를 반환
    if not explicit_H:
        total_num_hs_one_hot = one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4 ])
        results = results + total_num_hs_one_hot
        
    if use_chirality:
        try:
#             print(atom.GetProp('_CIPCode'))  # 카이랄성 정보 출력
            chirality_one_hot = one_of_k_encoding_unk(
                atom.GetProp('_CIPCode'),
                ['R', 'S'])
            results = results + chirality_one_hot + [atom.HasProp('_ChiralityPossible')]
        except:
#             print("Chirality information not available.")  # 카이랄성 정보가 없는 경우
            results = results + [False, False] + [atom.HasProp('_ChiralityPossible')]
    
    return np.array(results)
    

def one_of_k_encoding(x, allowable_set):
    if x not in allowable_set:
        raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))


def smile_to_graph(smile):
    mol = Chem.MolFromSmiles(smile)

    if mol is None:  # if the molecule is not parsed correctly by RDKit, return None
        return None
    
    c_size = mol.GetNumAtoms()
    
    features = []
    for atom in mol.GetAtoms():
        feature = atom_features(atom)
        features.append(feature / sum(feature))
#     print("Length of atom feature:", len(feature))  # 각 원자의 feature vector의 차원을 출력

    edges = []
    for bond in mol.GetBonds():
        edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
    g = nx.Graph(edges).to_directed()    
    edge_index = []
    for e1, e2 in g.edges:
        edge_index.append([e1, e2])

    if not edge_index:  # check if edge_index is empty
        return None

    return c_size, features, edge_index



In [3]:

# def atom_features(atom, explicit_H = False, use_chirality=True):
#     symbol_one_hot = one_of_k_encoding_unk(
#       atom.GetSymbol(), #37
#       [ 'Sb', 'Cl', 'Te', 'Si', 'Br', 'Cd', 'S',
#        'Ga', 'I', 'Mo', 'Te', 'As', 'N', 'V',
#        'Sn', 'P', 'Sb', 'Ni',  'F','Ti',
#        'O', 'Hg', 'C', 'Co'])
    
#     degree_one_hot = one_of_k_encoding(atom.GetDegree(), [1, 2, 3, 4 ]) #차수
#     formal_charge_one_hot = one_of_k_encoding_unk(atom.GetFormalCharge(),[-1, 0, 1]) #형식전하
#     explicit_valence_one_hot = one_of_k_encoding(atom.GetExplicitValence(), [1, 2, 3, 4, 5, 6]) #명시적원자가
#     implicit_valence_one_hot = one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3])
#     hybridization_one_hot = one_of_k_encoding_unk(atom.GetHybridization(), [
#                 Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
#                 Chem.rdchem.HybridizationType.SP3]) #혼성화
#     aromatic_one_hot = [atom.GetIsAromatic()]

#     radical_one_hot = one_of_k_encoding_unk(atom.GetNumRadicalElectrons(), [0, 1])

#     results = radical_one_hot +symbol_one_hot + degree_one_hot + explicit_valence_one_hot +implicit_valence_one_hot+formal_charge_one_hot + hybridization_one_hot + aromatic_one_hot
# #    results = radical_one_hot +symbol_one_hot + degree_one_hot + explicit_valence_one_hot +formal_charge_one_hot + hybridization_one_hot + aromatic_one_hot

#     #false인 경우 명시적 수소수가 아니라 총 수소수를 반환
#     if not explicit_H:
#         total_num_hs_one_hot = one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3])
#         results = results + total_num_hs_one_hot
        
#     if use_chirality:
#         try:
# #             print(atom.GetProp('_CIPCode'))  # 카이랄성 정보 출력
#             chirality_one_hot = one_of_k_encoding_unk(
#                 atom.GetProp('_CIPCode'),
#                 ['R', 'S'])
#             results = results + chirality_one_hot + [atom.HasProp('_ChiralityPossible')]
#         except:
# #             print("Chirality information not available.")  # 카이랄성 정보가 없는 경우
#             results = results + [False, False] + [atom.HasProp('_ChiralityPossible')]
    
#     return np.array(results)
    

# def one_of_k_encoding(x, allowable_set):
#     if x not in allowable_set:
#         raise Exception("input {0} not in allowable set{1}:".format(x, allowable_set))
#     return list(map(lambda s: x == s, allowable_set))

# def one_of_k_encoding_unk(x, allowable_set):
#     """Maps inputs not in the allowable set to the last element."""
#     if x not in allowable_set:
#         x = allowable_set[-1]
#     return list(map(lambda s: x == s, allowable_set))


# def smile_to_graph(smile):
#     mol = Chem.MolFromSmiles(smile)

#     if mol is None:  # if the molecule is not parsed correctly by RDKit, return None
#         return None
    
#     c_size = mol.GetNumAtoms()
    
#     features = []
#     for atom in mol.GetAtoms():
#         feature = atom_features(atom)
#         features.append(feature / sum(feature))
# #     print("Length of atom feature:", len(feature))  # 각 원자의 feature vector의 차원을 출력

#     edges = []
#     for bond in mol.GetBonds():
#         edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
#     g = nx.Graph(edges).to_directed()    
#     edge_index = []
#     for e1, e2 in g.edges:
#         edge_index.append([e1, e2])

#     if not edge_index:  # check if edge_index is empty
#         return None

#     return c_size, features, edge_index



In [4]:
import pandas as pd

# Load the CSV file
file_path = '/data/home/dbswn0814/2025JCM/data/multi task/tissue_merged_data.csv'
df = pd.read_csv(file_path)

# Remove rows where the 'SMILES' column contains NaN values
df = df.dropna(subset=['SMILES'])
print(df)

     Unnamed: 0                                             SMILES  liv  lun  \
0             0                                       C(/C=C/Cl)Cl  1.0  1.0   
1             1                                       C(C(CBr)Br)O  1.0  1.0   
2             2            C(C(CBr)Br)OP(=O)(OCC(CBr)Br)OCC(CBr)Br  1.0  1.0   
3             3                                      C(C(CCl)Cl)Cl  1.0  0.0   
4             4                                  C(C(CO)(CBr)CBr)O  1.0  1.0   
..          ...                                                ...  ...  ...   
338         338        C1(=C(C(=C(C(=C1Cl)Cl)Cl)Cl)Cl)[N+](=O)[O-]  0.0  0.0   
339         339                   C1(=C(C(=NC(=C1Cl)Cl)C(=O)O)Cl)N  0.0  0.0   
340         340                    C1(=C(C(C(=C1Cl)Cl)(Cl)Cl)Cl)Cl  0.0  0.0   
341         341                               C1(=NC(=NC(=N1)N)N)N  0.0  0.0   
342         342  C1(=O)C2(C3(C4(C1(C5(C2(C3(C(C45Cl)(Cl)Cl)Cl)C...  0.0  0.0   

     sto  mgl  
0    1.0  0.0  
1    1.

In [5]:
# # 데이터에서 사용할 원소 기호 리스트
# allowed_symbols = set(['Al', 'Sb', 'Cl', 'Te', 'Si', 'Br', 'Cd', 'S', 'Mn', 'Ba',
#        'Ga', 'Cr', 'I', 'Mo', 'B', 'Te', 'As', 'Sb', 'N', 'V',
#        'Sn', 'P', 'Sb', 'Ni', 'Pb', 'Se', 'In', 'Be', 'F','Ti',
#        'O', 'Hg', 'H', 'C', 'Co', 'Fe', 'Zr'])

# # 데이터셋에서 SMILES 가져오기
# smiles_list = df['SMILES'].tolist()

# # SMILES 데이터에서 실제로 등장하는 원소 찾기
# found_symbols = set()

# for smiles in smiles_list:
#     mol = Chem.MolFromSmiles(smiles)  # SMILES를 RDKit 분자로 변환
#     if mol:
#         for atom in mol.GetAtoms():
#             found_symbols.add(atom.GetSymbol())  # 원자 기호 추출

# # ✅ 실제 데이터에 등장하지 않는 원소 찾기
# unused_symbols = allowed_symbols - found_symbols  # `allowed_symbols` 중 SMILES에 없는 원소

# # 결과 출력
# if not unused_symbols:
#     print("`allowed_symbols`에 불필요한 원소가 없음.")
# else:
#     print("`allowed_symbols`에 포함되어 있지만 데이터에 없는 원소 발견!")
#     print(" 데이터에서 사용되지 않은 원소:", unused_symbols)

    
# # 허용된 값 리스트들
# allowed_degrees = set([0, 1, 2, 3, 4 ,5])  # 원자 차수
# allowed_formal_charges = set([-1, 0, 1])  # 형식 전하
# allowed_explicit_valences = set( [0, 1, 2, 3, 4, 5, 6])  # 명시적 원자가
# allowed_implicit_valences = set( [0, 1, 2, 3, 4, 5, 6])  # 암묵적 원자가
# allowed_hybridizations = set([
#     Chem.rdchem.HybridizationType.SP, 
#     Chem.rdchem.HybridizationType.SP2, 
#     Chem.rdchem.HybridizationType.SP3
# ])  # 혼성화
# allowed_radical_electrons = set([0, 1, 2])  # 라디칼 전자 수
# allowed_total_hs = set([0, 1, 2, 3, 4 ])  # 총 수소 수
# allowed_chirality = set(['R', 'S'])  # 카이랄성

# # 실제 데이터에서 등장하는 값 저장
# found_degrees = set()
# found_formal_charges = set()
# found_explicit_valences = set()
# found_implicit_valences = set()
# found_hybridizations = set()
# found_radical_electrons = set()
# found_total_hs = set()
# found_chirality = set()

# # 데이터셋에서 SMILES 가져오기
# smiles_list = df['SMILES'].tolist()

# for smiles in smiles_list:
#     mol = Chem.MolFromSmiles(smiles)
#     if mol:
#         for atom in mol.GetAtoms():
#             # 등장한 값 저장
#             found_degrees.add(atom.GetDegree())
#             found_formal_charges.add(atom.GetFormalCharge())
#             found_explicit_valences.add(atom.GetExplicitValence())
#             found_implicit_valences.add(atom.GetImplicitValence())
#             found_hybridizations.add(atom.GetHybridization())
#             found_radical_electrons.add(atom.GetNumRadicalElectrons())
#             found_total_hs.add(atom.GetTotalNumHs())
#             if atom.HasProp('_CIPCode'):
#                 found_chirality.add(atom.GetProp('_CIPCode'))

# # 허용된 값과 실제 등장한 값 비교
# unused_degrees = allowed_degrees - found_degrees
# unused_formal_charges = allowed_formal_charges - found_formal_charges
# unused_explicit_valences = allowed_explicit_valences - found_explicit_valences
# unused_implicit_valences = allowed_implicit_valences - found_implicit_valences
# unused_hybridizations = allowed_hybridizations - found_hybridizations
# unused_radical_electrons = allowed_radical_electrons - found_radical_electrons
# unused_total_hs = allowed_total_hs - found_total_hs
# unused_chirality = allowed_chirality - found_chirality

# # 결과 출력
# print("데이터에 등장하지 않은 값들:")
# if unused_degrees:
#     print(f"  사용되지 않은 원자 차수 (Degree): {unused_degrees}")
# if unused_formal_charges:
#     print(f"  사용되지 않은 형식 전하 (Formal Charge): {unused_formal_charges}")
# if unused_explicit_valences:
#     print(f"  사용되지 않은 명시적 원자가 (Explicit Valence): {unused_explicit_valences}")
# if unused_implicit_valences:
#     print(f"  사용되지 않은 암묵적 원자가 (Implicit Valence): {unused_implicit_valences}")
# if unused_hybridizations:
#     print(f"  사용되지 않은 혼성화 (Hybridization): {unused_hybridizations}")
# if unused_radical_electrons:
#     print(f"  사용되지 않은 라디칼 전자 수 (Radical Electrons): {unused_radical_electrons}")
# if unused_total_hs:
#     print(f"  사용되지 않은 총 수소 수 (Total Num Hs): {unused_total_hs}")
# if unused_chirality:
#     print(f"  사용되지 않은 카이랄성 값 (Chirality): {unused_chirality}")

# if not (unused_degrees or unused_formal_charges or unused_explicit_valences or unused_implicit_valences or unused_hybridizations or unused_radical_electrons or unused_total_hs or unused_chirality):
#     print(" 모든 허용된 값들이 실제 데이터에서 사용됨.")


In [6]:
import numpy as np
from rdkit import Chem

# 임의의 원자 (예: 탄소 C) 생성
atom = Chem.MolFromSmiles("CO").GetAtomWithIdx(0)

# 원자 특징 벡터 생성
features = atom_features(atom)

# 차원 출력
print("Atom feature dimension:", len(features))


Atom feature dimension: 75


In [7]:
df = df.dropna(subset=['SMILES'])  # SMILES 열에 NaN이 있는 행을 제거
list_tissue = ['liv', 'lun', 'sto', 'mgl']
smiles_list = df['SMILES'].tolist()

In [8]:
valid_dataX = []
valid_smiles_list = []
valid_labels = [[] for _ in range(len(list_tissue))]  # 각 태스크별 레이블을 저장할 리스트
invalid_smiles_list = []  # 유효하지 않은 SMILES를 저장할 리스트

# 유효한 SMILES와 레이블을 필터링
for i, smiles in enumerate(smiles_list):
    g = smile_to_graph(smiles)
    if g is not None:
        valid_dataX.append(g)
        valid_smiles_list.append(smiles)
        for idx, tissue in enumerate(list_tissue):
            label = df[tissue].iloc[i]
            if pd.isnull(label):
                valid_labels[idx].append(-1)  # 레이블이 NaN인 경우 -1로 처리
            else:
                valid_labels[idx].append(int(label))  # 레이블이 유효한 경우 해당 레이블로 처리
    else:
        invalid_smiles_list.append(smiles)  # 유효하지 않은 SMILES를 저장

# 유효하지 않은 SMILES들을 출력
print("Invalid SMILES:", invalid_smiles_list)

val_data = np.array(valid_dataX,dtype=object)
val_label = np.array(valid_labels)


# Save data: train, validation, and test
np.save('/data/home/dbswn0814/2025JCM/data/multi task/val/multi_refined_data.npy', val_data)
np.save('/data/home/dbswn0814/2025JCM/data/multi task/val/multi_refined_labels.npy', val_label)

# Since our data is already in the desired format, no need for further transformations.
print("Length of data:", len(val_data))

Invalid SMILES: []
Length of data: 343


In [9]:
# 각 list_tissue별 0과 1과 -1의 개수 출력
def print_label_counts(labels, data_type):
    for idx, tissue in enumerate(list_tissue):
        labels_for_tissue = labels[idx]  # (tasks, samples) 형식에서 각 태스크의 레이블
        count_0 = np.sum(labels_for_tissue == 0)
        count_1 = np.sum(labels_for_tissue == 1)
        count_neg1 = np.sum(labels_for_tissue == -1)
        print(f"Tissue {tissue} in {data_type}: 0={count_0}, 1={count_1}, -1={count_neg1}")
    print("------------------------------------------------------------------------")

print("Length of val_data:", len(val_data))


print_label_counts(val_label, "cross validation data")

Length of val_data: 343
Tissue liv in cross validation data: 0=166, 1=177, -1=0
Tissue lun in cross validation data: 0=159, 1=184, -1=0
Tissue sto in cross validation data: 0=157, 1=186, -1=0
Tissue mgl in cross validation data: 0=169, 1=174, -1=0
------------------------------------------------------------------------


In [10]:
val_path = '/data/home/dbswn0814/2025JCM/data/multi task/val/val_smiles.txt'
with open(val_path, 'w', encoding='utf-8') as f:
    for item in valid_smiles_list:
        f.write("%s\n" % item)

### combination data

In [11]:
import pandas as pd
import numpy as np
from itertools import combinations
import os

# List of tissues
tissues = ['liv', 'lun', 'sto', 'mgl']

# Generate all possible non-redundant combinations of tissue names
all_combinations = []
for i in range(2, len(tissues) + 1):  # 최소 2개 이상의 조합을 고려
    all_combinations.extend(combinations(tissues, i))

# Base directory for loading and saving data
base_dir = '/data/home/dbswn0814/2025JCM/data/multi task'

for list_tissue in all_combinations:
    # Generate the file name from the combination (e.g., 'liv_lun_data.csv')
    tissue_comb_name = '_'.join(list_tissue)
    file_name = f"{tissue_comb_name}_data.csv"
    file_path = os.path.join(base_dir, file_name)
    
    # Check if the file exists
    if not os.path.exists(file_path):
        print(f"File {file_name} not found. Skipping...")
        continue
    
    # Load the CSV file
    df = pd.read_csv(file_path)

    # Remove rows where the 'SMILES' column contains NaN values
    df = df.dropna(subset=['SMILES'])

    smiles_list = df['SMILES'].tolist()

    valid_dataX = []
    valid_smiles_list = []
    valid_labels = [[] for _ in range(len(list_tissue))]
    invalid_smiles_list = []

    # Filter valid SMILES and labels
    for i, smiles in enumerate(smiles_list):
        g = smile_to_graph(smiles)  # Replace this with your actual graph conversion function
        if g is not None:
            valid_dataX.append(g)
            valid_smiles_list.append(smiles)
            for idx, tissue in enumerate(list_tissue):
                column_name = f"result_{tissue}"  # Adjust column name to match format
                if column_name not in df.columns:
                    raise KeyError(f"Column {column_name} not found in {file_name}")
                label = df[column_name].iloc[i]
                if pd.isnull(label):
                    valid_labels[idx].append(-1)
                else:
                    valid_labels[idx].append(int(label))
        else:
            invalid_smiles_list.append(smiles)

    # Save the valid data and labels
    val_data = np.array(valid_dataX, dtype=object)
    val_label = np.array(valid_labels)

    np.save(os.path.join(base_dir, f"val/{tissue_comb_name}_data.npy"), val_data)
    np.save(os.path.join(base_dir, f"val/{tissue_comb_name}_labels.npy"), val_label)

    # Save the valid SMILES to a text file
    val_path = os.path.join(base_dir, f"val/{tissue_comb_name}_smiles.txt")
    with open(val_path, 'w', encoding='utf-8') as f:
        for item in valid_smiles_list:
            f.write("%s\n" % item)

    # Print summary
    print(f"Processed {file_name}:")
    print(f"  Length of data: {len(val_data)}")
    print(f"  Invalid SMILES: {invalid_smiles_list}")

    # Print label counts
    def print_label_counts(labels, data_type):
        for idx, tissue in enumerate(list_tissue):
            labels_for_tissue = labels[idx]
            count_0 = np.sum(labels_for_tissue == 0)
            count_1 = np.sum(labels_for_tissue == 1)
            count_neg1 = np.sum(labels_for_tissue == -1)
            print(f"Tissue {tissue} in {data_type}: 0={count_0}, 1={count_1}, -1={count_neg1}")
        print("------------------------------------------------------------------------")

    print_label_counts(val_label, tissue_comb_name)

File liv_lun_data.csv not found. Skipping...
File liv_sto_data.csv not found. Skipping...
File liv_mgl_data.csv not found. Skipping...
File lun_sto_data.csv not found. Skipping...
File lun_mgl_data.csv not found. Skipping...
File sto_mgl_data.csv not found. Skipping...
Processed liv_lun_sto_data.csv:
  Length of data: 10
  Invalid SMILES: []
Tissue liv in liv_lun_sto: 0=5, 1=5, -1=0
Tissue lun in liv_lun_sto: 0=8, 1=2, -1=0
Tissue sto in liv_lun_sto: 0=5, 1=5, -1=0
------------------------------------------------------------------------
Processed liv_lun_mgl_data.csv:
  Length of data: 12
  Invalid SMILES: []
Tissue liv in liv_lun_mgl: 0=7, 1=5, -1=0
Tissue lun in liv_lun_mgl: 0=12, 1=0, -1=0
Tissue mgl in liv_lun_mgl: 0=7, 1=5, -1=0
------------------------------------------------------------------------
Processed liv_sto_mgl_data.csv:
  Length of data: 5
  Invalid SMILES: []
Tissue liv in liv_sto_mgl: 0=3, 1=2, -1=0
Tissue sto in liv_sto_mgl: 0=3, 1=2, -1=0
Tissue mgl in liv_sto_mgl: