### Enter your fasta file path and prediction folder path here

In [1]:
prediction_folder = "" #TODO
query_fasta = "" #TODO
embedding_batchsize = 30 #TODO. Decrease in case of CUDA-Out-Of-Memory errors<

In [2]:
from config import FileSetter, FileManager
from bindNode23.bindNode23 import BindNode23
from pathlib import Path
import gc
from transformers import T5Tokenizer, T5Model, T5EncoderModel
import re
import torch
import h5py
from tqdm import tqdm
import os
import config
import bindNode23.bindNode23
from config import FileSetter, FileManager

  from .autonotebook import tqdm as notebook_tqdm


##### This code box sets up the protein language model ProtT5. Could trigger CUDA-errors on machines with less than 4GB (NVIDIA) GPU VRAM

In [3]:
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
gc.collect()
torch.cuda.empty_cache()

In [6]:
# parse fasta file, space and substitute unknown AAs, tokenize
query_sequences = FileManager.read_fasta(query_fasta)
sequences = query_sequences.values()
headers = query_sequences.keys()
sequences = [" ".join(sequence) for sequence in sequences]
sequences = [re.sub(r"[UZOB]", "X", sequence) for sequence in sequences]
ids = tokenizer.batch_encode_plus(sequences, add_special_tokens=True, padding="longest")

##### Embed the sequences in your fasta file. If you encouter CUDA-Out-Of-Memory warnings, consider using a lower batchsize in top cell.

In [5]:
bsize = embedding_batchsize
maxsize = len(sequences)
start = 0
end = bsize
features = []

for i in tqdm(range(bsize, maxsize, bsize)):
    start = i - bsize
    end = i
    batch_ids = ids.input_ids[start:end]
    batch_masks = ids.attention_mask[start:end]
    input_ids = torch.tensor(batch_ids).to(device)
    attention_mask = torch.tensor(batch_masks).to(device)
    with torch.no_grad():
      embedding = model(input_ids=input_ids,attention_mask=attention_mask)
    embedding = embedding.last_hidden_state.cpu().numpy()
    for e in embedding:
      features.append(e)
    torch.cuda.empty_cache()
    
# Add last batch
batch_ids = ids.input_ids[end:]
batch_masks = ids.attention_mask[end:]
input_ids = torch.tensor(batch_ids).to(device)
attention_mask = torch.tensor(batch_masks).to(device)
with torch.no_grad():
  embedding = model(input_ids=input_ids,attention_mask=attention_mask)
embedding = embedding.last_hidden_state.cpu().numpy()
for e in embedding:
  features.append(e)

# sanity check for the created data
assert len(features)==len(sequences)

print("Making file...")
h5f = h5py.File(os.path.join(prediction_folder, 'embeddings_halfprec_inference.h5'), 'w')
for i,header in enumerate(headers):
  h5f.create_dataset(header, data=features[i][:len(query_sequences[header])])
h5f.close()
print("Finished")

100%|█████████████████████████████████████████████| 9/9 [00:35<00:00,  3.93s/it]


Making file...
Finished


##### now, there is an .h5 file created in your prediction directory. We will use it to predict your proteins.

In [7]:
model_prefix="trained_models/trained_model"
cutoff=0.5
ri=False
result_folder=prediction_folder
ids = list(headers)
fasta_file=query_fasta

proteins = BindNode23.GCN_prediction_pipeline(
   os.path.join(prediction_folder, 'embeddings_halfprec_inference.h5'), model_prefix, cutoff, result_folder, ids, fasta_file, ri)

Prepare data
Load model
Calculate predictions
Load model
Calculate predictions
Load model
Calculate predictions
Load model
Calculate predictions
Load model
Calculate predictions


#### For each protein, you will yield an output file (.bindPredict_out) with the following content:
##### Binding prediction (b/nb), respective probability for each class

In [54]:
allpredictionfiles = list(Path(prediction_folder).rglob("*.bindPredict_out"))
with open(allpredictionfiles[0], 'r') as displayFile:
    results = displayFile.readlines()
    
print("Protein name:\n")
print(allpredictionfiles[0].name.split(".")[0], '\n')
print(results[0])
for k in range(1,15):
    print(results[k].replace('\t', '\t\t'))

Protein name:

P32643 

Position	Metal.Proba	Metal.Class	Nuclear.Proba	Nuclear.Class	Small.Proba	Small.Class	Any.Class

1		0.002		nb		0.004		nb		0.077		nb		nb

2		0.009		nb		0.012		nb		0.354		nb		nb

3		0.029		nb		0.037		nb		0.328		nb		nb

4		0.021		nb		0.014		nb		0.428		nb		nb

5		0.005		nb		0.010		nb		0.206		nb		nb

6		0.009		nb		0.016		nb		0.277		nb		nb

7		0.011		nb		0.032		nb		0.172		nb		nb

8		0.100		nb		0.006		nb		0.332		nb		nb

9		0.046		nb		0.009		nb		0.566		b		b

10		0.038		nb		0.003		nb		0.223		nb		nb

11		0.047		nb		0.009		nb		0.280		nb		nb

12		0.048		nb		0.001		nb		0.112		nb		nb

13		0.022		nb		0.014		nb		0.286		nb		nb

14		0.124		nb		0.023		nb		0.633		b		b

