<a href="https://colab.research.google.com/github/ibmm-unibe-ch/msa-tests/blob/master/FrankenFold.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FrankenMSA
## [Deep learning protein folding models predict alternative protein conformations with informative sequence alignments](https://github.com/ibmm-unibe-ch/msa-tests)

First, fill out the first form before clicking play.
You can select "Upload none" to upload nothing and not consider it at all. We allow for PDB structures directly from the PDB using it's PDB identifier.

Currently, we only support monomers and inverse fold chain "A".

In [None]:
from google.colab import output
output.enable_custom_widget_manager()
import ipywidgets as widgets
from ipywidgets import interactive, VBox, HBox, Text, Button
from IPython.display import display
import functools
import requests
import hashlib
import tarfile
import time
import os
from typing import Tuple, List
import random
from tqdm import tqdm
import numpy as np
import logging
logger = logging.getLogger(__name__)
from string import ascii_uppercase,ascii_lowercase
from pathlib import Path
import math
from google.colab import files
import uuid

#@markdown #Input Options
Query_sequence = '' #@param {type:"string"}
MSA_depth = 128 #@param {type:"integer"}

def get_pdb(pdb_code=""):
  if pdb_code == "Upload none":
    return None
  if pdb_code is None or pdb_code == "Upload":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    filename = str(uuid.uuid4())+".pdb"
    with open(filename,"wb") as out: out.write(pdb_string)
    return filename
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

def get_msa(msa_code=""):
  if msa_code is None or msa_code == "Upload none":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    filename = str(uuid.uuid4())+".a3m"
    with open(filename,"wb") as out: out.write(pdb_string)
    return filename
  else:
    return None

PDB_to_inverse_fold='Upload none' #@param {type:"string"}
pdb_path = get_pdb(PDB_to_inverse_fold)

Own_MSA='Upload none' #@param {type:"string"}
msa_path = get_msa(Own_MSA)

ProteinMPNN_model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
model_name = ProteinMPNN_model_name

# remove whitespaces
query_sequence = "".join(Query_sequence.split())

new_part_button = widgets.Button(description='Confirm')
mmseqs_type_widget = widgets.Dropdown(options=['Full','Part',],value='Full',description='Query:',disabled=False)
inverse_folding_options_widget = widgets.BoundedFloatText(value=1,min=0,max=10,step=0.1,description='Temperature',disabled=False)
repeat_text_widget =widgets.Text(value='',placeholder='Empty == query',description='Repeated text:',disabled=False)
box = None
part_type = ""
parts = []
start = 0
seq_len = len(query_sequence)
print(f"Protein has length: {seq_len}")
end = seq_len

def set_box(box_object):
  global box
  box = box_object
  return box_object
def create_start_widget(end, start:int=0):
  return widgets.BoundedIntText(value=start,min=0,max=end,step=1,description='Start',disabled=False)
def create_end_widget(end, start:int=0):
  return widgets.BoundedIntText(value=start,min=0,max=end,step=1,description='End',disabled=False)
def get_mmseqs_widget(end:int, start:int=0):
  return set_box(widgets.HBox([create_start_widget(end, start), create_end_widget(end, start), mmseqs_type_widget, new_part_button]))
def get_inverse_folding_widget(end:int, start:int=0):
  return set_box(widgets.HBox([create_start_widget(end, start), create_end_widget(end, start), inverse_folding_options_widget, new_part_button]))
def get_repeat_widget(end:int, start:int=0):
  return set_box(widgets.HBox([create_start_widget(end, start), create_end_widget(end, start), repeat_text_widget, new_part_button]))
def get_gap_widget(end:int, start:int=0):
  return set_box(widgets.HBox([create_start_widget(end, start), create_end_widget(end, start), new_part_button]))
def get_own_msa_widget(end:int, start:int=0):
  return set_box(widgets.HBox([create_start_widget(end, start), create_end_widget(end, start), new_part_button]))

def on_new_button_clicked(*args):
  global box, part_type, parts, start
  if part_type == "MMseqs":
    other = box.children[2].value
    addendum = f" query: {other}"
  elif part_type == "Inverse folding":
    other = box.children[2].value
    addendum = f" temperature: {other}"
  elif part_type == "Repeat":
    other = box.children[2].value if box.children[2].value not in ["", 'Empty == query'] else None
    addendum = f" sequence: {other}" if other else ""
  else:
    other = None
    addendum = ""
  print(f"[{box.children[0].value}:{box.children[1].value}], {part_type}{addendum}")
  parts.append((part_type,box.children[0].value, box.children[1].value, other))
  start = box.children[1].value
  return interactive(callback, Part_type=type_options)

new_part_button.on_click(functools.partial(on_new_button_clicked))

def callback(Part_type):
  global part_type, start
  part_type = Part_type
  type_callbacks = {"Gaps": get_gap_widget, 'Inverse folding':get_inverse_folding_widget, 'MMseqs':get_mmseqs_widget, 'Own MSA': get_own_msa_widget, 'Repeat': get_repeat_widget}
  outputt = type_callbacks[Part_type](seq_len,start)
  display(outputt)
type_options = ['Gaps', 'Inverse folding', 'MMseqs', 'Own MSA', 'Repeat']

interactive(callback, Part_type=type_options)

In [None]:
#@title Download hhsuite
%%capture
!wget https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-AVX2-Linux.tar.gz; tar xvfz hhsuite-3.3.0-AVX2-Linux.tar.gz; export PATH="$(pwd)/bin:$(pwd)/scripts:$PATH"

In [None]:
# From https://github.com/sokrypton/ColabFold/blob/main/colabfold/colabfold.py
#@title MMSeqs

alphabet_list = list(ascii_uppercase+ascii_lowercase)
aatypes = set('ACDEFGHIKLMNPQRSTVWY')
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def run_mmseqs2(x, prefix, use_env=True, use_filter=True,
                use_templates=False, filter=None, use_pairing=False, pairing_strategy="greedy",
                host_url="https://api.colabfold.com",
                user_agent: str = "") -> Tuple[List[str], List[str]]:
  submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"
  headers = {}
  if user_agent != "":
    headers['User-Agent'] = user_agent
  else:
    logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.")

  def submit(seqs, mode, N=101):
    n, query = N, ""
    for seq in seqs:
      query += f">{n}\n{seq}\n"
      n += 1

    while True:
      error_count = 0
      try:
        # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
        # "good practice to set connect timeouts to slightly larger than a multiple of 3"
        res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while submitting to MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break

    try:
      out = res.json()
    except ValueError:
      logger.error(f"Server didn't reply with json: {res.text}")
      out = {"status":"ERROR"}
    return out

  def status(ID):
    while True:
      error_count = 0
      try:
        res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while fetching status from MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break
    try:
      out = res.json()
    except ValueError:
      logger.error(f"Server didn't reply with json: {res.text}")
      out = {"status":"ERROR"}
    return out

  def download(ID, path):
    error_count = 0
    while True:
      try:
        res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while fetching result from MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break
    with open(path,"wb") as out: out.write(res.content)

  # process input x
  seqs = [x] if isinstance(x, str) else x

  # compatibility to old option
  if filter is not None:
    use_filter = filter

  # setup mode
  if use_filter:
    mode = "env" if use_env else "all"
  else:
    mode = "env-nofilter" if use_env else "nofilter"

  if use_pairing:
    use_templates = False
    mode = ""
    # greedy is default, complete was the previous behavior
    if pairing_strategy == "greedy":
      mode = "pairgreedy"
    elif pairing_strategy == "complete":
      mode = "paircomplete"
    if use_env:
      mode = mode + "-env"

  # define path
  path = f"{prefix}_{mode}"
  if not os.path.isdir(path): os.mkdir(path)

  # call mmseqs2 api
  tar_gz_file = f'{path}/out.tar.gz'
  N,REDO = 101,True

  # deduplicate and keep track of order
  seqs_unique = []
  #TODO this might be slow for large sets
  [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
  Ms = [N + seqs_unique.index(seq) for seq in seqs]
  # lets do it!
  if not os.path.isfile(tar_gz_file):
    TIME_ESTIMATE = 150 * len(seqs_unique)
    with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
      while REDO:
        pbar.set_description("SUBMIT")

        # Resubmit job until it goes through
        out = submit(seqs_unique, mode, N)
        while out["status"] in ["UNKNOWN", "RATELIMIT"]:
          sleep_time = 5 + random.randint(0, 5)
          logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
          # resubmit
          time.sleep(sleep_time)
          out = submit(seqs_unique, mode, N)

        if out["status"] == "ERROR":
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

        if out["status"] == "MAINTENANCE":
          raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')

        # wait for job to finish
        ID,TIME = out["id"],0
        pbar.set_description(out["status"])
        while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
          t = 5 + random.randint(0,5)
          logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
          time.sleep(t)
          out = status(ID)
          pbar.set_description(out["status"])
          if out["status"] == "RUNNING":
            TIME += t
            pbar.update(n=t)
          #if TIME > 900 and out["status"] != "COMPLETE":
          #  # something failed on the server side, need to resubmit
          #  N += 1
          #  break

        if out["status"] == "COMPLETE":
          if TIME < TIME_ESTIMATE:
            pbar.update(n=(TIME_ESTIMATE-TIME))
          REDO = False

        if out["status"] == "ERROR":
          REDO = False
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

      # Download results
      download(ID, tar_gz_file)

  # prep list of a3m files
  if use_pairing:
    a3m_files = [f"{path}/pair.a3m"]
  else:
    a3m_files = [f"{path}/uniref.a3m"]
    if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")

  # extract a3m files
  if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
    with tarfile.open(tar_gz_file) as tar_gz:
      tar_gz.extractall(path)

  # templates
  if use_templates:
    templates = {}
    #print("seq\tpdb\tcid\tevalue")
    for line in open(f"{path}/pdb70.m8","r"):
      p = line.rstrip().split()
      M,pdb,qid,e_value = p[0],p[1],p[2],p[10]
      M = int(M)
      if M not in templates: templates[M] = []
      templates[M].append(pdb)
      #if len(templates[M]) <= 20:
      #  print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")

    template_paths = {}
    for k,TMPL in templates.items():
      TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
      if not os.path.isdir(TMPL_PATH):
        os.mkdir(TMPL_PATH)
        TMPL_LINE = ",".join(TMPL[:20])
        response = None
        while True:
          error_count = 0
          try:
            # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
            # "good practice to set connect timeouts to slightly larger than a multiple of 3"
            response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers)
          except requests.exceptions.Timeout:
            logger.warning("Timeout while submitting to template server. Retrying...")
            continue
          except Exception as e:
            error_count += 1
            logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)")
            logger.warning(f"Error: {e}")
            time.sleep(5)
            if error_count > 5:
              raise
            continue
          break
        with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
          tar.extractall(path=TMPL_PATH)
        os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
        with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
          f.write("")
      template_paths[k] = TMPL_PATH

  # gather a3m lines
  a3m_lines = {}
  for a3m_file in a3m_files:
    update_M,M = True,None
    for line in open(a3m_file,"r"):
      if len(line) > 0:
        if "\x00" in line:
          line = line.replace("\x00","")
          update_M = True
        if line.startswith(">") and update_M:
          M = int(line[1:].rstrip())
          update_M = False
          if M not in a3m_lines: a3m_lines[M] = []
        a3m_lines[M].append(line)

  # return results

  a3m_lines = ["".join(a3m_lines[n]) for n in Ms]

  if use_templates:
    template_paths_ = []
    for n in Ms:
      if n not in template_paths:
        template_paths_.append(None)
        #print(f"{n-N}\tno_templates_found")
      else:
        template_paths_.append(template_paths[n])
    template_paths = template_paths_


  return (a3m_lines, template_paths) if use_templates else a3m_lines


#########################################################################
# utils
#########################################################################
def get_hash(x):
  return hashlib.sha1(x.encode()).hexdigest()

def homooligomerize(msas, deletion_matrices, homooligomer=1):
 if homooligomer == 1:
  return msas, deletion_matrices
 else:
  new_msas = []
  new_mtxs = []
  for o in range(homooligomer):
    for msa,mtx in zip(msas, deletion_matrices):
      num_res = len(msa[0])
      L = num_res * o
      R = num_res * (homooligomer-(o+1))
      new_msas.append(["-"*L+s+"-"*R for s in msa])
      new_mtxs.append([[0]*L+m+[0]*R for m in mtx])
  return new_msas, new_mtxs

# keeping typo for cross-compatibility
def homooliomerize(msas, deletion_matrices, homooligomer=1):
  return homooligomerize(msas, deletion_matrices, homooligomer=homooligomer)

def homooligomerize_heterooligomer(msas, deletion_matrices, lengths, homooligomers):
  '''
  ----- inputs -----
  msas: list of msas
  deletion_matrices: list of deletion matrices
  lengths: list of lengths for each component in complex
  homooligomers: list of number of homooligomeric copies for each component
  ----- outputs -----
  (msas, deletion_matrices)
  '''
  if max(homooligomers) == 1:
    return msas, deletion_matrices

  elif len(homooligomers) == 1:
    return homooligomerize(msas, deletion_matrices, homooligomers[0])

  else:
    frag_ij = [[0,lengths[0]]]
    for length in lengths[1:]:
      j = frag_ij[-1][-1]
      frag_ij.append([j,j+length])

    # for every msa
    mod_msas, mod_mtxs = [],[]
    for msa, mtx in zip(msas, deletion_matrices):
      mod_msa, mod_mtx = [],[]
      # for every sequence
      for n,(s,m) in enumerate(zip(msa,mtx)):
        # split sequence
        _s,_m,_ok = [],[],[]
        for i,j in frag_ij:
          _s.append(s[i:j]); _m.append(m[i:j])
          _ok.append(max([o != "-" for o in _s[-1]]))

        if n == 0:
          # if first query sequence
          mod_msa.append("".join([x*h for x,h in zip(_s,homooligomers)]))
          mod_mtx.append(sum([x*h for x,h in zip(_m,homooligomers)],[]))

        elif sum(_ok) == 1:
          # elif one fragment: copy each fragment to every homooligomeric copy
          a = _ok.index(True)
          for h_a in range(homooligomers[a]):
            _blank_seq = [["-"*l]*h for l,h in zip(lengths,homooligomers)]
            _blank_mtx = [[[0]*l]*h for l,h in zip(lengths,homooligomers)]
            _blank_seq[a][h_a] = _s[a]
            _blank_mtx[a][h_a] = _m[a]
            mod_msa.append("".join(["".join(x) for x in _blank_seq]))
            mod_mtx.append(sum([sum(x,[]) for x in _blank_mtx],[]))
        else:
          # else: copy fragment pair to every homooligomeric copy pair
          for a in range(len(lengths)-1):
            if _ok[a]:
              for b in range(a+1,len(lengths)):
                if _ok[b]:
                  for h_a in range(homooligomers[a]):
                    for h_b in range(homooligomers[b]):
                      _blank_seq = [["-"*l]*h for l,h in zip(lengths,homooligomers)]
                      _blank_mtx = [[[0]*l]*h for l,h in zip(lengths,homooligomers)]
                      for c,h_c in zip([a,b],[h_a,h_b]):
                        _blank_seq[c][h_c] = _s[c]
                        _blank_mtx[c][h_c] = _m[c]
                      mod_msa.append("".join(["".join(x) for x in _blank_seq]))
                      mod_mtx.append(sum([sum(x,[]) for x in _blank_mtx],[]))
      mod_msas.append(mod_msa)
      mod_mtxs.append(mod_mtx)
    return mod_msas, mod_mtxs

def chain_break(idx_res, Ls, length=200):
  # Minkyung's code
  # add big enough number to residue index to indicate chain breaks
  L_prev = 0
  for L_i in Ls[:-1]:
    idx_res[L_prev+L_i:] += length
    L_prev += L_i
  return idx_res

def read_pdb_renum(pdb_filename, Ls=None):
  if Ls is not None:
    L_init = 0
    new_chain = {}
    for L,c in zip(Ls, alphabet_list):
      new_chain.update({i:c for i in range(L_init,L_init+L)})
      L_init += L

  n,pdb_out = 1,[]
  resnum_,chain_ = 1,"A"
  for line in open(pdb_filename,"r"):
    if line[:4] == "ATOM":
      chain = line[21:22]
      resnum = int(line[22:22+5])
      if resnum != resnum_ or chain != chain_:
        resnum_,chain_ = resnum,chain
        n += 1
      if Ls is None: pdb_out.append("%s%4i%s" % (line[:22],n,line[26:]))
      else: pdb_out.append("%s%s%4i%s" % (line[:21],new_chain[n-1],n,line[26:]))
  return "".join(pdb_out)


def kabsch(a, b, weights=None, return_v=False):
  a = np.asarray(a,float)
  b = np.asarray(b,float)
  if weights is None: weights = np.ones(len(b))
  else: weights = np.asarray(weights,float)
  B = np.einsum('ji,jk->ik', weights[:, None] * a, b)
  u, s, vh = np.linalg.svd(B)
  if np.linalg.det(u @ vh) < 0: u[:, -1] = -u[:, -1]
  if return_v: return u
  else: return u @ vh

In [None]:
#@title MMseqs clean-up

from typing import List

def get_mmseqs_from_string(seq:str):
  lines = run_mmseqs2(seq, "FrankenMSA", user_agent="FrankenMSA/jannik.gut@unibe.ch")
  lines = lines[0].split("\n")
  return lines

def dump_msa_to_file(msa:List[str], path:Path):
  with open(path, "w") as f:
    f.write("\n".join(msa))

def read_msa_from_file(path:Path):
  with open(path, "r") as f:
    return [line.strip() for line in f.read().split("\n") if len(line.strip())>0]

def sequences_from_msa(msa:List[str]):
  return [line for line in msa if not line.startswith(">")]

def process_msa(msa:List[str], depth:int=128):
  if isinstance(msa, Path):
    msa = read_msa_from_file(msa)
  length = len(msa[1])
  dump_msa_to_file(msa, Path("msa_input.a3m"))
  input_path = "msa_input.a3m"
  formated_path = "msa_formated.a3m"
  cut_path = "msa_cut.a3m"
  output_path = "msa_output.a3m"
  sh = f"""
perl /content/scripts/reformat.pl -r -l {length} {input_path} {formated_path};
/content/bin/hhfilter -M first -diff {depth} -i {formated_path} -o {cut_path};
head -n {2*depth} {cut_path}>{output_path}
"""
  with open('process_msa_script.sh', 'w') as file:
    file.write(sh)
  !bash process_msa_script.sh &> out
  return read_msa_from_file(output_path)

def run_mmseqs_pipeline(seq:str, depth:int=128):
  temp_msa = get_mmseqs_from_string(seq)
  msa = process_msa(temp_msa, depth=depth)
  if len(msa)<(2*depth):
    multiplied_msa = msa*int((2*depth)/len(msa))
    if ((2*depth)%len(msa)) !=0:
      added_msa = msa[:(1+(2*depth)%len(msa))]
  out_msa =multiplied_msa+added_msa
  return out_msa

In [None]:
#@title Simpler operations

def concat_horizontally(msa1:List[str], msa2:List[str]):
  for i in range(len(msa1)):
    msa1[i] += msa2[i]
  return msa1

def select_horizontally(msa:List[str], start:int, end:int):
  return_msa = [""]*len(msa)
  for i in range(len(msa)):
    return_msa[i] = msa[i][start:end]
  return return_msa

def repeat_until(msa:List[str], depth:int=128) -> List[str]:
  num_repetitions = int(np.ceil(depth/(len(msa)/2)))
  too_big = msa*num_repetitions
  return too_big[:(2*depth)]

def fill_gaps_until(msa:List[str], depth:int=128) -> List[str]:
  seq_len = len(msa[1])
  gap_seq = [">gap_seq", '-'*seq_len]
  len_msa = len(msa)/2
  add = int(math.ceil(depth-len_msa))
  msa = msa + gap_seq*add
  return msa

In [None]:
#@title Download ProteinMPNN
import json, time, os, sys, glob

if not os.path.isdir("ProteinMPNN"):
  os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

In [None]:
#@title ProteinMPNN
import torch
import copy
from protein_mpnn_utils import tied_featurize, parse_PDB, ProteinMPNN, StructureDataset, StructureDatasetPDB, _scores, _S_to_seq
from pathlib import Path
from typing import Tuple, List


def load_model():
  device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
  model_name = "v_48_020"
  backbone_noise=0.00
  hidden_dim = 128
  num_layers = 3
  model_folder_path = '/content/ProteinMPNN/vanilla_model_weights'
  if model_folder_path[-1] != '/':
      model_folder_path = model_folder_path + '/'
  checkpoint_path = model_folder_path + f'{model_name}.pt'
  checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
  noise_level_print = checkpoint['noise_level']
  model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
  model.to(device)
  model.load_state_dict(checkpoint['model_state_dict'])
  model.eval()
  return model

def make_tied_positions_for_homomers(pdb_dict_list):
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1,chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i] #needs to be a list
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict

## proteinMPNN params
max_length=20000
BATCH_COPIES = 1 #batch_size
pssm_threshold = 0
pssm_multi = 0
pssm_log_odds_flag =0
pssm_bias_flag = 0

def generate_proteinMPNN(input_path:Path, input_chain_list:List[str]=["A"], sampling_temp:float=1.0, num_seqs:int=128):
  NUM_BATCHES = num_seqs//BATCH_COPIES
  with torch.no_grad():
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    bias_AAs_np = np.zeros(len(alphabet))
    omit_AAs_np = np.array([AA in "X" for AA in alphabet]).astype(np.float32)
    model = load_model()
    pdb_dict_list = parse_PDB(input_path, input_chain_list=input_chain_list)
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
    chain_id_dict = {}
    #fixed_chain_list = all_chains-input_chain_list
    chain_id_dict[pdb_dict_list[0]['name']]= (input_chain_list,[]) #fixed_chain_list)
    for ix, protein in enumerate(dataset_valid):
      batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
      tied_positions_dict = None#make_tied_positions_for_homomers([input_path]) if homomer else None
      X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, next(model.parameters()).device, chain_id_dict, None, None, tied_positions_dict, None, None)
      pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
      name_ = batch_clones[0]['name']
      randn_1 = torch.randn(chain_M.shape, device=X.device)
      log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
      mask_for_loss = mask*chain_M*chain_M_pos
      lines = []
      for j in range(NUM_BATCHES):
          randn_2 = torch.randn(chain_M.shape, device=X.device)
          if tied_positions_dict == None:
              sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=sampling_temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
              S_sample = sample_dict["S"]
          else:
              sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=sampling_temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
          # Compute scores
              S_sample = sample_dict["S"]
          log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
          mask_for_loss = mask*chain_M*chain_M_pos
          scores = _scores(S_sample, log_probs, mask_for_loss)
          for b_ix in range(BATCH_COPIES):
              masked_chain_length_list = masked_chain_length_list_list[b_ix]
              masked_list = masked_list_list[b_ix]
              seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
              seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
              score = scores[b_ix]
              native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
              if b_ix == 0 and j==0:
                  start = 0
                  end = 0
                  list_of_AAs = []
                  for mask_l in masked_chain_length_list:
                      end += mask_l
                      list_of_AAs.append(native_seq[start:end])
                      start = end
                  native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                  l0 = 0
                  for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                      l0 += mc_length
                      native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                      l0 += 1
                  sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                  print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                  sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                  print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                  line = '>{}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, print_visible_chains, print_masked_chains, model_name, native_seq)
              start = 0
              end = 0
              list_of_AAs = []
              for mask_l in masked_chain_length_list:
                  end += mask_l
                  list_of_AAs.append(seq[start:end])
                  start = end
              seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
              l0 = 0
              for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                  l0 += mc_length
                  seq = seq[:l0] + '/' + seq[l0:]
                  l0 += 1
              seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
              line = '>T={}, sample={}, seq_recovery={}\n{}\n'.format(sampling_temp,b_ix,seq_rec_print,seq)
              lines.append(line.rstrip())
      out = []
      for line in lines:
        out += line.split("\n")
      return out

In [None]:
#@title Generate things

full_mmseqs = None
full_inverse_folding = {}
output_msa = [""]*MSA_depth
prepared_own_msa = None
for (part_type,start,end, other) in parts:
  if part_type == "Gaps":
    curr_msa = fill_gaps_until([None,query_sequence[start:end]], MSA_depth)[1::2]
  elif part_type == "Repeat":
    if other is None:
      curr_msa = repeat_until([None, query_sequence[start:end]], MSA_depth)[1::2]
    else:
      curr_msa = repeat_until([other], MSA_depth-1)[1::2]
      curr_msa = [query_sequence[start:end]]+curr_msa
  elif part_type == "MMseqs":
    if full_mmseqs is None and other == "Full":
      full_mmseqs = run_mmseqs_pipeline(query_sequence, MSA_depth)
    if other == "Full":
      curr_msa = select_horizontally(full_mmseqs, start, end)[1::2]
    else:
      curr_msa =run_mmseqs_pipeline(query_sequence[start:end], MSA_depth)[1::2]
  elif part_type == "Inverse folding":
    if full_inverse_folding.get(other, None) is None:
      full_inverse_folding[other] = generate_proteinMPNN(pdb_path, ["A"], other, MSA_depth)
    curr_msa = select_horizontally(full_inverse_folding[other], start, end)[1::2]
  elif part_type == "Own MSA":
    if prepared_own_msa is None:
      prepared_own_msa = process_msa(Path(msa_path))
    curr_msa = select_horizontally(prepared_own_msa, start, end)[1::2]
  else:
    print(f"Unknown part type {part_type}")
  output_msa = concat_horizontally(output_msa, curr_msa)
output_msa