# This notebook is to predict the ProteinMPNN optimised sequences using AF2 and partial masked templates
*This Notebook runs ColabDesign version of AF2*

## 1.0 Libraries

In [None]:
import jax, os, copy, pickle, glob, datetime, shutil
import jax.numpy as jnp
import numpy as np

from colabdesign import mk_af_model

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm.notebook import tqdm

from colabdesign.af.alphafold.common import protein
from colabdesign.shared.protein import renum_pdb_str
from colabdesign.af.alphafold.common import residue_constants

from colabdesign.af.loss import *
from colabdesign.af.loss import _get_pw_loss

from Bio.PDB import PDBParser
from Bio import SeqUtils
from scipy.spatial.distance import cdist

## 2.0 Functions

In [None]:
def rank_array_predict(input_array):
    # numpy.argsort returns the indices that would sort an array.
    # We convert it to a python list before returning
    return list(np.argsort(input_array))[::-1]

def rank_and_write_pdb_predict(af_model, name, write_all=False, renum_pdb = True):
    ranking = rank_array_predict(np.mean(af_model.aux['all']['plddt'],-1))
    if write_all != True:
        ranking = [ranking[0]]
    
    aux = af_model.aux
    aux = aux["all"]
    
    p = {k:aux[k] for k in ["aatype","residue_index","atom_positions","atom_mask"]}
    p["b_factors"] = 100 * p["atom_mask"] * aux["plddt"][...,None]
    
    def to_pdb_str(x, n=None):
        p_str = protein.to_pdb(protein.Protein(**x))
        p_str = "\n".join(p_str.splitlines()[1:-2])
        if renum_pdb: p_str = renum_pdb_str(p_str, af_model._lengths)
        #if n is not None:
        #    p_str = f"MODEL{n:8}\n{p_str}\nENDMDL\n"
        return p_str

    m=1
    
    pdbs_out = []
    
    for n in ranking:
        p_str = ""
        p_str += to_pdb_str(jax.tree_map(lambda x:x[n],p), n+1)
        p_str += "END\n"

        with open(name + '_model_{n}_rank_{m}.pdb'.format(n=n, m=m), 'w') as f:
            f.write(p_str)
        pdbs_out.append(name + '_model_{n}_rank_{m}.pdb'.format(n=n, m=m))
        m+=1
    return pdbs_out

In [None]:
def get_interface_info(pdb:str, binder_chain:str, thresh:float=5.0) -> str:
    """
    This function is to return the string of hotpots for a certain structure
    :params:
        - pdb          : the pdb to be processed
        - binder_chain : the binder chain ID for which the hotspot residues will be defined
        - thresh       : Distance threshold below which a residue would be considered as a hotspot, DEFAULT = 5.0 A 
    """
      
    # load the pdb file
    struct = PDBParser(QUIET=True).get_structure(os.path.basename(pdb), pdb)
    
    # define the chains
    binder_chains = list(binder_chain)
    target_chains = [x.id for x in struct.get_chains() if x.id not in binder_chains]

    # get binder and target lengths
    binder_chain_length = len(struct[0][binder_chains[0]])
    target_chain_length = len(struct[0][target_chains[0]])
    
    # get binder first resnum
    binder_first_res_num = int(list(struct[0][binder_chains[0]].get_residues())[0].id[1])
    target_first_res_num = int(list(struct[0][target_chains[0]].get_residues())[0].id[1])
    target_last_res_num = int(list(struct[0][target_chains[0]].get_residues())[-1].id[1])

    # get the atom coords 
    target_atoms = np.array([atom.get_coord() for atom in struct.get_atoms() if atom.get_full_id()[2] in target_chains if 'H' not in atom.get_full_id()[4][0]])
    binder_atoms = np.array([atom.get_coord() for atom in struct.get_atoms() if atom.get_full_id()[2] in binder_chains if 'H' not in atom.get_full_id()[4][0]])
    
    # map atoms to residues
    binder_residues = np.array([atom.get_parent().id[1] for atom in struct.get_atoms() if atom.get_full_id()[2] in binder_chains if 'H' not in atom.get_full_id()[4][0]])
    
    # generate the distance matrix
    dists = cdist(target_atoms, binder_atoms)
    
    # return the binder atoms that are closest to target chains
    closest_binder_atoms_clac = np.argmin(dists, axis=1)
     
    # get binder atoms that are within the targeted threshold
    closest_binder_atoms = closest_binder_atoms_clac[dists[np.arange(len(dists)), closest_binder_atoms_clac] < thresh]
    
    # get the hotspot residues that correspond to the closest atoms
    hotspot_residues = sorted(list(set(binder_residues[closest_binder_atoms])))
    
    # generate the hotspots list
    hotspots = [int(item) for item in hotspot_residues]
    
    return hotspots, binder_first_res_num, binder_chain_length, binder_chains[0], target_chains[0], target_first_res_num, target_last_res_num

In [None]:
def extract_target_info(path:str, target_chain:str):
    """ This function is to extract the target protein """
    # load target
    parser1 = PDBParser(QUIET=True)
    struct = parser1.get_structure(os.path.basename(path), path)[0][target_chain]   
    target_seq = SeqUtils.seq1("".join([r.get_resname() for r in struct.get_residues()]))
    return target_seq

In [None]:
def plot_af_scores(df:pd.DataFrame, outpath:str, iden:str, export:bool=False, show:bool=False):
    """
    """
    
    labels = {
        'multimer_plddt':'pLDDT',
        'multimer_ptm' : 'pTM',
        'multimer_iptm': 'Interface pTM', 
        'multimer_complex_rmsd':'RMSD (A)' 
               }
    
    fig, axs = plt.subplots( nrows=1, ncols=len(list(labels.keys())), figsize=[15,6] )

    for i,k in enumerate(labels.keys()):
        sns.violinplot(y=k,  data=df, ax=axs[i], inner='points', color='#DCDCDC')
        axs[i].set_ylabel(labels[k], fontsize=12 ,labelpad= 10.0)

    plt.suptitle(f"""{df.shape[0]} {iden}
    """, fontsize=16, fontweight='bold')
    plt.tight_layout(h_pad=5.0, w_pad=3.0)

    if export:
        plt.savefig(f'{outpath}/{datetime.date.today()}_{iden.lower().replace(" ", "_")}.png', format="png" ,dpi=300, transparent=True)
        plt.savefig(f'{outpath}/{datetime.date.today()}_{iden.lower().replace(" ", "_")}.svg', format="svg" ,dpi=300, transparent=False)
    
    if show:
        plt.show()
    else:
        plt.close()
    
    return None

## 3.0 Variables to set

In [None]:
# set variables (Modify this only)
mpnn_output_folder = "{PATH_to_'sol_mpnn_{strategy_name}_output'}, e.g., ./sol_mpnn_fixed_interface/"
target_models      = "{PATH_TO_ROSETTA_INPUT_MODELS}"
params_path        = "{PATH_TO_AF2_'params'_FOLDER}"

binder_chain_input = 'B'
interface_thresh   = 5.0
ipTM_thresh        = 0.7

## 4.0 I/O

In [None]:
# find the mpnn output analysis folder
mpnn_output_analysis_folder = [p.strip() for p in glob.iglob(os.path.join(mpnn_output_folder, "*")) if os.path.isdir(p) if "analysis" in p]
assert len(mpnn_output_analysis_folder) == 1, "File finding failure"

mpnn_output_analysis_folder = mpnn_output_analysis_folder[0]
print(mpnn_output_analysis_folder)

In [None]:
# make main output folder
output = os.path.join(mpnn_output_analysis_folder, "af2_mulltimer_pred")
os.makedirs(output, exist_ok=0)

# make af pred output folder
af_pred = os.path.join(output, 'af2_multimer_pred_output')
os.makedirs(af_pred, exist_ok=0)

# make pdb output folder
pdb_out = os.path.join(af_pred, 'pdbs')
os.makedirs(pdb_out, exist_ok=0)

# make score output folder
sc_out = os.path.join(af_pred, 'scores')
os.makedirs(sc_out, exist_ok=0)

# make af pred analysis output folder
af_pred_analysis = os.path.join(output, 'af2_multimer_pred_analysis')
os.makedirs(af_pred_analysis, exist_ok=0)

# make best pdbs out
best_pdb_out = os.path.join(af_pred_analysis, 'best_pdbs')
os.makedirs(best_pdb_out, exist_ok=0)

## 5.0 Predicting

In [None]:
# path to input structures
target_models_paths = [p.strip() for p in glob.iglob(os.path.join(target_models, '*.pdb'))]
print(str(len(target_models_paths)) + ' Parsed input files')

In [None]:
# parse af monomer selection results
af_monomer_selection_file = [p.strip() for p in glob.iglob(os.path.join(mpnn_output_analysis_folder, "*", "*.csv"))
                             if "af2_monomer_pred_analysis" in p
                             if "af_monomer_selection" in p
                            ]
assert len(af_monomer_selection_file) == 1, "File finding error"

df_parsed = pd.read_csv(af_monomer_selection_file[0], index_col=0)
print(str(df_parsed.shape[0]) + ' Parsed sequences')

In [None]:
# map the mpnn to the source models
map_dic = {}

for gp01 in tqdm(df_parsed.groupby(by='parent'), desc="Mapping models"):
    path_ = [os.path.abspath(p) for p in target_models_paths if gp01[0].replace('x','_') in p][0]
    mpnn_seqs = [(d.replace('>',''),s) for d,s in zip(gp01[1].label, gp01[1].sequence)]
    assert gp01[1].shape[0] == len(mpnn_seqs), 'Parsing error'
    map_dic[path_] = mpnn_seqs
    
assert len(map_dic.keys()) != None, "Model mapping failure"

In [None]:
# Predict
print('Predicting...')
for pdb, mpnned in tqdm(map_dic.items()):
    # get target information
    design_pos, binder_first_res_num, binder_chain_length, binder_chain, target_chain, target_first_res_num, target_last_res_num = get_interface_info(pdb=pdb, binder_chain=binder_chain_input, thresh=interface_thresh)
    
    # the positions we want to keep fixed to AF2seq
    fix_pos = []
    for i in range(int(binder_first_res_num), binder_chain_length + 1):
        if i not in design_pos:
            fix_pos.append(f"{binder_chain}" + str(i))
    for i in range(target_first_res_num, target_last_res_num + 1):
        fix_pos.append(f"{target_chain}" + str(i))
    # get the fixed positions
    fix_pos = ','.join(fix_pos)
    
    # initialise af
    af_model = mk_af_model(protocol='fixbb', use_templates=True, initial_guess=False,data_dir=params_path)
    af_model.prep_inputs(pdb_filename=pdb, chain=f'{target_chain},{binder_chain}', fix_pos=fix_pos)
    
    # Mask design positions from template
    for j in design_pos:
        af_model._inputs['batch']['all_atom_mask'][j-1,:] = np.zeros_like(af_model._inputs['batch']['all_atom_mask'][j-1,:])
    
    for item in mpnned:
        name_ = item[0]
        seq_ = item[1]
        # get the complex sequence from the pdb Target + Design
        cmplx_seq = extract_target_info(path=pdb, target_chain=target_chain) + seq_ 
        # run the prediction
        af_model.set_seq(cmplx_seq)
        af_model.predict(num_recycles=3, models = [0,1], num_models=2)
        pdbs = rank_and_write_pdb_predict(af_model, name=os.path.join(pdb_out,name_))
        
        # get af scores
        af_sc = pd.DataFrame({'design':name_,
                              'multimer_plddt':af_model.aux['all']['plddt'].mean(),
                              'multimer_ptm':af_model.aux['all']['ptm'].mean(),
                              'multimer_iptm':af_model.aux['all']['i_ptm'].mean(),
                              'multimer_complex_rmsd':af_model.aux['all']['losses']['rmsd'].mean()}, index=[0])
        af_sc.to_csv(os.path.join(sc_out, name_+'_scores.csv'))

## 6.0 Predictions analysis

In [None]:
# Parse af model scores
sc_paths = [p.strip() for p in glob.iglob(os.path.join(sc_out, '*.csv'))]
df_af_sc = pd.concat([pd.read_csv(s, index_col=0) for s in sc_paths])
df_af_sc.sort_values(by='multimer_iptm', ascending=False, inplace=True)
df_af_sc.reset_index(drop=True, inplace=True)
df_af_sc.to_csv(os.path.join(af_pred_analysis, 'all_af2_multimer_models_scores.csv'))
print(str(df_af_sc.shape[0]) + ' Parsed Designs')

In [None]:
# Get best designs based on the iptm scores
df_af_sel = df_af_sc.copy()[(df_af_sc.multimer_iptm >= ipTM_thresh)].reset_index(drop=True)
df_af_sel.to_csv(os.path.join(af_pred_analysis, 'best_af2_multimer_models_scores.csv'))
print(str(df_af_sel.shape[0]) + ' Best designs')

In [None]:
# merge with all scores and export a full info file

# formatting
df_af_sel["label"] =df_af_sel["design"].apply(lambda x: ">"+x) 

# Merging
df_merged = pd.merge(left=df_parsed, right=df_af_sel, on="label", how="inner").reset_index(drop=True)
assert df_merged.shape[0] == df_af_sel.shape[0], "Merging error"

df_merged.to_csv(os.path.join(af_pred_analysis, 'af2_multimer_selection.csv'))

In [None]:
# Get the best designs
pdb_paths = [p.strip() for p in glob.iglob(os.path.join(pdb_out,'*.pdb'))]

for des in tqdm(df_af_sel.design, desc='Copying best designs models'):
    path_ = [p for p in pdb_paths if des in p][0]
    shutil.copy(path_, os.path.join(best_pdb_out, os.path.basename(path_)))

In [None]:
# Plot scores
print('Ploting...')
plot_af_scores(df=df_af_sc,  iden='Parsed MPNN Designs', outpath=af_pred_analysis, export=True, show=True)
plot_af_scores(df=df_af_sel, iden='Best MPNN Designs',   outpath=af_pred_analysis, export=True, show=True)