In [128]:
from pathlib import Path
from tempfile import mkdtemp
import subprocess
import pandas as pd
import numpy as np

In [147]:
def parse_score(score_sc_fn):
    """ parses the score.sc file to get overall energy for variant """

    df = pd.read_csv(score_sc_fn, delim_whitespace=True, skiprows=1)

    total_scores = df["total_score"].values.astype(np.float32)
    avg_score = np.mean(total_scores)
    return avg_score


def write_flags_mutate(pdb_fn, out_fn, database="/home/hunter/install/rosetta/rosetta.binary.linux.release-332/main/database"):
    with open(out_fn, 'w') as outfh:
        outfh.write(f"-s {pdb_fn}\n")
        outfh.write(f"-database {database}\n")
        outfh.write("-parser:protocol mutate.xml\n")
        outfh.write("-ignore_unrecognized_res  \n")
        outfh.write("-restore_talaris_behavior\n")
        outfh.write("-nstruct 1")


def write_flags_score(pdb_fn, tempdir, out_fn, database="/home/hunter/install/rosetta/rosetta.binary.linux.release-332/main/database"):
    with open(f"{tmpdir}/list.txt", 'w') as outfh:
        outfh.write(f"{pdb_fn}")

    with open(out_fn, 'w') as outfh:
        outfh.write(f"-l {tmpdir}/list.txt\n")
        outfh.write(f"-database {database}\n")
        outfh.write("-restore_talaris_behavior\n")
        outfh.write(f"-out:file:silent {tempdir}/energy.txt")


def gen_resfile_str(variant, template_fn=Path("./mutation_template.resfile")):
    "residue_number chain PIKAA replacement_AA"

    mutation_strs = []
    for mutation in variant.split(","):
        resnum = int(mutation[1:-1])
        new_aa = mutation[-1]
        mutation_strs.append("{} A PIKAA {}".format(resnum, new_aa))

    # add new lines between mutation strs
    mutation_strs = "\n".join(mutation_strs)

    # load the template
    with open(template_fn, "r") as f:
        template_str = f.read()

    formatted_template = template_str.format(mutation_strs)

    return formatted_template


def gen_rosetta_script_str(variant, template_fn=Path('./mutate_template.xml')):

    # for the rosetta script xml, we just need the mutated residue numbers and chain
    # Note: ROSETTA USES 1-BASED INDEXING
    resnums = []
    for mutation in variant.split(","):
        resnum = int(mutation[1:-1])
        resnums.append("{}A".format(resnum))

    resnum_str = ",".join(resnums)

    # load the template
    with open(template_fn, "r") as f:
        template_str = f.read()

    # fill in the template
    formatted_template = template_str.format(resnum_str)

    return formatted_template



def score_variant_rosetta(variant, pdb_fn):
    ROSETTA_BIN="/home/hunter/install/rosetta/rosetta.binary.linux.release-332/main/source/bin"


    rosetta_script = gen_rosetta_script_str(variant)
    resfile = gen_resfile_str(variant)
    tmpdir = Path(mkdtemp())
    write_flags_mutate(pdb_fn, tmpdir / 'flags_mutate')


    with open(tmpdir / 'mutate.xml', 'w') as outfn:
        outfn.write(rosetta_script)

    with open(tmpdir / 'mutation.resfile', 'w') as outfn:
        outfn.write(resfile)

    cmd = f"{ROSETTA_BIN}/rosetta_scripts.static.linuxgccrelease @flags_mutate -out:level 200"
    subprocess.check_output(cmd, shell=True, cwd=tmpdir)

    output_pdb = tmpdir / f'{pdb_fn.stem}_0001.pdb'
    assert(output_pdb.exists())

    write_flags_score(output_pdb, tmpdir, out_fn=tmpdir / 'flags_score')


    cmd = f"{ROSETTA_BIN}/residue_energy_breakdown.static.linuxgccrelease @flags_score -out:level 200"
    subprocess.check_output(cmd, shell=True, cwd=tmpdir)

    score = parse_score(tmpdir / 'score.sc')
    return score

In [148]:
# input PDB file
pdb_fn = Path('./2qmt.pdb').resolve()
# input variant to score
variant = "A33K"

score = score_variant_rosetta(variant, pdb_fn)
