# This is an example of reproducing training results
Prior to this step, please execute `download_weights.py` to download the weights for all pre-trained models.

In [1]:
# import packages
from torch.utils.data import DataLoader
from readData import readData
import shutil
from utils import *
import torch
from torch import nn
from model import LABind
from func_help import setALlSeed,get_std_opt
from tqdm import tqdm
import pickle as pkl
from sklearn.model_selection import KFold
import gc

In [2]:
# config
DEVICE = torch.device('cuda:0')
root_path = getRootPath()
dataset = 'Unseen' # DS1:LigBind, DS2:GPSite, DS3:Unseen

nn_config = {
    # dataset 
    'train_file': f'{root_path}/{dataset}/label/training.fa',
    'test_file': f'{root_path}/{dataset}/label/test.fa',
    'valid_file': f'{root_path}/{dataset}/label/picking.fa',
    'proj_dir': f'{root_path}/{dataset}/',
    'lig_dict': pkl.load(open(f'{root_path}/tools/ligand.pkl', 'rb')),
    'pdb_class':'source', # source or omegafold or esmfold
    'dssp_max_repr': np.load(f'{root_path}/tools/dssp_max_repr.npy'),
    'dssp_min_repr': np.load(f'{root_path}/tools/dssp_min_repr.npy'),
    'ankh_max_repr': np.load(f'{root_path}/tools/ankh_max_repr.npy'),
    'ankh_min_repr': np.load(f'{root_path}/tools/ankh_min_repr.npy'),
    'ion_max_repr': np.load(f'{root_path}/tools/ion_max_repr.npy'),
    'ion_min_repr': np.load(f'{root_path}/tools/ion_min_repr.npy'),
    # model parameters
    
    'rfeat_dim':1556,
    'ligand_dim':768, 
    'hidden_dim':256, 
    'heads':4, 
    'augment_eps':0.05, 
    'rbf_num':8, 
    'top_k':30, 
    'attn_drop':0.1, 
    'dropout':0.1, 
    'num_layers':4, 
    'lr':0.0004, 
    
    # training parameters 
    # You can modify it according to the actual situation. 
    # Since it involves mapping the entire protein, it will consume a large amount of GPU memory.
    'batch_size':15,
    'max_patience':10,
    'device_ids':[0,1], # 2*A100-40G
}
pretrain_path = { # Please modify 
    'esmfold_path': '../tools/esmfold_v1', # esmfold path
    'ankh_path': '../tools/ankh-large/', # ankh path
    'molformer_path': '../tools/MoLFormer-XL-both-10pct/', # molformer path
    'model_path':f'{root_path}/model/Unseen/' # based on Unseen
}

### Download dataset
Download the file from https://zenodo.org/records/13938443 and place it in the root directory.


### Retrieve the features

#### Ankh

In [None]:
# get ankh features
from transformers import AutoTokenizer, T5EncoderModel 
from Bio import SeqIO
tokenizer = AutoTokenizer.from_pretrained(pretrain_path['ankh_path'])
model     = T5EncoderModel.from_pretrained(pretrain_path['ankh_path'])
model.to(DEVICE)
model.eval()
out_path = f"{root_path}/{dataset}/ankh/"
makeDir(out_path)
# 使用biopython读取fasta文件
fasta_path = f"{root_path}/{dataset}/fasta/"
for file_class in os.listdir(fasta_path):
    for fasta_base_file in os.listdir(f"{fasta_path}/{file_class}"):
        fasta_file = f"{fasta_path}/{file_class}/{fasta_base_file}"
        sequences = SeqIO.parse(fasta_file, "fasta")
        for record in tqdm(sequences):
            if os.path.exists(out_path+f'{record.id}.npy'):
                continue
            ids = tokenizer.batch_encode_plus([list(record.seq)], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")
            input_ids = ids['input_ids'].to(DEVICE)
            attention_mask = ids['attention_mask'].to(DEVICE)
            with torch.no_grad():
                embedding_repr = model(input_ids=input_ids,attention_mask=attention_mask)
                emb = embedding_repr.last_hidden_state[0,:len(record.seq)].cpu().numpy()
                np.save(out_path+f'{record.id}.npy',emb)
del model
gc.collect()
torch.cuda.empty_cache()

#### DSSP

In [None]:
from Bio.PDB import PDBParser
from Bio.PDB.DSSP import DSSP

mapSS = {' ':[0,0,0,0,0,0,0,0,0],
        '-':[1,0,0,0,0,0,0,0,0],
        'H':[0,1,0,0,0,0,0,0,0],
        'B':[0,0,1,0,0,0,0,0,0],
        'E':[0,0,0,1,0,0,0,0,0],
        'G':[0,0,0,0,1,0,0,0,0],
        'I':[0,0,0,0,0,1,0,0,0],
        'P':[0,0,0,0,0,0,1,0,0],
        'T':[0,0,0,0,0,0,0,1,0],
        'S':[0,0,0,0,0,0,0,0,1]}
p = PDBParser(QUIET=True)
pdb_path = f"{root_path}/{dataset}/pdb/"
dssp_path = "../tools/mkdssp"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_dssp/")
for pdb_file_name in tqdm(os.listdir(pdb_path),desc='DSSP running',ncols=80,unit='proteins'):
    pdb_file = pdb_path+pdb_file_name
    save_file = pdb_file.replace('.pdb','.npy').replace('pdb',f'{pdb_class}_dssp')
    if os.path.exists(save_file):
        continue
    structure = p.get_structure("tmp", pdb_file)
    model = structure[0]
    try:
        dssp = DSSP(model, pdb_file, dssp=dssp_path)
        keys = list(dssp.keys())
    except:
        keys = []
    res_np = []
    for chain in model:
        for residue in chain:
            res_key = (chain.id,(' ', residue.id[1], residue.id[2]))
            if res_key in keys:
                tuple_dssp = dssp[res_key]
                res_np.append(mapSS[tuple_dssp[2]] + list(tuple_dssp[3:]))
            else:
                res_np.append(np.zeros(20))
    np.save(save_file, np.array(res_np))

#### Position

In [None]:
from Bio.PDB.ResidueDepth import get_surface
from scipy.spatial import cKDTree

pdb_path = f"{root_path}/{dataset}/pdb/"
msms_path = "../tools/msms"
pdb_class = nn_config['pdb_class']
makeDir(f"{root_path}/{dataset}/{pdb_class}_pos/")
for pdb_file_name in tqdm(os.listdir(pdb_path),desc='MSMS running',ncols=80,unit='proteins'):
    pdb_file = pdb_path+pdb_file_name
    save_file = pdb_file.replace('.pdb','.npy').replace('pdb',f'{pdb_class}_pos')
    if os.path.exists(save_file):
        continue
    parser = PDBParser(QUIET=True)
    X = []
    chain_atom = ['N', 'CA', 'C', 'O']
    model = parser.get_structure('model', pdb_file)[0]
    chain = next(model.get_chains())
    try:
        surf = get_surface(chain,MSMS=msms_path)
        surf_tree = cKDTree(surf)
    except:
        surf = np.empty(0)
    for residue in chain:
        line = []
        atoms_coord = np.array([atom.get_coord() for atom in residue])
        if surf.size == 0:
            dist, _ = surf_tree.query(atoms_coord)
            closest_atom = np.argmin(dist)
            closest_pos = atoms_coord[closest_atom]
        else:
            closest_pos = atoms_coord[-1]
        atoms = list(residue.get_atoms())
        ca_pos= residue['CA'].get_coord()
        pos_s = 0
        un_s = 0
        for atom in atoms:
            if atom.name in chain_atom:
                line.append(atom.get_coord())
            else:
                pos_s += calMass(atom,True)
                un_s += calMass(atom,False)
        # 此处line应该等于4
        if len(line) != 4:
            line = line + [list(ca_pos)]*(4-len(line))
        if un_s == 0:
            R_pos = ca_pos
        else:
            R_pos = pos_s / un_s
        line.append(R_pos)  
        line.append(closest_pos) # 加入最近点的残基信息
        X.append(line) 
    np.save(save_file, X)

### Train

In [4]:
def valid(model, valid_list):
    model.to(DEVICE)
    model.eval()
    valid_data = readData(
        name_list=valid_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file']) # If 5-fold cross-validation is not used, it needs to be changed to valid_file.
    valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=valid_data.collate_fn, num_workers=5)
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in valid_loader:
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            logits = model(rfeat, ligand, xyz,  mask).sigmoid() # [N]
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
        # 通过aupr数值进行早停
        aupr_value = average_precision_score(all_y_true, all_y_score)
    return aupr_value

def train(train_list,valid_list=None,model=None,epochs=50,fold_idx=None):
    model.to(DEVICE)
    train_data = readData(
        name_list=train_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=nn_config['train_file'])
    train_loader = DataLoader(train_data, batch_size=nn_config['batch_size'],shuffle=True, collate_fn=train_data.collate_fn, num_workers=5)
    loss_fn = nn.BCELoss(reduction='none')
    optimizer = get_std_opt(len(train_list),nn_config['batch_size'], model.parameters(), nn_config['hidden_dim'], nn_config['lr'])
    v_max_aupr = 0
    patience = 0
    t_mccs = []
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    train_losses = []
    for epoch in range(epochs):
        all_loss = 0
        all_cnt = 0
        model.train()
        for rfeat, ligand, xyz,  mask, y_true in tqdm(train_loader):
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            optimizer.zero_grad()
            logits = model(rfeat, ligand, xyz, mask).sigmoid() # [N]
            # 计算所有离子的loss
            loss = loss_fn(logits, y_true) * mask
            loss = loss.sum() / mask.sum()
            all_loss += loss.item()
            all_cnt += 1
            loss.backward()
            optimizer.step()
        train_losses.append(all_loss / all_cnt)
        # 根据验证集的aupr进行早停
        if valid_list is not None:
            v_aupr = valid(model,valid_list,fold_idx)
            t_mccs.append(v_aupr)
            print(f'Epoch {epoch} Loss: {all_loss / all_cnt}', f'Epoch Valid {epoch} AUPR: {v_aupr}')
            if v_aupr > v_max_aupr:
                v_max_aupr = v_aupr
                patience = 0
                torch.save(model.state_dict(), f'{root_path}/Output/{dataset}/fold{fold_idx}.ckpt')
            else:
                patience += 1
            if patience >= nn_config['max_patience']:
                break


In [None]:
setALlSeed(11)
# 5-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state = 42)
data_list = readDataList(f'{root_path}/{dataset}/label/train/train.fa',skew=1)
makeDir(f'{root_path}/Output/{dataset}_5fold/')
fold_idx = 0
for train_idx, valid_idx in kf.split(data_list):
    train_list = [data_list[i] for i in train_idx]
    valid_list = [data_list[j] for j in valid_idx]
    model = LABind(
    rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
    rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
    train(train_list,valid_list,model,epochs=70,fold_idx=fold_idx)
    fold_idx += 1

In [None]:
setALlSeed(11)
# train-validation-test
train_list = readDataList(f'{root_path}/{dataset}/label/training.fa',skew=1)
valid_list = readDataList(f'{root_path}/{dataset}/label/picking.fa',skew=1)
makeDir(f'{root_path}/Output/{dataset}/')
model = LABind(
rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers'])
train(train_list, valid_list, model, epochs=70, fold_idx=0)

In [None]:
# Determine the best threshold for MCC based on the validation set.
model_path = f'{root_path}/Output/{dataset}/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(1): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(device)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    model.load_state_dict(state_dict)
    model.eval()
    models.append(model)
    
valid_list = readDataList(f'{root_path}/{dataset}/label/picking.fa',skew=1)
valid_data = readData(
    name_list=valid_list, 
    proj_dir=nn_config['proj_dir'], 
    lig_dict=nn_config['lig_dict'],
    true_file=f'{root_path}/{dataset}/label/picking.fa')
# 打印长度
valid_loader = DataLoader(valid_data, batch_size=nn_config['batch_size'], collate_fn=valid_data.collate_fn)
print(f'valid data length: {len(valid_data)}')
all_y_score = []
all_y_true = []
with torch.no_grad():
    for rfeat, ligand, xyz,  mask, y_true in valid_loader:
        tensors = [rfeat, ligand, xyz,  mask, y_true]
        tensors = [tensor.to(DEVICE) for tensor in tensors]
        rfeat, ligand, xyz, mask, y_true = tensors
        
        logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
        logits = torch.stack(logits,0).mean(0)
        
        logits = torch.masked_select(logits, mask==1)
        y_true = torch.masked_select(y_true, mask==1)
        all_y_score.extend(logits.cpu().detach().numpy())
        all_y_true.extend(y_true.cpu().detach().numpy())

best_threshold,best_mcc,best_pred = getBestThreshold(all_y_true, all_y_score)
appendText(f'{model_path}/Best_Threshold.txt',f'{best_threshold} {best_mcc}\n')

### Test

In [None]:
import pandas as pd
model_path = f'{root_path}/Output/{dataset}/' # if 5-fold cross-validation, {dataset}_5fold
print(model_path)
print(nn_config['pdb_class'])

models = []
for fold in range(1): # if 5-fold cross-validation, set to 5
    state_dict = torch.load(model_path + 'fold%s.ckpt'%fold,'cuda:0')
    model = LABind(
        rfeat_dim=nn_config['rfeat_dim'], ligand_dim=nn_config['ligand_dim'], hidden_dim=nn_config['hidden_dim'], heads=nn_config['heads'], augment_eps=nn_config['augment_eps'], 
        rbf_num=nn_config['rbf_num'],top_k=nn_config['top_k'], attn_drop=nn_config['attn_drop'], dropout=nn_config['dropout'], num_layers=nn_config['num_layers']).to(device)
    model = nn.DataParallel(model,device_ids=nn_config['device_ids'])
    model.load_state_dict(state_dict)
    model.eval()
    models.append(model)

df = pd.DataFrame(columns=['ligand','Rec','SPE','Acc','Pre','F1','MCC','AUC','AUPR'])
for ionic in os.listdir(f'{root_path}/{dataset}/label/test/'):
    ionic = ionic.split('.')[0]
    test_list = readDataList(f'{root_path}/{dataset}/label/test/{ionic}.fa',skew=1)
    test_data = readData(
        name_list=test_list, 
        proj_dir=nn_config['proj_dir'], 
        lig_dict=nn_config['lig_dict'],
        true_file=f'{root_path}/{dataset}/label/test/{ionic}.fa')
    # 打印长度
    test_loader = DataLoader(test_data, batch_size=nn_config['batch_size'], collate_fn=test_data.collate_fn)
    print(f'{ionic} test data length: {len(test_data)}')
    all_y_score = []
    all_y_true = []
    with torch.no_grad():
        for rfeat, ligand, xyz,  mask, y_true in test_loader:
            tensors = [rfeat, ligand, xyz,  mask, y_true]
            tensors = [tensor.to(DEVICE) for tensor in tensors]
            rfeat, ligand, xyz, mask, y_true = tensors
            
            logits = [model(rfeat, ligand, xyz, mask).sigmoid() for model in models]
            logits = torch.stack(logits,0).mean(0)
            
            logits = torch.masked_select(logits, mask==1)
            y_true = torch.masked_select(y_true, mask==1)
            all_y_score.extend(logits.cpu().detach().numpy())
            all_y_true.extend(y_true.cpu().detach().numpy())
    data_dict = calEval(all_y_true, all_y_score, best_th = best_threshold) # please set best threshold.
    data_dict['ligand'] = ionic
    df = pd.concat([df,pd.DataFrame(data_dict,index=[0])])
df.to_csv(f'{model_path}test.csv',index=False)
df