<a href="https://colab.research.google.com/github/engelberger/intro-ai-pd/blob/master/notebooks/3_colab_proteinmpnn_enzyme_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AI based protein design UDLA 2023 
## **Using ProteinMPNN to design plastic degrading enzymes**
#### **Base notebook Authors:**
- **[Sergey Ovchinnikov](https://www.solab.org/)**
- **[Simon Kozlov](https://twitter.com/sim0nsays?lang=en)**


---

fixbb monomer design:
 - `pdb="6MRR" chains="A"`

fixbb homooligomer design:
 - `pdb="5XZK" chains="A,B,C" homooligomer=True`

binder design:
 - `pdb="1SSC" chains="A,B" fix_pos="A"`

---


In [None]:
#@title Install colabdesign
import os
try:
  import colabdesign
except:
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  os.system("ln -s /usr/local/lib/python3.7/dist-packages/colabdesign colabdesign")

from colabdesign.mpnn import mk_mpnn_model, clear_mem
from colabdesign.shared.protein import pdb_to_string

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import pandas as pd
import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

from google.colab import files
from google.colab import data_table
data_table.enable_dataframe_formatter()

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"


# First, update the package list and install necessary packages
!sudo apt-get update && sudo apt-get install -y wget curl ncbi-blast+

# Since Google Colab is based on Ubuntu, we can download and install the MAFFT package directly
# Define the version number as a variable for easier updates
mafft_version = "7.475"

# Download and install MAFFT
!wget https://mafft.cbrc.jp/alignment/software/mafft_{mafft_version}-1_amd64.deb
!sudo dpkg -i mafft_{mafft_version}-1_amd64.deb
!rm mafft_{mafft_version}-1_amd64.deb

# Install specific Python packages
!pip install --quiet 'colabfold[alphafold-minus-jax]@git+https://github.com/sokrypton/ColabFold'
!pip install --upgrade dm-haiku
!pip install git+https://github.com/jonathanking/BioPython-A3MIO

# Create a data directory (if needed)
!mkdir -p /content/data


from os.path import exists
"""
This script was obtained from ProteinMPNN helper scripts/other tools. However, additional modifications had to be made to
fit this particular tool
"""

import pandas as pd
import numpy as np
import json


def softmax(x, T):
    return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True)

def parse_pssm(path, seq_len):
    data = pd.read_csv(path, skiprows=2)
    floats_list_list = []
    for i in range(seq_len):
        str1 = data.values[i][0][8:]
        floats_list = []
        for item in str1.split():
            floats_list.append(float(item))
        floats_list_list.append(floats_list)
    np_lines = np.array(floats_list_list)
    return np_lines


def make_dict(seq_len):
    np_lines = parse_pssm('.temp/pssm.txt', seq_len)

    mpnn_alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    input_alphabet = 'ARNDCQEGHILKMFPSTWYV'

    permutation_matrix = np.zeros([20,21])
    for i in range(20):
        letter1 = input_alphabet[i]
        for j in range(21):
            letter2 = mpnn_alphabet[j]
            if letter1 == letter2:
                permutation_matrix[i,j]=1.

    pssm_log_odds = np_lines[:,:20] @ permutation_matrix
    pssm_probs = np_lines[:,20:40] @ permutation_matrix

    X_mask = np.concatenate([np.zeros([1,20]), np.ones([1,1])], -1)

    def softmax(x, T):
        return np.exp(x/T)/np.sum(np.exp(x/T), -1, keepdims=True)

    #Load parsed PDBs:  
    with open('.temp/parsed_pdbs.jsonl', 'r') as json_file:
        json_list = list(json_file)

    my_dict = {}
    for json_str in json_list:
        result = json.loads(json_str)
        all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
        pssm_dict = {}
        for chain in all_chain_list:
            pssm_dict[chain] = {}
            pssm_dict[chain]['pssm_coef'] = (np.ones(len(result['seq_chain_A']))).tolist() #a number between 0.0 and 1.0 specifying how much attention put to PSSM, can be adjusted later as a flag
            pssm_dict[chain]['pssm_bias'] = (softmax(pssm_log_odds-X_mask*1e8, 1.0)).tolist() #PSSM like, [length, 21] such that sum over the last dimension adds up to 1.0
            pssm_dict[chain]['pssm_log_odds'] = (pssm_log_odds).tolist()
        my_dict[result['name']] = pssm_dict

    #Write output to:    
    with open('.temp/pssm_dict.jsonl', 'w') as f:
        f.write(json.dumps(my_dict) + '\n')



from colabfold.colabfold import run_mmseqs2
from Bio import SeqIO, AlignIO
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.mpnn.model import aa_order

import os
from os import listdir
import json
import subprocess
import pandas as pd
import time
import shutil
import requests
from io import StringIO
from collections import defaultdict
from multiprocessing.pool import ThreadPool

#import make_pssm_dict as prep_mpnn

import requests
import time

class evoFilter():
    
    def __init__(self, pdb_id: str):
        self.name = pdb_id.lower() 
        self.native = ""
        self.seq_len = 0
        
        
    def run(self):
            
        print("Creating temporary directories...")
        # create a temp file for this tool that will be deleted later
        os.makedirs(f"{self.name}", exist_ok=True)
        os.makedirs(f"{self.name}/results", exist_ok=True)
        os.makedirs(".temp", exist_ok=True) 
        
        
        # get fastas from rcsb database
        # self.get_fasta(self.name, f"example_run/{self.name}.fasta")
        # run fastas on colabfold alphafold2_batch googlecolab and place folded pdbs in example_run
        
        # update native sequence and length
        print("Reading fasta file...")
        os.makedirs(f"benchmarked", exist_ok=True)
        # Download from the link
        url = f"https://raw.githubusercontent.com/engelberger/intro-ai-pd/master/fasta/6ane.fasta"
        os.system(f"wget {url} -O benchmarked/6ane.fasta")
        url2 = f"https://raw.githubusercontent.com/engelberger/intro-ai-pd/master/pdbs/6ane_A.pdb"
        os.system(f"wget {url2} -O benchmarked/6ane.pdb")
        
        with open(f"benchmarked/{self.name}.fasta", "r") as fasta:
            self.native = fasta.readlines()[-1].replace("\n","")
            self.seq_len = len(self.native)

        # run sequence alignment using colabdesign api
        # at the moment the goal here is to retrieve a sequence alignment file
        # which later is trasnformed into a multifasta file which
        # later is transformed into a fasta alignment file
        # that will be used to make a pssm matrix using psiblast
        print("Running MMSEQS2 to get related evolutionary sequences...")
        run_mmseqs2(self.native, self.name)

        # TODO : Implement a redundancy check with HHblits 
        # similar to the one in predict.ipynb by sergey in ColabDesign
        
        # find and parse a3m file to write to string
        print("Parsing a3m file...")
        seq_str = ''
        records = SeqIO.parse(f"{self.name}_env/uniref.a3m", "fasta")

        for record in records:
            seq_str += (">" + str(record.id) + '\n')
            seq_str += (str(record.seq).replace("-","") + '\n')

        #run ClustalO on mmseq2 sequences
        print("Running ClustalO for sequence alignment...")
        child = subprocess.Popen(['mafft', '--quiet', '-'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
        child.stdin.write(seq_str.encode())
        child_out = child.communicate()[0].decode('utf8')
        alignment = AlignIO.read(StringIO(child_out), 'fasta')
        child.stdin.close()

        # take aligned target sequence, and get all positions with gaps
        # remove these gap positions from other sequences
        print("Processing sequence alignment results...")
        reference = alignment[0].seq
        to_del = []
        for i in range(len(reference)):
            if reference[i] == "-":
                to_del.append(i)
                
        for align in alignment:
            pos = {}
            for i in range(len(align.seq)):
                pos[i] = align.seq[i]
            for key in to_del:
                pos.pop(key)
            align.seq = "".join(list(pos.values()))
        
        
        # run MPNN and update json database
        print("Running Blast...")
        os.makedirs("temp",  exist_ok=True)
        status_one = self.get_pssm(alignment=alignment)
        #print("MPNN status: ", status_one)
        #status_two = self.run_mpnn()
        #print("MPNN status: ", status_two)
#
        ## obtain top 10 mutations
        #print("Getting top 10 mutations...")
        #data = self.get_muts()
        #data.to_csv(f"{self.name}/results/{self.name}.csv", index=False)
        #
        #print("Writing sequences to JSON file...")
        #with open(f"{self.name}/results/{self.name}_sequences.json", "w") as f:
        #    json.dump(self.get_seqs(data), f)
        #
        ## delete temp folder and MSA folder
        #print("Cleaning up temporary files...")
        #shutil.rmtree("./.temp")
        #shutil.rmtree(f"./{self.name}_env")

        return alignment

    # TODO: get fasta from user given pdb file
    def get_fasta(self, pdb_id, output_file):
        url = f"https://www.rcsb.org/fasta/entry/{pdb_id}"
        
        try:
            response = requests.get(url)
            response.raise_for_status()  # Raise an exception for any HTTP error
        except requests.exceptions.RequestException as e:
            print(f"Failed to retrieve FASTA for PDB ID {pdb_id}: {e}")
            return
        
        fasta_text = response.text.strip()
        if fasta_text:
            with open(output_file, "w") as f:
                f.write(fasta_text)
        else:
            print(f"No FASTA data found for PDB ID {pdb_id}")
            

    def get_fasta_retry(self, pdb_id, output_file, max_retries=5):
        url = f"{self.base_url}{pdb_id}"
        retry_count = 0
        wait_time = 1  # initial wait time in seconds, you could adjust as needed

        while retry_count < max_retries:
            try:
                response = requests.get(url)
                response.raise_for_status()
                fasta_text = response.text.strip()
                
                if fasta_text:
                    with open(output_file, "w") as f:
                        f.write(fasta_text)
                    print(f"FASTA data for PDB ID {pdb_id} saved to {output_file}")
                    return
                else:
                    print(f"No FASTA data found for PDB ID {pdb_id}")
                    return
            except requests.exceptions.RequestException as e:
                print(f"Attempt {retry_count + 1}: Failed to retrieve FASTA for PDB ID {pdb_id}. Error: {e}")
                retry_count += 1
                time.sleep(wait_time)
                wait_time *= 2  # double the wait time for the next retry
        print("Max retries reached. Failed to retrieve FASTA data.")
        
    def get_pssm(self, alignment):
        """Get a PSSM matrix by running psi blast against MSA object; process it for MPNN"""
        
        # Write MSA to FASTA file and turn it into a database
        db = f".temp/database.fasta"
        seq_records = [SeqIO.SeqRecord(seq=record.seq, id=record.id, description="") for record in alignment]
        tmp = open(db, 'w')
        for records in seq_records:
            tmp.write(">" + records.id + "\n")
            tmp.write(records.seq + "\n")
        tmp.close()
        subprocess.run("makeblastdb -in .temp/database.fasta -dbtype prot -out .temp/db", shell=True)

        # prepare for psi blast pssm
        query = f"benchmarked/{self.name}.fasta"

        # obtain a PSSM matrix using psi blast
        subprocess.run(f'psiblast -query {query} -db .temp/db -num_iterations 3 -out_ascii_pssm \
                       ./temp/pssm.txt -outfmt 0', shell=True)
        
        return "PSSM matrix has been successfully produced and saved in temp/pssm.txt"

    def _task(self, i):
        # multithreaded process of runing MPNN
        
        pos_data = defaultdict(list)
        
        # get PSSM matrix
        with open('.temp/pssm_dict.jsonl','r') as pssm:
            temp = json.load(pssm)
            bias = temp[self.name]["A"]["pssm_bias"]
            
        bias_matrix = bias[i]
        pos = i + 1
        
        if i < 1:
            fixed = f"{pos+1}-{len(bias)}"
        elif pos >= len(bias):
            fixed = f"1-{i}"
        elif i == 1:
            fixed = f"1, {pos+1}-{len(bias)}"
        elif ((pos+1) == len(bias)):
            fixed = f"1-{i},{len(bias)}"
        else:
            fixed = f"1-{i},{pos+1}-{len(bias)}"
            
        # create mpnn model
        mpnn_model = mk_mpnn_model()
        print(f"Processing {self.name}/{self.name}.pdb")
        mpnn_model.prep_inputs(pdb_filename=f"{self.name}/{self.name}.pdb", 
                            fix_pos=fixed,
                            chain="A",)
        
        # adjust PSSM probabilities
        alphabet = 'ACDEFGHIKLMNPQRSTVWY'
        for aa in range(len(alphabet)):
            mpnn_model._inputs["bias"][i,aa_order[alphabet[aa]]] = bias_matrix[aa]
        
        samples = mpnn_model.sample_parallel(batch=5)
        
        # record the data
        ctr = 0
        for sequence in samples["seq"]:
            for j in range(self.seq_len):
                if self.native[j] != sequence[j]:
                    pos_data["sequence"].append(sequence)
                    mut = f"{self.native[j]}    {sequence[j]}" # record what native position has been mutated to
                    pos_data["mutation"].append(mut) 
                    pos_data["score"].append(samples["score"][ctr])
            ctr += 1   
        
        if pos_data:
            df = pd.DataFrame(pos_data)
            df['score'] = pd.to_numeric(df['score'])
            df2 = df.sort_values(by='score', ascending=False).reset_index(drop=True)
            df3 = df2.head(1)
            df3.to_csv(f".temp/csv/{self.name}_{i}.csv", index=False, header=False)
        
        return "One position has been designed, and the amino acid mutation that led \
            to the greatest improvement in thermostability has been recorded."

    def run_mpnn(self):
        """
        Run helper scripts and protein MPNN
        """
        
        os.makedirs(f"./temp/csv", exist_ok=True)

        # Run helper scripts to set up for MPNN
        subprocess.run(f"python ProteinMPNN/helper_scripts/parse_multiple_chains.py \
                        --input_path {self.name} \
                        --output_path .temp/parsed_pdbs.jsonl", shell=True)
        
        prep_mpnn.make_dict(self.seq_len)
        
        start = time.time()

        # Run MPNN on sequence and mutating +_scoring one by one
        print("Mutating positions one by one...")
        # Set the maximum number of threads to 1 
        # to avoid overloading the laptop,
        # should be modified in production
        pool = ThreadPool(1)
        pool.map(self._task, range(self.seq_len))
        pool.close()
        pool.join()
        
        end = time.time()
        
        return f"MPNN has finished running in {end-start} seconds, and all data has been recorded in .temp/csv."
   
    def get_muts(self):
        
        all_data = {
            "Position": [],
            "Sequence": [],
            "Wildtype": [],
            "Mutant": [],
            "Score": []
        }
        
        # iterate through the csv directory
        filenames = listdir(f".temp/csv")
        files = [filename for filename in filenames if filename.endswith(".csv")]
        
        for file in files:
            all_data["Position"].append(int(file.replace(".csv","").split("_")[-1]) + 1)
            with open(os.path.join(f".temp/csv", file), "r") as f:
                data = f.readlines()[0].split(",")
                all_data["Sequence"].append(data[0])
                all_data["Wildtype"].append(data[1].split(" ")[0])
                all_data["Mutant"].append(data[1].split(" ")[-1])
                all_data["Score"].append(float(data[2].replace("\n","")))
        
        df = pd.DataFrame.from_dict(all_data)
        df.sort_values("Score", ascending=False, inplace=True,ignore_index=True)
        
        return df.head(10)
    
    def get_seqs(self, data):
        print(data)
        # create sequence with the individual mutations
        to_mutate = dict(zip(data['Position'], data['Mutant']))
        mut_seq = list(self.native)
        for i in range(len(self.native)):
            if (i+1) in to_mutate:
                mut_seq[i] = f"<{to_mutate[i+1]}>"
        return "".join(mut_seq)

    def parse_pssm(file_path):
      with open(file_path, 'r') as file:
          lines = file.readlines()
      
      # Skip the header and find the start of the matrix
      start_index = 0
      for i, line in enumerate(lines):
          if line.strip() and all(c.isalpha() or c.isspace() for c in line.strip()):
              start_index = i + 1
              break
      
      # Parse the matrix
      pssm = []
      for line in lines[start_index:]:
          if line.strip():  # Skip empty lines
              parts = line.split()
              try:
                  position = int(parts[0])
              except ValueError:
                  continue  # Skip lines that don't start with an integer
              
              residue = parts[1]
              scores = [int(x) for x in parts[2:22]]  # Assuming there are 20 amino acids
              other_data = parts[22:]  # Capture any additional data at the end of the line
              pssm.append({
                  'position': position,
                  'residue': residue,
                  'scores': scores,
                  'other_data': other_data
              })
      
      return pssm

def filter_positions_with_fewer_negatives(pssm_data, max_negative_scores):
    filtered_positions = []
    for entry in pssm_data:
        negative_count = sum(1 for score in entry['scores'] if score < 0)
        if negative_count < max_negative_scores:
            filtered_positions.append(entry['position'])
    return filtered_positions

def print_positions_in_format(positions):
    # Helper function to format a list of positions as ranges
    def format_as_ranges(positions):
        if not positions:
            return ""
        ranges = []
        start = positions[0]
        end = start
        for pos in positions[1:]:
            if pos == end + 1:
                end = pos
            else:
                ranges.append(f"{start}-{end}" if start != end else f"{start}")
                start = pos
                end = start
        ranges.append(f"{start}-{end}" if start != end else f"{start}")
        return ",".join(ranges)
    
    formatted_positions = format_as_ranges(positions)
    print(f"Positions to keep fixed in the sequence: {formatted_positions}")
    return formatted_positions
    

In [None]:
example_run = ["6ane"]
for example in example_run:
    x = evoFilter(example)
    alignment = x.run()
# Usage
pssm_file_path = './temp/pssm.txt'
pssm_data = evoFilter.parse_pssm(pssm_file_path)

In [None]:
n = 18  # Set the maximum number of negative scores allowed
filtered_positions = filter_positions_with_fewer_negatives(pssm_data, n)
print(f"Fixing {len(filtered_positions)} positions")
formatted_positions = print_positions_in_format(filtered_positions)

In [None]:
%%time
#@title Run ProteinMPNN to design new sequences for given backbone

import warnings, os, re
warnings.simplefilter(action='ignore', category=FutureWarning)

os.system("mkdir -p output")

# USER OPTIONS
#@markdown #### ProteinMPNN options
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
#@markdown #### Input Options
pdb='benchmarked/6ane.pdb' #@param {type:"string"}
#@markdown - leave blank to get an upload prompt
chains = "A" #@param {type:"string"}
homooligomer = False #@param {type:"boolean"}
#@markdown #### Design constraints
fix_pos = "1,4-6,9,12-19,21,23-28,30-33,40-42,44-47,49-52,55-56,59-67,69,71-73,76,81-84,86-89,91,93-94,96-97,99-101,103-121,123-124,126,128-129,131,138-149,151-155,157-168,170-177,179-180,183-186,188,190-198,200,202-208,210-228,230,232-234,236,238,241-242,244,246,248-260,262" #@param {type:"string"}
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify chain specific constraints (example: `A1-10,B1-20`)
#@markdown - you can also specify to fix entire chain(s) (example: `A`)
inverse = False #@param {type:"boolean"}
#@markdown - inverse the `fix_pos` selection (define position to "free" [or design] instead of "fix")
rm_aa = "" #@param {type:"string"}
#@markdown - specify amino acid(s) to exclude (example: `C,A,T`)

#@markdown #### Design Options
num_seqs = 32 #@param ["32", "64", "128", "256", "512", "1024"] {type:"raw"}
sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.

#@markdown Note: designed sequences are saved to `design.fasta`

# cleaning user options
chains = re.sub("[^A-Za-z]+",",", chains)
if fix_pos == "": fix_pos = None
rm_aa = ",".join(list(re.sub("[^A-Z]+","",rm_aa.upper())))
if rm_aa == "": rm_aa = None

pdb_path = get_pdb(pdb)
if "mpnn_model" not in dir():
  mpnn_model = mk_mpnn_model(model_name)

mpnn_model.prep_inputs(pdb_filename=pdb_path,
                       chain=chains, homooligomer=homooligomer,
                       fix_pos=fix_pos, inverse=inverse,
                       rm_aa=rm_aa, verbose=True)
out = mpnn_model.sample(num=num_seqs//32, batch=32,
                        temperature=sampling_temp,
                        rescore=homooligomer)

with open("design.fasta","w") as fasta:
  for n in range(num_seqs):
    line = f'>score:{out["score"][n]:.3f}_seqid:{out["seqid"][n]:.3f}\n{out["seq"][n]}'
    fasta.write(line+"\n")

labels = ["score","seqid","seq"]
data = [[out[k][n] for k in labels] for n in range(num_seqs)]

df = pd.DataFrame(data, columns=labels)
df.to_csv('output/mpnn_results.csv')
data_table.DataTable(df.round(3))

In [None]:
#@title ### Get amino acid probabilities from ProteinMPNN (optional)
mode = "unconditional" #@param ["unconditional", "conditional", "conditional_fix_pos"]
#@markdown - `unconditional` - P(sequence | structure) 
#@markdown - `conditional` - P(sequence | structure, sequence)
#@markdown - `conditional_fix_pos` - P(sequence[not_fixed] | structure, sequence[fix_pos])
show = "all" 
import plotly.express as px
from scipy.special import softmax
from colabdesign.mpnn.model import residue_constants
L = sum(mpnn_model._lengths)
fix_pos = mpnn_model._inputs.get("fix_pos",[])
free_pos = np.delete(np.arange(L),fix_pos)

if mode == "conditional":
  ar_mask = 1-np.eye(L)
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None
if mode == "conditional_fix_pos":
  assert "fix_pos" in mpnn_model._inputs, "no positions fixed"
  ar_mask = 1-np.eye(L)
  p = np.delete(np.arange(L),mpnn_model._inputs["fix_pos"])
  ar_mask[free_pos[:,None],free_pos[None,:]] = 0
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  logits = logits[free_pos]
  pdb_labels = np.array([f"{i}_{c}" for c,i in zip(mpnn_model.pdb["idx"]["chain"], mpnn_model.pdb["idx"]["residue"])])
  pdb_labels = pdb_labels[free_pos]
else:
  ar_mask = np.zeros((L,L))
  logits = mpnn_model.score(ar_mask=ar_mask)["logits"]
  pdb_labels = None

pssm = softmax(logits,-1)
np.savetxt("output/pssm.txt",pssm)

fig = px.imshow(np.array(pssm).T,
               labels=dict(x="positions", y="amino acids", color="probability"),
               y=residue_constants.restypes + ["X"],
               x=pdb_labels,
               zmin=0,
               zmax=1,
               template="simple_white",
              )
fig.update_xaxes(side="top")
fig.show()

In [None]:
#@title Run AlphaFold Prediction on ProteinMPNN sequences (optional)
#@markdown ###AlphaFold Options
num_models = 1 #@param ["1","2","3","4","5"] {type:"raw"}
num_recycles = 1 #@param ["0","1","2","3"] {type:"raw"}
use_multimer = False #@param {type:"boolean"}
use_templates = False #@param {type:"boolean"}
rm_template_interchain = False #@param {type:"boolean"}
if not os.path.isdir("params"):
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

# where pdb files will be save:
if not os.path.isdir("output/all_pdb"):
  os.system("mkdir output/all_pdb")
else:
  os.system("rm output/all_pdb/*")

from colabdesign.af import mk_af_model
af_args = [pdb_path, chains, homooligomer,
           use_multimer, use_templates]
if "af_arg_current" not in dir() or af_args != af_arg_current:
  af_model = mk_af_model(use_multimer=use_multimer,
                         use_templates=use_templates,
                         best_metric="dgram_cce")
  af_model.prep_inputs(pdb_path,chains,homooligomer=homooligomer)
  af_arg_current = [x for x in af_args]

af_model.restart()
af_model.set_opt("template", rm_ic=rm_template_interchain)

with tqdm.notebook.tqdm(total=out["S"].shape[0], bar_format=TQDM_BAR_FORMAT) as pbar:
  for n,S in enumerate(out["S"]):
    seq = S[:af_model._len].argmax(-1)
    af_model.predict(seq=seq,
                    num_recycles=num_recycles,
                    num_models=num_models,
                    verbose=False)
    (rmsd, ptm, plddt) = (af_model.aux["log"][k] for k in ["rmsd","ptm","plddt"])
    af_model.aux["log"]["composite"] = ptm * plddt
    af_model._save_results(save_best=True, verbose=False)
    af_model.save_current_pdb(f"output/all_pdb/n{n}.pdb")
    af_model._k += 1
    pbar.update(1)

af_model.save_pdb(f"output/best.pdb")

data = []
labels = ["dgram_cce","plddt","ptm","i_ptm","rmsd","composite","mpnn","seqid","seq"]
for n,af in enumerate(af_model._tmp["log"]):
  data.append([af["dgram_cce"],
               af["plddt"],
               af["ptm"],
               af["i_ptm"],
               af["rmsd"],
               af["composite"],
               out["score"][n],
               out["seqid"][n],
               out["seq"][n]])

df = pd.DataFrame(data, columns=labels)
df.to_csv('output/alphafold_results.csv')
data_table.DataTable(df.sort_values("dgram_cce").round(3))
#@markdown Note: designed pdbs are saved to `output/all_pdb/`

In [None]:
#@title download predictions (optional)
from google.colab import files
os.system(f"zip -r output.zip output/")
files.download(f'output.zip')

In [None]:
#@title display protein (optional) {run: "auto"}
show_best = True #@param {type:"boolean"}
show_idx = 0 #@param {type:"integer"}
#@markdown - Enter index of protein to show, if `show_best` is disabled.
#@markdown - Note: these are NOT sorted and correspond to 
#@markdown the index in pandas dataframe above.
color = "pLDDT" #@param ["chain", "pLDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}
color_HP = False #@param {type:"boolean"}
animate = True #@param {type:"boolean"}
#@markdown - if `num_models` > 1, will iterate through the models when `animate` is enabled.
if not show_best:
  pdb_str = pdb_to_string(f"output/all_pdb/n{show_idx}.pdb")
else:
  pdb_str = None
af_model.plot_pdb(show_sidechains=show_sidechains,
                  show_mainchains=show_mainchains,
                  color=color, color_HP=color_HP,
                  animate=animate, pdb_str=pdb_str)

In [None]:
#@title animate (optional)
#@markdown Note: animation frames are sorted worst to best design
def sort_traj(self, metric="dgram_cce"):
  if metric in ["plddt","ptm","i_ptm","seqid","composite"]:
    metric_higher_better = True
  else:
    metric_higher_better = False
  num = len(self._tmp["traj"]["seq"])
  log = self._tmp["log"][-num:]
  if metric in log[0]:
    n = np.array([x[metric] for x in log]).argsort()
    if metric_higher_better: n = n[::-1]
    sub_traj = {k:[v[m] for m in n] for k,v in self._tmp["traj"].items()}
    return sub_traj
  else:
    return None

sub_traj= sort_traj(af_model)

color_by = "plddt" #@param ["chain", "plddt", "rainbow"]
dpi = 100 #@param {type:"integer"}
HTML(af_model.animate(traj={k:v[::-1] for k,v in sub_traj.items()}, color_by=color_by, dpi=dpi))
