In [1]:
import sys
sys.path.append("/home/zengxin/fpk/pycharm_project/GNN-DDAS")
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import matthews_corrcoef,roc_auc_score,f1_score, cohen_kappa_score, roc_curve, auc, roc_auc_score, average_precision_score,accuracy_score
import pickle
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
import numpy as np
from torch import tensor
from utils.graph_dataset import SMILESDataset
from utils.resample import resampled
np.random.seed(42)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# SMILES原子格式
smiles_dict = {"#": 29, "%": 30, ")": 31, "(": 1, "+": 32, "-": 33, "/": 34, ".": 2, 
        "1": 35, "0": 3, "3": 36, "2": 4, "5": 37, "4": 5, "7": 38, "6": 6, 
        "9": 39, "8": 7, "=": 40, "A": 41, "@": 8, "C": 42, "B": 9, "E": 43, 
        "D": 10, "G": 44, "F": 11, "I": 45, "H": 12, "K": 46, "M": 47, "L": 13, 
        "O": 48, "N": 14, "P": 15, "S": 49, "R": 16, "U": 50, "T": 17, "W": 51, 
        "V": 18, "Y": 52, "[": 53, "Z": 19, "]": 54, "\\": 20, "a": 55, "c": 56, 
        "b": 21, "e": 57, "d": 22, "g": 58, "f": 23, "i": 59, "h": 24, "m": 60, 
        "l": 25, "o": 61, "n": 26, "s": 62, "r": 27, "u": 63, "t": 28, "y": 64,"*":65}
smiles_dict

{'#': 29,
 '%': 30,
 ')': 31,
 '(': 1,
 '+': 32,
 '-': 33,
 '/': 34,
 '.': 2,
 '1': 35,
 '0': 3,
 '3': 36,
 '2': 4,
 '5': 37,
 '4': 5,
 '7': 38,
 '6': 6,
 '9': 39,
 '8': 7,
 '=': 40,
 'A': 41,
 '@': 8,
 'C': 42,
 'B': 9,
 'E': 43,
 'D': 10,
 'G': 44,
 'F': 11,
 'I': 45,
 'H': 12,
 'K': 46,
 'M': 47,
 'L': 13,
 'O': 48,
 'N': 14,
 'P': 15,
 'S': 49,
 'R': 16,
 'U': 50,
 'T': 17,
 'W': 51,
 'V': 18,
 'Y': 52,
 '[': 53,
 'Z': 19,
 ']': 54,
 '\\': 20,
 'a': 55,
 'c': 56,
 'b': 21,
 'e': 57,
 'd': 22,
 'g': 58,
 'f': 23,
 'i': 59,
 'h': 24,
 'm': 60,
 'l': 25,
 'o': 61,
 'n': 26,
 's': 62,
 'r': 27,
 'u': 63,
 't': 28,
 'y': 64,
 '*': 65}

In [3]:
def generate_ecfp6_fingerprint(smiles):
    smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles), isomericSmiles=True)
    mol = Chem.MolFromSmiles(smiles)
    fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, 6, nBits=1024)
    fingerprint = fingerprint.ToBitString()
    return fingerprint

def smiles_onehot(smiles=None):
    smiles_one_hot = np.zeros((len(smiles),65))
    for i, amino_acid in enumerate(smiles):
        smiles_one_hot[i, smiles_dict[amino_acid]] = 1
    return np.array(smiles_one_hot)
def smiles_string(data, max_len):
    toks_list = []
    mask_attn_list = []
    toks = [smiles_dict[char] for char in data]
    if len(toks) > max_len:
        toks = toks[:max_len]
        mask_attn = [1]*max_len
    else:
        toks = toks + [0] * (max_len - len(toks))
        mask_attn = [1] * len(data) + [0] * (max_len - len(data))
    return toks,mask_attn
# 生成ECFP6分子指纹、ont-hot编码
def data_processed(data):
    fingerprint_list = []
    one_hot_list = []
    lable_list = []
    toks_list = []
    for idx, data in enumerate(data):
        fingerprint = generate_ecfp6_fingerprint(data.smiles)
        integer_value = int(fingerprint, 2)
        fingerprint = np.array([int(bit) for bit in bin(integer_value)[2:].zfill(len(fingerprint))])
        fingerprint_list.append(fingerprint)
        lable_list.append(int(data.y))
    return fingerprint_list,lable_list

In [4]:
train_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/train_data'
test_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/test_data'
temp_train_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/resample/train'
raw_train_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/train_data/raw/train_data.csv'
temp_test_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/resample/test'
raw_test_root = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/merge/merge_data/test_data/raw/test_data.csv'

train_set = SMILESDataset(root=train_root,raw_dataset='train_data.csv',processed_data='train.pt',max_node_num=125)
test_set = SMILESDataset(root=test_root,raw_dataset='test_data.csv',processed_data='test.pt',max_node_num=125)
# train_set = resampled(temp_train_root=temp_train_root,raw_train_root=raw_train_root,ratio=1)
# test_set = resampled(temp_train_root=temp_test_root,raw_train_root=raw_test_root,ratio=1)
# train_set, test_set = train_test_split(train_set,test_size=0.1,random_state=42)
print(len(train_set),len(test_set))

2764 684


In [5]:
def knn(rf=None,X_train=None,y_train=None,X_test=None,y_test=None,param=None,cv=5,model_path=None,data_name=None):

    gs = GridSearchCV(rf,param_grid=param,cv=cv)
    gs.fit(X_train,y_train)
    res = gs.score(X_test,y_test)
    y_pred = gs.predict(X_test)
    y_prob = gs.predict_proba(X_test)[:, 1]
    f1 = f1_score(y_test,y_pred)
    ck = cohen_kappa_score(y_test,y_pred)
    mcc = matthews_corrcoef(y_test,y_pred)
    auprc = average_precision_score(y_test,y_pred)
    acc = accuracy_score(y_test,y_pred)
    best_params = gs.best_params_
    # 输出最佳的n_estimators和max_depth
    best_n_neighbors = best_params["n_neighbors"]
    print(f'f1:{f1},mcc:{mcc},acc:{acc},ck:{ck},auprc:{auprc}')
    print("最佳 _n_neighbors:", best_n_neighbors)
    filename = model_path + 'knn_' + data_name + '.pkl'
    pickle.dump(rf, open(filename, 'wb'))

In [6]:
train_fingerprint,train_label = data_processed(train_set)
test_fingerprint,test_lables = data_processed(test_set)
train_fingerprint = np.array(train_fingerprint)

In [7]:
knn_ecfp = KNeighborsClassifier()

In [8]:
param = {'n_neighbors':[4,5,6,7,8,9,10,12,14,16,18,20]}
model_path = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/save_model/ML/merge/'
knn(knn_ecfp,param=param,X_train=train_fingerprint,y_train=train_label,X_test=test_fingerprint,y_test=test_lables,cv=5,model_path=model_path,data_name='merge_ecfp')

Roc:0.6678928693141952,f1:0.49196277107881337,mcc:0.145720616625586,acc:0.881578947368421,ck:0.04158593966233026,auprc:0.14251743817374762
最佳 _n_neighbors: 12


In [9]:
predictions_data = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/data/new_data/raw'
predictions_data = SMILESDataset(root=predictions_data,raw_dataset='test_data.csv',processed_data='test.pt')
predictions_fingerprint,test_lables = data_processed(predictions_data)
predictions_model = '/home/zengxin/fpk/pycharm_project/GNN-DDAS/save_model/ML/merge/knn_merge_ecfp.pkl'
with open(predictions_model, 'rb') as f:
    loaded_model = pickle.load(f)

# 拟合模型
loaded_model.fit(train_fingerprint, train_label)

In [10]:
# 进行预测
predictions = loaded_model.predict(predictions_fingerprint)

# 输出预测结果
print("Predictions:", predictions)

Predictions: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
