## Inpaint a region of a protein to design new sequences and fold them 

### For a sequence of interest with region to be inpainted denoted with "[]"
 - e.g input: CASTTFG\[DASTT\]KFTDYD
 1. Use **ESMfold** to predict the structure of the original sequence
 2. Use **RFdiffusion** to inpaint the region selected (3 designs)
 3. Use **ProteinMPNN** to infer sequences of these new designs (3 sequences for each design, 9 total sequences)
 4. Use **ESMfold** to predict the new Structures for each design (9 new structures)

In [0]:
%pip install biopython==1.79
%pip install py3Dmol
%pip install mlflow>=2.18
dbutils.library.restartPython()

In [0]:
import os
import requests
import json
import mlflow

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
DATABRICKS_URL = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()

def hit_model_endpoint(endpoint_name, inputs):
    url = f'{DATABRICKS_URL}/serving-endpoints/{endpoint_name}/invocations'
    headers = {'Authorization': f'Bearer {DATABRICKS_TOKEN}', 'Content-Type': 'application/json'}
    ds_dict = {'inputs':inputs}
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method='POST', headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f'Request failed with status {response.status_code}, {response.text}')
    return response.json()

@mlflow.trace(span_type="LLM")
def hit_esmfold(sequence):
    return hit_model_endpoint('esmfold', [sequence])['predictions'][0]

@mlflow.trace(span_type="TOOL")
def hit_rfdiffusion(input_dict):
    return hit_model_endpoint('rfdiffusion_inpaint', [input_dict])['predictions'][0]

@mlflow.trace(span_type="TOOL")
def hit_proteinmpnn(pdb_str):
    return hit_model_endpoint('proteinmpnn', [pdb_str])['predictions']


In [0]:
from Bio.PDB import PDBList
from Bio.PDB import PDBParser
from Bio import PDB
import tempfile

@mlflow.trace(span_type="TOOL")
def extract_chain_reindex(structure, chain_id='A'):
    # Extract chain A
    chain = structure[0][chain_id]
    
    # Create a new structure with only chain A & 1-indexed
    new_structure = PDB.Structure.Structure('new_structure')
    new_model = PDB.Model.Model(0)
    new_chain = PDB.Chain.Chain(chain_id)
    
    # Reindex residues starting from 1
    for i, residue in enumerate(chain, start=1):
        if residue.id[0] == ' ':  # Ensure no HETATM
            residue.id = (' ', i, ' ')
            new_chain.add(residue)
    
    new_model.add(new_chain)
    new_structure.add(new_model)
    
    # Save the new structure to a PDB file
    io = PDB.PDBIO()
    io.set_structure(new_structure)
    with tempfile.NamedTemporaryFile(suffix='.pdb') as f:
        io.save(f.name)
        with open(f.name, 'r') as f_handle:
            pdb_text = f_handle.read()
    return pdb_text

In [0]:
@mlflow.trace(span_type="TOOL")
def parse_sequence(sequence):
    # Get index of "[" and "]"
    start_idx = sequence.find("[")
    end_idx = sequence.find("]")

    raw_sequence = sequence.replace("[", "").replace("]", "")
    return {
        "sequence": raw_sequence,
        "start_idx": start_idx,
        "end_idx": end_idx,
    }

@mlflow.trace(span_type="TOOL")
def make_designs(sequence):

    seq_details = parse_sequence(sequence)
    esmfold_initial = hit_esmfold(seq_details['sequence'])

    # take the output and modify so that bewteen start and end idx residues are replaced with Glycine and only CA kept
    with tempfile.NamedTemporaryFile(suffix='.pdb') as f:
        with open(f.name, 'w') as fw:
            fw.write(esmfold_initial)
        structure = PDBParser().get_structure("esmfold", f.name)
    
    modified_pdb_text = extract_chain_reindex(
        structure, 
        chain_id='A'
    )

    # now pass that modified structure to rfdifffusion as string
    designed_pdb_strs = []
    for i in range(3):
        designed_pdb = hit_rfdiffusion({
            'pdb': modified_pdb_text,
            'start_idx': seq_details['start_idx'],
            'end_idx': seq_details['end_idx'],
        })
        designed_pdb_strs.append(designed_pdb)
    
    all_seqs = []
    for pdb_ in designed_pdb_strs:
        all_seqs.extend(hit_proteinmpnn(pdb_))
    
    all_pdbs = []
    for s in all_seqs:
        all_pdbs.append(hit_esmfold(s))

    return {
        'initial': esmfold_initial,
        'designed': all_pdbs
    }

In [0]:
# SARS-CoV2 with [] around loop region near ACE2 binding site

cov2_rbd = "PTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPF[ERDISTEIYQAGSTPCNGVEGFNCYFPLQSY]GFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNF"

In [0]:
designed_pdbs = make_designs(cov2_rbd)

In [0]:
import Bio.PDB as bp
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
import numpy as np
import os
import tempfile

class ChainSelect(bp.Select):
    def __init__(self, chain):
        self.chain = chain

    def accept_residue(self, residue):
        """ remove hetatm """
        return 1 if residue.id[0] == " " else 0

    def accept_chain(self, chain):
        if chain.get_id() == self.chain:
            return 1
        else:          
            return 0
        

def get_seq_alignment(structure1, structure2):
    """ structures can be biopy.PDB.Structures, Models, or Chains """
    seq1 = ""
    seq2 = ""
    for residue in structure1.get_residues():
        seq1 += residue.get_resname()[0]
    for residue in structure2.get_residues():
        seq2 += residue.get_resname()[0]

    alignment = pairwise2.align.localxx(seq1, seq2)
    return alignment

def get_backbone_atoms(structure):
    backbone_atoms = []
    for atom in structure.get_atoms():
        if atom.get_name() in ['N', 'CA', 'C']:
            backbone_atoms.append(atom)
    return backbone_atoms

def get_overlapping_backbone(structure1, structure2):
    """ structures can be biopy.PDB.Structures, Models, or Chains """
    alignment = get_seq_alignment(structure1,structure2)

    seq=""
    a_count=0
    b_count=0
    a_residues=[r for r in structure1.get_residues()]
    b_residues=[r for r in structure2.get_residues()]
    a_backbone = []
    b_backbone = []

    for idx,(i,j) in enumerate(zip(alignment[0].seqA,alignment[0].seqB)):
        if i!='-' and j!='-':
            if i==j:
                seq+=i
            # print(a_residues[a_count].resname,b_residues[b_count].resname)
            a_backbone.extend(get_backbone_atoms(a_residues[a_count]))
            b_backbone.extend(get_backbone_atoms(b_residues[b_count]))
        if i!='-':
            a_count+=1
        if j!='-':
            b_count+=1
    return {'atoms1':a_backbone, 'atoms2':b_backbone, 'seq':seq} 

def pdb_to_str(fpath):
    with open(fpath, 'r') as f:
        lines = f.readlines()
    return ''.join(lines)

def select_and_align(true_structure,af_structure):
    mmcif_parser = bp.MMCIFParser()
    pdb_parser = bp.PDBParser()

    alignment_scores = dict()
    for chain in true_structure.get_chains():
        ali = get_seq_alignment(af_structure, chain)[0]
        alignment_scores[chain.id] = ali.score
    max_overlapping_chain = max(alignment_scores, key=alignment_scores.get)

    # get Bio.PDB structure of just the chain of interest
    with tempfile.TemporaryDirectory() as tmp: 
        io = bp.PDBIO()
        io.set_structure(true_structure)
        new_name = os.path.join(tmp,"true_structure_singlechain.pdb")
        io.save(
            new_name,
            ChainSelect(max_overlapping_chain)
        )
        true_structure = pdb_parser.get_structure('true', new_name)

    alignment_dict = get_overlapping_backbone(true_structure, af_structure)
    ref_backbone_atoms = alignment_dict['atoms1']
    af_backbone_atoms = alignment_dict['atoms2']

    imposer = bp.Superimposer()
    imposer.set_atoms(
        ref_backbone_atoms,
        af_backbone_atoms 
    )
    imposer.apply(af_structure.get_atoms())

    with tempfile.TemporaryDirectory() as tmp:
        # create temp files of aligned structures
        io = bp.PDBIO()
        io.set_structure(af_structure)
        io.save(os.path.join(tmp,"af_s.pdb"))
        io.set_structure(true_structure)
        io.save(os.path.join(tmp,"tr_s.pdb"))
        af_structure_str = pdb_to_str(os.path.join(tmp,"af_s.pdb")) 
        true_structure_str = pdb_to_str(os.path.join(tmp,"tr_s.pdb")) 
    return true_structure_str, af_structure_str

    



In [0]:
# plot 3D with py3Dmol
# original plus one of the designs (can change)
# or save pdbs to UC and then download for local analysis etc

import py3Dmol

def html_for_proteins(structure0, structure1):
    with tempfile.TemporaryDirectory() as tmp:

        view = py3Dmol.view(width=800, height=300)

        # create temp files of aligned structures
        io = bp.PDBIO()
        io.set_structure(structure0)
        io.save(os.path.join(tmp,"structure0.pdb"))

        view.addModel(
            open(os.path.join(tmp,"structure0.pdb"),'r').read(),
            'pdb'
        )
        view.setStyle(
            {'model':0},
            {'cartoon': {'color': 'red', 'opacity': 0.9}}
        )

        
        io.set_structure(structure1)
        io.save(os.path.join(tmp,"structure1.pdb"))
        view.addModel(open(os.path.join(tmp,"structure1.pdb"),'r').read(),'pdb')
        view.setStyle(
            {'model':1},
            {'cartoon': {'color': 'white', 'opacity': 0.9}}
        )
        
        view.zoomTo()
        html = view._make_html()
    return html

In [0]:
with tempfile.TemporaryDirectory() as tmpdir:
    with open(os.path.join(tmpdir,"structure0.pdb"), 'w') as f:
        f.write(designed_pdbs['designed'][0])
    with open(os.path.join(tmpdir,"structure1.pdb"), 'w') as f:
        f.write(designed_pdbs['initial'])

    structure0 = PDBParser().get_structure("designed", os.path.join(tmpdir,"structure0.pdb"))
    structure1 = PDBParser().get_structure("esmfold_initial", os.path.join(tmpdir,"structure1.pdb"))
    html = html_for_proteins(structure0, structure1)

In [0]:
displayHTML(html)

In [0]:
html

In [0]:
seq_details = parse_sequence(cov2_rbd)
esmfold_initial = hit_esmfold(seq_details['sequence'])

In [0]:
esmfold_initial

In [0]:
# take the output and modify so that bewteen start and end idx residues are replaced with Glycine and only CA kept
with tempfile.NamedTemporaryFile(suffix='.pdb') as f:
    with open(f.name, 'w') as fw:
        fw.write(esmfold_initial)
    structure = PDBParser().get_structure("esmfold", f.name)

modified_pdb_text = extract_chain_reindex(
    structure, 
    chain_id='A'
)

In [0]:
# test with single case
designed_pdb = hit_rfdiffusion({
    'pdb': modified_pdb_text,
    'start_idx': seq_details['start_idx'],
    'end_idx': seq_details['end_idx'],
})

In [0]:
# now pass that modified structure to rfdiffusion as string 3 times and collect all
designed_pdb_strs = []
for i in range(3):
    designed_pdb = hit_rfdiffusion({
        'pdb': modified_pdb_text,
        'start_idx': seq_details['start_idx'],
        'end_idx': seq_details['end_idx'],
    })
    designed_pdb_strs.append(designed_pdb)


In [0]:
designed_pdb_strs

In [0]:
all_seqs = []
for pdb_ in designed_pdb_strs:
    all_seqs.extend(hit_proteinmpnn(pdb_))

In [0]:
# make async ?????
all_pdbs = []
for s in all_seqs:
    all_pdbs.append(hit_esmfold(s))

In [0]:
len(all_pdbs)

In [0]:
all_pdbs[0]