<a href="https://colab.research.google.com/github/mheinzinger/ProstT5/blob/main/notebooks/ProstT5_inverseFolding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
#@title Install dependencies. { display-mode: "form" }
%%capture
!pip -q install sentencepiece
!pip -q install git+https://github.com/huggingface/transformers.git # had a -U before
!pip -q install git+https://github.com/huggingface/peft.git
!pip -q install git+https://github.com/huggingface/accelerate.git
!pip -q install omegaconf pytorch_lightning biopython ml_collections einops py3Dmol
!wget -q -nc https://mmseqs.com/foldseek/foldseek-linux-avx2.tar.gz; tar xzf foldseek-linux-avx2.tar.gz; export PATH=$(pwd)/foldseek/bin/:$PATH

In [3]:
#@title Import dependencies. { display-mode: "form" }
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
  print("Warning: You are running this notebook without GPU which will be slow.")

In [11]:
#@title Load ProstT5. { display-mode: "form" }
# Load the tokenizer
tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5_fp16', do_lower_case=False, legacy=True)

# Load the model
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/ProstT5_fp16",low_cpu_mem_usage=True, device_map="auto",torch_dtype=torch.float16)
model = model.eval()
model = model.half()

In [5]:
#@title Example: how to get 3Di from PDB/AFDB
%%capture
query_ID = 'A0A6G0XC32' #@param {type:"string"}
!mkdir -p $query_ID
!wget -q -O $query_ID/AF-$query_ID-F1-model_v4.pdb https://alphafold.ebi.ac.uk/files/AF-$query_ID-F1-model_v4.pdb
!foldseek/bin/foldseek createdb $query_ID/ $query_ID/queryDB
!foldseek/bin/foldseek lndb $query_ID/queryDB_h $query_ID/queryDB_ss_h
!foldseek/bin/foldseek convert2fasta $query_ID/queryDB_ss $query_ID/queryDB_ss.fasta

In [6]:
#@title Read in file in FASTA format. { display-mode: "form" }
def read_fasta( in_path, is_3Di ):
    '''
        Reads in fasta file containing a single or multiple sequences.
        Returns dictionary.
    '''

    sequences = dict()
    with open( in_path, 'r' ) as fasta_f:
        for line in fasta_f:
            # get uniprot ID from header and create new entry
            if line.startswith('>'):
                uniprot_id = line.split(" ")[0].replace('>', '').replace(".pdb","").strip()
                sequences[ uniprot_id ] = ''
            else:
                # repl. all whie-space chars and join seqs spanning multiple lines
                if is_3Di:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","").lower() # drop gaps and cast to lower-case
                else:
                    sequences[ uniprot_id ] += ''.join( line.split() ).replace("-","")

    example = sequences[uniprot_id]

    print("##########################")
    print(f"Input is 3Di: {is_3Di}")
    print(f"Example sequence: >{uniprot_id}\n{example}")
    print("##########################")

    return sequences

In [12]:
#@title Generate new sequence and predict 3D structure using ESMFold API. { display-mode: "form" }
import os
import locale
import time
import requests
locale.getpreferredencoding = lambda: "UTF-8"
!mkdir -p $query_ID/gen_seqs

seq_dict = read_fasta(f"{query_ID}/queryDB_ss.fasta",is_3Di=True)

gen_kwargs =  {
            "do_sample": True,
            "top_p" : 0.85,
            "temperature" : 1.0,
            "top_k" : 3,
            "repetition_penalty" : 1.2,
            }

generated_sequences=dict()
for seq_idx, (fasta_id, seq) in enumerate(seq_dict.items(),1): # for each sequence in the FASTA file
    seq_len = len(seq)
    seq = seq.replace('U','X').replace('Z','X').replace('O','X').replace("B","X")
    seq = " ".join(list(seq))

    max_len=seq_len+1
    min_len=seq_len+1

    # starting point tokens
    start_encoding = tokenizer.batch_encode_plus( ["<fold2AA>" + " " + seq],
                                       add_special_tokens=True,
                                       return_tensors='pt'
                                       ).to(device)
    print(seq)
    with torch.no_grad():
      # forward translation tokens
      target = model.generate(
                            start_encoding.input_ids,
                            attention_mask=start_encoding.attention_mask,
                            max_length=max_len, # max length of generated text
                            min_length=min_len, # minimum length of the generated text
                            length_penalty=1.0, # import for correct normalization of scores
                            num_return_sequences=10, # return only a single sequence
                            **gen_kwargs
                            )
    t_strings = tokenizer.batch_decode( target, skip_special_tokens=True )
    for gen_seq_idx, t in enumerate(t_strings,0):
      time.sleep(5)
      gen_seq = "".join( t.split(" "))
      gen_seq_id = fasta_id + f"_{gen_seq_idx}"

      esmfold_api_url = 'https://api.esmatlas.com/foldSequence/v1/pdb/'
      r = requests.post(esmfold_api_url, data=gen_seq, timeout=10)
      while not r.status_code == 200:
        print("Internal Server error of ESMFold API. Sleeping 6s and then trying again.")
        time.sleep(6)
        r = requests.post(esmfold_api_url, data=gen_seq, timeout=60)

      structure = r.text
      with open(f"{query_ID}/gen_seqs/{gen_seq_id}.pdb","w") as out_f:
        out_f.write(structure)
      print("Success")

##########################
Input is 3Di: True
Example sequence: >AF-A0A6G0XC32-F1-model_v4
ddfdaedepacccpdqedadagarhqyyhhanvqnhaadaynyqnhqyyepenhanhqlvrhlvnhdhqnhayyehalhhhpdddplvsllvnlqshqnhqhyehalpqaelssvvsclpsnpnhqeyeyedapprpphhhhddpvsvvvscvvvvshdydyd
##########################
d d f d a e d e p a c c c p d q e d a d a g a r h q y y h h a n v q n h a a d a y n y q n h q y y e p e n h a n h q l v r h l v n h d h q n h a y y e h a l h h h p d d d p l v s l l v n l q s h q n h q h y e h a l p q a e l s s v v s c l p s n p n h q e y e y e d a p p r p p h h h h d d p v s v v v s c v v v v s h d y d y d
Success
Internal Server error of ESMFold API. Sleeping 6s and then trying again.
Internal Server error of ESMFold API. Sleeping 6s and then trying again.
Internal Server error of ESMFold API. Sleeping 6s and then trying again.
Internal Server error of ESMFold API. Sleeping 6s and then trying again.
Internal Server error of ESMFold API. Sleeping 6s and then trying again.
Internal Server err

In [13]:
#@title Compute RMSD between generated sequences (ESMFold) and groundtruth (AFDB). { display-mode: "form" }
# https://colab.research.google.com/github/pb3lab/ibm3202/blob/master/tutorials/lab02_molviz.ipynb
import Bio.PDB
from pathlib import Path
# Start the parser
pdb_parser = Bio.PDB.PDBParser(QUIET = True)

# Get the structures
ref_structure = pdb_parser.get_structure("reference", f"{query_ID}/AF-{query_ID}-F1-model_v4.pdb")
best_RMSD = None
for p in Path(f"{query_ID}/gen_seqs").rglob("*.pdb"):
  sample_structure = pdb_parser.get_structure("sample", p)

  # Use the first model in the pdb-files for alignment
  # Change the number 0 if you want to align to another structure
  ref_model    = ref_structure[0]
  sample_model = sample_structure[0]

  # Make a list of the atoms (in the structures) you wish to align.
  # In this case we use CA atoms whose index is in the specified range
  ref_atoms = []
  sample_atoms = []

  # Iterate of all chains in the model in order to find all residues
  for ref_chain in ref_model:
    # Iterate of all residues in each model in order to find proper atoms
    for ref_res in ref_chain:
        # Append CA atom to list
        ref_atoms.append(ref_res['CA'])

  # Do the same for the sample structure
  for sample_chain in sample_model:
    for sample_res in sample_chain:
        sample_atoms.append(sample_res['CA'])

  # Now we initiate the superimposer:
  super_imposer = Bio.PDB.Superimposer()
  super_imposer.set_atoms(ref_atoms, sample_atoms)
  super_imposer.apply(sample_model.get_atoms())

  # Print RMSD:
  print(f'The calculated RMSD for {p} is: {super_imposer.rms}Å')
  if best_RMSD is None or best_RMSD > super_imposer.rms:
    best_RMSD = super_imposer.rms
    # Save the aligned version
    print("Saving better structure.")
    io = Bio.PDB.PDBIO()
    io.set_structure(sample_structure)
    io.save(f"{query_ID}/aligned.pdb")

The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_3.pdb is: 2.138284079518373Å
Saving better structure.
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_2.pdb is: 2.4801749000225137Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_5.pdb is: 2.9416590374543845Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_9.pdb is: 2.6099442129021235Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_4.pdb is: 2.373950990804826Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_0.pdb is: 2.383053581381057Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_7.pdb is: 2.2018923792594367Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_6.pdb is: 3.4498973487792317Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_1.pdb is: 2.206357915434644Å
The calculated RMSD for A0A6G0XC32/gen_seqs/AF-A0A6G0XC32-F1-model_v4_8.pdb is:

In [9]:
#@title Display superposition of generated sequence with lowest RMSD. { display-mode: "form" }
# https://colab.research.google.com/github/pb3lab/ibm3202/blob/master/tutorials/lab02_molviz.ipynb
import py3Dmol
view=py3Dmol.view()
view.addModel(open(f"{query_ID}/AF-{query_ID}-F1-model_v4.pdb", 'r').read(),'pdb')
view.addModel(open(f'{query_ID}/aligned.pdb', 'r').read(),'pdb')
view.zoomTo()
view.setBackgroundColor('white')
view.setStyle({'model':-1},{'cartoon': {'color':'green'}})
view.setStyle({'model':-2},{'cartoon': {'color':'red'}})
view.show()
print("Reference structure (AFDB) shown in green; ESMFold prediction of generated sequence shown in red.")