In [None]:
from IPython.display import clear_output
# !pip install pgl
clear_output()
print("安装成功")

In [None]:
print("1")
import paddle
print(paddle.__version__)

In [None]:
paddle.utils.run_check()

In [None]:
import rdkit
from rdkit import Chem

In [None]:
!tree data

In [None]:
train_data = project_path+'/apps/drug_target_interaction/graph_dta'
test_data =  project_path+'/apps/drug_target_interaction/graph_dta'
max_protein_len = 1000  # set -1 to use full sequence

In [None]:
from src.data_gen import DTADataset
train_dataset = DTADataset(train_data, max_protein_len=max_protein_len)
test_dataset = DTADataset(test_data, max_protein_len=max_protein_len)
print('训练集数量:', len(train_dataset))
print('测试集数量:', len(test_dataset))

In [None]:
lr = 0.0005                                         # 学习率
model_config = {
    "compound": {
        "atom_names": ["atomic_num", "chiral_tag"], # 化合物药物表示为图时节点的特征
        "bond_names": ["bond_dir", "bond_type"],    # 化合物药物表示为图时边的特征

        "gnn_type": "gin",                          # 图神经网络类型
        "dropout_rate": 0.2,                        # 图神经网络dropout操作的丢弃概率

        "embed_dim": 32,                            # 原子类型的embedding矩阵的纬度
        "layer_num": 5,                             # 图卷积神经网络的层数
        "hidden_size": 32,                          # 图卷积神经网络隐含层的大小
        "output_dim": 128                           # 化合物药物表征向量的纬度
    },

    "protein": {
        "max_protein_len": max_protein_len,         # 设置为-1时使用全长蛋白质序列作为输入
        "embed_dim": 128,                           # 氨基酸类别的embedding矩阵的纬度
        "num_filters": 32,                          # 序列卷积的滤波器的数量
        "output_dim": 128                           # 靶标蛋白表征向量的纬度
    },

    "dropout_rate": 0.2                             # 亲和性预测网络dropout操作的丢弃概率
}

In [None]:
import pgl
import paddle
import numpy as np
from src.model import DTAModel, DTAModelCriterion

model = DTAModel(model_config)
criterion = DTAModelCriterion()
optimizer = paddle.optimizer.Adam(
    learning_rate=lr,
    parameters=model.parameters())

In [None]:
from src.data_gen import DTACollateFunc

max_epoch = 2                     # 这里使用小的训练轮数方便演示
batch_size = 64                  # 训练时实验的批次数据大小
num_workers = 2                   # PGL dataloader的并行worker数目

collate_fn = DTACollateFunc(
    model_config['compound']['atom_names'],
    model_config['compound']['bond_names'],
    is_inference=False,
    label_name='Log10_Kd')

train_dataloader = train_dataset.get_data_loader(
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=collate_fn)

test_dataloader = test_dataset.get_data_loader(
        batch_size=batch_size,
        num_workers=1,
        shuffle=False,
        collate_fn=collate_fn)

In [None]:
def train(model, criterion, optimizer, dataloader):
    model.train()
    list_loss = []
    for graphs, proteins_token, proteins_mask, labels in dataloader:
        graphs = graphs.tensor()
        proteins_token = paddle.to_tensor(proteins_token)
        proteins_mask = paddle.to_tensor(proteins_mask)
        labels = paddle.to_tensor(labels)
        
        preds = model(graphs, proteins_token, proteins_mask)
        loss = criterion(preds, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.clear_grad()
        list_loss.append(loss.numpy())
    return np.mean(list_loss)

In [None]:
from src.utils import concordance_index

def evaluate(model, dataloader, prior_best_mse):
    model.eval()
    total_pred, total_label = [], []
    for graphs, proteins_token, proteins_mask, labels in dataloader:
        graphs = graphs.tensor()
        proteins_token = paddle.to_tensor(proteins_token)
        proteins_mask = paddle.to_tensor(proteins_mask)
        
        preds = model(graphs, proteins_token, proteins_mask)
        total_pred.append(preds.numpy())
        total_label.append(labels)

    total_pred = np.concatenate(total_pred, 0).flatten()
    total_label = np.concatenate(total_label, 0).flatten()
    mse = ((total_label - total_pred) ** 2).mean(axis=0)

    ci = None
    if mse < prior_best_mse:
        # Computing CI is time consuming
        ci = concordance_index(total_label, total_pred)
        
    return mse, ci

In [None]:
paddle.set_device('gpu')
best_mse, best_ci, best_ep = np.inf, 0, 0
best_model = 'best_model.pdparams'

metric = None
for epoch_id in range(1, max_epoch + 1):
    print('========== Epoch {} =========='.format(epoch_id))
    train_loss = train(model, criterion, optimizer, train_dataloader)
    print('Epoch: {}, Train loss: {}'.format(epoch_id, train_loss))
    mse, ci = evaluate(model, test_dataloader, best_mse)
    
    if mse < best_mse:
        best_mse, best_ci, best_ep = mse, ci, epoch_id  
        paddle.save(model.state_dict(), best_model)
        metric = 'Epoch: {}, Best MSE: {}, Best CI: {}'.format(epoch_id, best_mse, best_ci)
        print(metric)
    else:
        print('No improvement in epoch {}'.format(epoch_id))
        print('Current best: ({})'.format(metric))

In [None]:
import pgl
from rdkit.Chem import AllChem
from pahelix.utils.protein_tools import ProteinTokenizer
from pahelix.utils.compound_tools import mol_to_graph_data

protein_example = ''
drug_example = ''

# 处理药物分子
mol = AllChem.MolFromSmiles(drug_example)
mol_graph = mol_to_graph_data(mol)
print(mol_graph.values())
# 处理蛋白质序列
tokenizer = ProteinTokenizer()
protein_token_ids = tokenizer.gen_token_ids(protein_example)

# 融合药物分子和蛋白数据
data = {k: v for k, v in mol_graph.items()}
data['protein_token_ids'] = np.array(protein_token_ids)

# 当设定最大蛋白序列长度时，进行截取或加padding，使得序列满足条件
if max_protein_len > 0:
    protein_token_ids = np.zeros(max_protein_len, dtype=np.int64) + ProteinTokenizer.padding_token_id
    n = min(max_protein_len, data['protein_token_ids'].size)
    protein_token_ids[:n] = data['protein_token_ids'][:n]
    data['protein_token_ids'] = protein_token_ids
    
infer_collate_fn = DTACollateFunc(
    model_config['compound']['atom_names'],
    model_config['compound']['bond_names'],
    is_inference=True,
    label_name='Log10_Kd')

# 处理成DTAModel接收的数据
join_graph, proteins_token, proteins_mask = infer_collate_fn([data])

In [None]:
oin_graph = join_graph.tensor()
proteins_token = paddle.to_tensor(proteins_token)
proteins_mask = paddle.to_tensor(proteins_mask)

model.eval()
affinity_pred = model(join_graph, proteins_token, proteins_mask)
affinity_pred = affinity_pred.numpy()[0][0]