In [1]:
## Boilerplate imports

import pandas as pd
from pathlib import Path
import json
import pyrosetta
import pyrosetta_help as ph
from types import ModuleType
# Better than star imports:
prc: ModuleType = pyrosetta.rosetta.core
prp: ModuleType = pyrosetta.rosetta.protocols
pru: ModuleType = pyrosetta.rosetta.utility
prn: ModuleType = pyrosetta.rosetta.numeric
prs: ModuleType = pyrosetta.rosetta.std
pr_conf: ModuleType = pyrosetta.rosetta.core.conformation
pr_scoring: ModuleType = pyrosetta.rosetta.core.scoring
pr_res: ModuleType = pyrosetta.rosetta.core.select.residue_selector
pr_options: ModuleType = pyrosetta.rosetta.basic.options

logger = ph.configure_logger()
pyrosetta.distributed.maybe_init(extra_options=ph.make_option_string(no_optH=False,
                                                                     ex1=None,
                                                                     ex2=None,
                                                                     ignore_unrecognized_res=False,
                                                                     load_PDB_components=True,
                                                                     ignore_waters=True,
                                                                    )
                                )

In [18]:
import sys

sys.path.append(Path('../GCN2/protein-design-analysis').resolve().as_posix())
from protein_design_analysis.utils import create_design_tf

In [5]:
vanilla_pose = pyrosetta.pose_from_file('andre.min.pdb')
pose = vanilla_pose

triad = pr_res.ResidueIndexSelector()
offset = len(pose.chain_sequence(1))
for resi, resn in [(36, 'HIS'), (60, 'ASP'), (120, 'SER')]:
    assert pose.residue(resi+offset).name3() == resn, (pose.residue(resi+offset).name3(), resn)
    triad.append_index(resi+offset)
neigh = pr_res.CloseContactResidueSelector()
neigh.central_residue_group_selector(triad)
neigh.threshold(6)

In [63]:
def get_mutations(poseA, poseB):
    n_chains=vanilla_pose.num_chains()
    get_residue_iterator = lambda c: enumerate(zip(poseA.chain_sequence(c), poseB.chain_sequence(c)))
    return [f'{a}{i+1}{b}:{"ØABCDEF"[c]}' for c in range(1, n_chains+1) for i, (a, b) in get_residue_iterator(c) if a != b]

def tune(pose: pyrosetta.Pose, design_selector: pr_res.ResidueSelector, filename: str, design_cycles:int=5, res_type_constraint:int=1) -> pyrosetta.Pose:
    """
    Run FastDesign on the pose.
    Hacked from ``protein_design_analysis`` module of mine
    """
    info = {}
    vanilla_scorefxn: pr_scoring.ScoreFunction = pyrosetta.get_fa_scorefxn()
    weighted_scorefxn: pr_scoring.ScoreFunction = pyrosetta.create_score_function('ref2015')
    weighted_scorefxn.set_weight(pr_scoring.ScoreType.coordinate_constraint, 1)
    weighted_scorefxn.set_weight(pr_scoring.ScoreType.atom_pair_constraint, 1)
    weighted_scorefxn.set_weight(pr_scoring.ScoreType.res_type_constraint, res_type_constraint)
    
    # res_type_constraint is already set in the scorefxn
    prp.protein_interface_design.FavorNativeResidue(pose, 1)
    # make the ref sequence: need for safeguarding against Relax changing
    des_resi = pr_res.ResidueVector(design_selector.apply(pose))
    ref_seq = ''.join([resn if idx0+1 not in des_resi else '-' for idx0, resn in enumerate(pose.sequence())])
    previous_design = pose
    previous_complex_dG = vanilla_scorefxn(previous_design)
    previous_mono_dG = vanilla_scorefxn(previous_design.split_by_chain(1))
    current_design = previous_design # just in case self.settings.design_cycles == 0
    for cycle in range(design_cycles):
        current_design = previous_design.clone()
        task_factory: prc.pack.task.TaskFactory = create_design_tf(current_design,
                                                                   design_sele=design_selector,
                                                                   distance=0)
        relax = pyrosetta.rosetta.protocols.relax.FastRelax(weighted_scorefxn, 1)  # one cycle at a time
        relax.set_enable_design(True)
        relax.set_task_factory(task_factory)
        relax.apply(current_design)
        current_complex_dG = vanilla_scorefxn(current_design)
        chain = current_design.split_by_chain(1)
        current_mono_dG = vanilla_scorefxn(chain)
        info[f'design_cycle{cycle}_seq'] = current_design.sequence()
        info[f'design_cycle{cycle}_dG_complex'] = current_complex_dG
        info[f'design_cycle{cycle}_dG_monomer'] = current_mono_dG
        info[f'design_cycle{cycle}_mutations'] = get_mutations(pose, current_design)
        if any([have != expected and expected != '-' for have, expected in
                zip(current_design.sequence(), ref_seq)]):
            # this is a weird glitch in Relax that happens rarely
            #print('Mismatch happened: reverting!')
            current_design = previous_design
            info[f'design_cycle{cycle}_outcome'] = 'mismatched'
        elif current_complex_dG > previous_complex_dG:
            #print('Design is worse: reverting!')
            current_design = previous_design
            info[f'design_cycle{cycle}_outcome'] = 'worse complex'
        elif current_mono_dG > previous_mono_dG:
            #print('Design is worse: reverting!')
            current_design = previous_design
            info[f'design_cycle{cycle}_outcome'] = 'worse monomer'
        else:
            info[f'design_cycle{cycle}_outcome'] = 'success'
            previous_design = current_design
            previous_complex_dG = current_complex_dG
            current_design.dump_pdb(str(filename))
    return current_design, info

In [66]:
# vanilla_pose = pyrosetta.pose_from_file('andre.min.pdb')

In [64]:
data = []
snatched_pose = pyrosetta.pose_from_file('snatched.min.pdb')

In [None]:
import json

def save(data):
    with open('fastdesign.json', 'w') as fh:
        json.dump(data, fh)


for i in range(1, 10):
    tuned, info = tune(pose=vanilla_pose, design_selector=pr_res.NotResidueSelector(neigh),
                                       filename=f'variants/vanilla_fastdesign{i}.pdb',
                                       res_type_constraint=0.5,
                                       design_cycles=15)
    info['name'] = f'vanilla_fastdesign{i}'
    data.append(info)
    save(data)
    
for i in range(1, 10):
    tuned, info = tune(pose=snatched_pose,
                                        design_selector=pr_res.NotResidueSelector(neigh),
                                        res_type_constraint=0.5,
                                        filename=f'variants/snatched_fastdesign{i}.pdb',
                                        design_cycles=15)
    info['name'] = f'snatched_fastdesign{i}'
    data.append(info)
    save(data)

# allow drift

#(resi 57+58+103+105+107+102 and chain B) or (chain A and resi 34+33+37+31+28+40)
manual_selector = pr_res.ResidueIndexSelector()
for resi in '34+33+37+31+28+40'.split('+'):
    manual_selector.append_index(int(resi))
offset = len(pose.chain_sequence(1))
for resi in '34+33+37+31+28+40'.split('+'):
    manual_selector.append_index(int(resi)+offset)

for i in range(1, 10):
    tuned, info = tune(pose=vanilla_pose, design_selector=manual_selector,
                                       filename=f'variants/focused_fastdesign{i}.pdb',
                                       design_cycles=15,
                                       res_type_constraint=0)
    info['name'] = f'focused_fastdesign{i}'
    data.append(info)
    save(data)