In [None]:
#@title Install dependencies
!pip3 install torch torchvision torchaudio transformers sentencepiece accelerate --extra-index-url https://download.pytorch.org/whl/cu116

In [None]:
#@title Import dependencies
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50) 
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device: {}".format(device))

In [None]:
#@title Load encoder-part of ProtT5 in half-precision
# Load ProtT5 in half-precision (more specifically: the encoder-part of ProtT5-XL-U50 in half-precision) 
transformer_link = "Rostlab/prot_t5_xl_half_uniref50-enc"
print("Loading: {}".format(transformer_link))
model = T5EncoderModel.from_pretrained(transformer_link)
model.full() if device=='cpu' else model.half() # only cast to full-precision if no GPU is available
model = model.to(device)
model = model.eval()
tokenizer = T5Tokenizer.from_pretrained(transformer_link, do_lower_case=False)

In [None]:
#@title Set paths
SEQUENCE_PATH = '/sample_data/362663.protein.sequences.v11.5.fa'
LINKS_PATH = '/sample_data/362663.protein.links.v11.5.txt'
EMBEDDING_PATH = '/sample_data/embedding.pt'

In [None]:
#@title Load sequence data and protein names
f = open(SEQUENCE_PATH)
sequence_examples = ''.join(f.readlines()).split('>')
sequence_names = []
for i in range(1,len(sequence_examples)):
  sequence_examples[i] = sequence_examples[i].split("\n")
  sequence_names.append(sequence_examples[i].pop(0))
  sequence_examples[i] = ''.join(sequence_examples[i])

In [None]:
#@title Generate input ids and attention mask
# this will replace all rare/ambiguous amino acids by X and introduce white-space between all amino acids
sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in sequence_examples]

# tokenize sequences and pad up to the longest sequence in the batch
ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device)
attention_mask = torch.tensor(ids['attention_mask']).to(device)

In [None]:
#@title Run ProtBERT on input_ids and generate hidden layer
N = len(input_ids)
Z = torch.zeros(N, 1024)
for i in range(N):
  try:
    with torch.no_grad():
      Z[i] = model(input_ids=input_ids[i:(i+1)],attention_mask=attention_mask[i:(i+1)]).last_hidden_state[:,:7].mean(dim=1)
    c += 1
    if c > N/500:
      print(i/N*100)
      c = 0
  except:
    print("Crashed at i = ", i)
    break

In [None]:
#@title Save node embeddings
torch.save(Z ,EMBEDDING_PATH)

In [None]:
#@title Load node embeddings
Z = torch.load(EMBEDDING_PATH)

In [None]:
#@title Load protein links
f = open(LINKS_PATH)
f.readline()
edges = [line.split(' ')[0:2] for line in f.readlines()]