<a href="https://colab.research.google.com/github/breimanntools/aaanalysis/blob/master/protein_embeddings_ProtT5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Extensive Tutorial: https://colab.research.google.com/drive/1TUj-ayG3WO52n5N50S7KH9vtt6zRkdmj?usp=sharing#scrollTo=ET2v51slC5ui


In [8]:
#@title Install requirements
!pip install sentencepiece
!pip install h5py



In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
#@title Load Dependencies
import torch
import h5py
from pathlib import Path
from transformers import T5EncoderModel, T5Tokenizer
from tqdm import tqdm

In [11]:
def fasta_parser(fasta_path):
    def to_yield():
        yield uniprot_id, "".join(record)

    uniprot_id, record = "", []
    with open(fasta_path, "r") as fasta_f:
        for line in fasta_f:
            line = line.rstrip()
            if not line:
                continue
            if line.startswith(">"):
                if record:
                    yield from to_yield()
                uniprot_id = line[1:]
                record = []
            else:
                record.append(line.upper())
    yield from to_yield()

In [12]:
#@title Set variables
fasta_file = "seq10.fasta"
embedding_out_file = "seq10.h5"
prott5_model = "Rostlab/prot_t5_xl_half_uniref50-enc"
device = torch.device("cuda:0") if torch.cuda.is_available() else 'cpu'

In [13]:
#@title Read Data in
fasta_generator = fasta_parser(fasta_file)
# process the sequence
translation_table = "".maketrans("UZO", "XXX")
seq_dict = {
    uniprot_id: seq.translate(translation_table)
    for uniprot_id, seq in fasta_generator
}
nr_seqs = len(seq_dict)
# sort sequences according to length (reduces unnecessary padding --> speeds up embedding)
seqs_sorted = sorted(seq_dict.items(), key=lambda kv: len(seq_dict[kv[0]]), reverse=True)

In [14]:
#@title load model
model = T5EncoderModel.from_pretrained(prott5_model)
model = model.to(device)
model = model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/656 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.42G [00:00<?, ?B/s]

In [15]:
# load tokenizer
tokenizer = T5Tokenizer.from_pretrained(prott5_model, do_lower_case=False, legacy=True)

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/238k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

In [16]:
#@title Generate embeddings
with h5py.File(embedding_out_file, "w") as hdf:
    for identifier, seq in tqdm(seqs_sorted, desc="Embedding"):
        seq = ' '.join(list(seq))
        token_encoding = tokenizer.encode(seq, add_special_tokens=True)
        token_encoding = torch.tensor([token_encoding]).to(device)
        ids = tokenizer.encode_plus(seq, add_special_tokens=True)
        tokenized_sequences = torch.tensor([ids["input_ids"]]).to(model.device)
        attention_mask = torch.tensor([ids["attention_mask"]]).to(model.device)
        try:
            with torch.no_grad():
                embbeddings = model(input_ids=tokenized_sequences, attention_mask=attention_mask)
                emb = embbeddings.last_hidden_state[:, :-1]
                emb = emb.mean(dim=1).detach().cpu().numpy().squeeze()
                hdf.create_dataset(name=identifier, data=emb)
        except:
            pass


Embedding: 100%|██████████| 4480/4480 [52:03<00:00,  1.43it/s]


In [17]:
batch = [" ".join(list(seq)) for identifier, seq in seqs_sorted]
ids = tokenizer.batch_encode_plus(
            batch, add_special_tokens=True, padding="longest"
        )

In [18]:
embbeddings.keys()

odict_keys(['last_hidden_state'])