In [4]:
# import mdtraj
import tqdm
import glob
import os
import numpy as np
import itertools
from moleculekit.molecule import Molecule

**Requirements:**

`conda install transformers sentencepiece`

In [5]:
from transformers import T5EncoderModel, T5Tokenizer
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using {}".format(device))

Using cuda:0


In [6]:
def get_T5_model():
    model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc")
    model = model.to(device) # move model to GPU
    model = model.eval() # set model to evaluation model
    tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

    return model, tokenizer

model, tokenizer = get_T5_model()

In [10]:
def gen_T5_embeddings(molecule, model, tokenizer):
    # Generate the FASTA sequence for each chain (assumes 1 bead per residue)
    fasta = "".join([i[-1] for i in molecule.atomtype])
    fasta_list = []
    for i in [len(list(i[1])) for i in itertools.groupby(molecule.segid)]:
        fasta_list.append(fasta[:i])
        fasta = fasta[i:]

    # print("Generating ProtT5 embeddings for the sequences:", " ".join(fasta_list))

    embedding_list = []
    for fasta in fasta_list:
        fasta = " ".join(fasta)

        token_encoding = tokenizer(fasta, add_special_tokens=True)
        input_ids      = torch.tensor(token_encoding['input_ids']).to(device)
        attention_mask = torch.tensor(token_encoding['attention_mask']).to(device)

        with torch.no_grad():
            # The model expects a input of shape [batch, max_len]
            embedding_repr = model(input_ids[None,:], attention_mask=attention_mask[None,:])
            # Output has shape [batch, max_len, embedding_len]
            # We also need to trim off the termination token, in the original script this was done with [:s_len] to
            # also trim off the batch padding
            embedding_list.append(embedding_repr.last_hidden_state[0][:-1])

    embedding = torch.cat(embedding_list).cpu().numpy()
    return embedding

In [11]:
out_dir = "/home/argon/Stuff/prot_trans/cg_raz081724_CA_lj_angleXCX_dihedralX_V1_opt"
for i in tqdm.tqdm(glob.glob(os.path.join(out_dir, "*"))):
    pdbid = os.path.basename(i)
    if not os.path.exists(os.path.join(out_dir, f"{pdbid}/raw/")):
        tqdm.tqdm.write(f"<skip> {pdbid}")
        continue
    # This assumes 1 the mapping uses bead per residue
    molecule = Molecule(os.path.join(out_dir, f"{pdbid}/processed/{pdbid}_processed.psf"))
    embedding = gen_T5_embeddings(molecule, model, tokenizer)

    outpath = os.path.join(out_dir, f"{pdbid}/raw/protT5_embedding.npy")
    np.save(outpath, embedding)

 18%|█▊        | 182/987 [00:03<00:17, 46.40it/s]

<skip> result


 34%|███▎      | 331/987 [00:06<00:12, 54.50it/s]

<skip> priors.yaml


 53%|█████▎    | 522/987 [00:10<00:09, 47.00it/s]

<skip> pdb_list.pkl


 55%|█████▌    | 543/987 [00:11<00:09, 45.48it/s]

<skip> prior_fit_plots


 61%|██████    | 602/987 [00:12<00:07, 49.74it/s]

<skip> prior_builder.pkl


 67%|██████▋   | 661/987 [00:13<00:06, 51.96it/s]

<skip> prior_params.json


100%|██████████| 987/987 [00:20<00:00, 48.29it/s]
