In [None]:
import glob
import os
import subprocess
from ast import literal_eval

import matplotlib.pyplot as plt
import MDAnalysis as mda
import pandas as pd
import seaborn as sns
from MDAnalysis.analysis import rms

# Analysis

In [2]:
# FrameFlow Version
start_idx_dict = {
    "1prw": [15, 51],
    "1bcf": [90, 122, 46, 17],
    "5tpn": [108],
    "3ixt": [0],
    "4jhw": [144, 37],
    "4zyp": [357],
    "5wn9": [1],
    "5ius": [88, 34],
    "5yui": [89, 114, 194],
    "6vw1": [5, 45],
    "1qjg": [37, 13, 98],
    "1ycr": [2],
    "2kl8": [0, 27],
    "7mrx": [25],
    "5trv": [45],
    "6e6r": [22],
    "6exz": [25],
}
end_idx_dict = {
    "1prw": [34, 70],
    "1bcf": [98, 129, 53, 24],
    "5tpn": [126],
    "3ixt": [23],
    "4jhw": [159, 43],
    "4zyp": [371],
    "5wn9": [20],
    "5ius": [109, 53],
    "5yui": [93, 116, 196],
    "6vw1": [23, 63],
    "1qjg": [37, 13, 98],
    "1ycr": [10],
    "2kl8": [6, 78],
    "7mrx": [46],
    "5trv": [69],
    "6e6r": [34],
    "6exz": [39],
}

motif_name_mapping = {
    "1PRW": "1prw",
    "1BCF": "1bcf",
    "5TPN": "5tpn",
    "5IUS": "5ius",
    "3IXT": "3ixt",
    "5YUI": "5yui",
    "1QJG": "1qjg",
    "1YCR": "1ycr",
    "2KL8": "2kl8",
    "7MRX_60": "7mrx",
    "7MRX_85": "7mrx",
    "7MRX_128": "7mrx",
    "4JHW": "4jhw",
    "4ZYP": "4zyp",
    "5WN9": "5wn9",
    "5TRV_short": "5trv",
    "5TRV_med": "5trv",
    "5TRV_long": "5trv",
    "6E6R_short": "6e6r",
    "6E6R_med": "6e6r",
    "6E6R_long": "6e6r",
    "6EXZ_short": "6exz",
    "6EXZ_med": "6exz",
    "6EXZ_long": "6exz",
}

import numpy as np
from Bio import PDB


def calculate_avg_plddt(pdb_file):
    parser = PDB.PDBParser(QUIET=True)

    structure = parser.get_structure("protein", pdb_file)

    plddt_values = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if "CA" in residue:
                    ca_atom = residue["CA"]
                    plddt = ca_atom.get_bfactor()
                    plddt_values.append(plddt)

    if plddt_values:
        avg_plddt = np.mean(plddt_values)
        return avg_plddt
    else:
        raise NotImplementedError


def calc_rmsd_tmscore(
    pdb_name,
    reference_PDB,
    scaffold_pdb_path=None,
    scaffold_info_path=None,
    ref_motif_starts=[30],
    ref_motif_ends=[44],
    output_path=None,
):
    "Calculate RMSD between reference structure and generated structure over the defined motif regions"

    motif_df = pd.read_csv(
        os.path.join(scaffold_info_path, f"{pdb_name}.csv"), index_col=0
    )
    results = []
    for pdb in os.listdir(
        scaffold_pdb_path
    ):  # This needs to be in numerical order to match new_starts file
        if not pdb.endswith(".pdb"):
            continue
        ref = mda.Universe(reference_PDB)
        predict_PDB = os.path.join(scaffold_pdb_path, pdb)
        u = mda.Universe(predict_PDB)

        ref_selection = "name CA and resnum "
        u_selection = "name CA and resnum "

        i = int(pdb.split("_")[1].split(".")[0])
        new_motif_starts = literal_eval(motif_df["start_idxs"].iloc[i])
        new_motif_ends = literal_eval(motif_df["end_idxs"].iloc[i])

        ref_selection_list = []
        for j in range(len(ref_motif_starts)):
            ref_selection_temp = (
                ref_selection
                + str(ref_motif_starts[j])
                + ":"
                + str(ref_motif_ends[j])
                + " "
            )
            ref_selection_list.append(ref_selection_temp)
            u_selection += (
                str(new_motif_starts[j] + 1)
                + ":"
                + str(new_motif_ends[j] + 1)
                + " "
            )

        ref_ordered = ref.select_atoms(ref_selection_list[0])
        for j in range(1, len(ref_selection_list)):
            ref_ordered = ref_ordered + ref.select_atoms(ref_selection_list[j])

        # print("U SELECTION", u_selection)
        # print("SEQUENCE", i)
        # print("ref", ref_ordered.resnames)
        # print("gen", u.select_atoms(u_selection).resnames)
        # This asserts that the motif sequences are the same - if you get this error something about your indices are incorrect - check chain/numbering
        assert len(ref_ordered.resnames) == len(
            u.select_atoms(u_selection).resnames
        ), "Motif \
                                                                    lengths do not match, check PDB preprocessing\
                                                                    for extra residues"

        assert (
            ref_ordered.resnames == u.select_atoms(u_selection).resnames
        ).all(), "Resnames for\
                                                                    motifRMSD do not match, check indexing"
        rmsd = rms.rmsd(
            u.select_atoms(u_selection).positions,
            # coordinates to align
            ref_ordered.positions,
            # reference coordinates
            center=True,  # subtract the center of geometry
            superposition=True,
        )  # superimpose coordinates

        temp_file = open(os.path.join(output_path, "temp_tmscores.txt"), "w")

        subprocess.call(
            ["./TMscore", reference_PDB, predict_PDB, "-seq"], stdout=temp_file
        )
        with open(os.path.join(output_path, "temp_tmscores.txt"), "r") as f:
            for line in f:
                if len(line.split()) > 1 and "TM-score" == line.split()[0]:
                    tm_score = line.split()[2]
                    break

        plddt = calculate_avg_plddt(predict_PDB)
        results.append((pdb_name, i, rmsd, plddt, tm_score))
    return results

## DPLM

### Calculate motif rmsd

In [None]:
scaffold_dir = "../generation-results/dplm_650m/motif_scaffold"
motif_pdb_dir = "../data-bin/scaffolding-pdbs"

output_dir = os.path.join(scaffold_dir, "scaffold_results")
os.makedirs(output_dir, exist_ok=True)

results = []

for ori_pdb, pdb in motif_name_mapping.items():
    print(pdb)
    ref_motif_starts = start_idx_dict[pdb]
    ref_motif_ends = end_idx_dict[pdb]
    reference_PDB = os.path.join(
        "../data-bin/scaffolding-pdbs", pdb + "_reference.pdb"
    )
    with open(reference_PDB) as f:
        line = f.readline()
        ref_basenum = int(line.split()[5])
    ref_motif_starts = [num + ref_basenum for num in ref_motif_starts]
    ref_motif_ends = [num + ref_basenum for num in ref_motif_ends]
    results += calc_rmsd_tmscore(
        pdb_name=ori_pdb,
        reference_PDB=reference_PDB,
        scaffold_pdb_path=f"{scaffold_dir}/scaffold_fasta/esmfold_pdb/{ori_pdb}",
        scaffold_info_path=f"{scaffold_dir}/scaffold_info",
        ref_motif_starts=ref_motif_starts,
        ref_motif_ends=ref_motif_ends,
        output_path=output_dir,
    )

results = pd.DataFrame(
    results, columns=["pdb_name", "index", "rmsd", "plddt", "tmscore"]
)
results.to_csv(os.path.join(output_dir, "rmsd_tmscore.csv"), index=False)

### Calculate success rate

In [None]:
def cal_success_scaffold(pdb):
    total = len(pdb)
    pdb["total"] = total
    pdb = pdb[(pdb["rmsd"] < 1.0) & (pdb["plddt"] > 70)]
    return pdb


rmsd_tmscore = pd.read_csv(os.path.join(output_dir, "rmsd_tmscore.csv"))
success_scaffold = rmsd_tmscore.groupby("pdb_name", as_index=False).apply(
    cal_success_scaffold
)
success_scaffold_count = success_scaffold.groupby("pdb_name").size()
success_scaffold_count = success_scaffold_count.reset_index(
    name="success_count"
)

all_pdb = list(rmsd_tmscore["pdb_name"].unique())
success_pdb = list(success_scaffold_count["pdb_name"])
failed_pdb = list(set(all_pdb) - set(success_pdb))
failed_scaffold_count = {
    "pdb_name": failed_pdb,
    "success_count": [0] * len(failed_pdb),
}
results = pd.concat(
    [success_scaffold_count, pd.DataFrame(failed_scaffold_count)]
).sort_values("pdb_name")
results.to_csv(os.path.join(output_dir, "result.csv"))
results

## DPLM-2

### Calculate motif rmsd

In [None]:
scaffold_dir = "../generation-results/dplm2_650m/motif_scaffold"
motif_pdb_dir = "../data-bin/scaffolding-pdbs"

output_dir = os.path.join(scaffold_dir, "scaffold_results")
os.makedirs(output_dir, exist_ok=True)

results = []

for ori_pdb, pdb in motif_name_mapping.items():
    print(pdb)
    ref_motif_starts = start_idx_dict[pdb]
    ref_motif_ends = end_idx_dict[pdb]
    reference_PDB = os.path.join(
        "../data-bin/scaffolding-pdbs", pdb + "_reference.pdb"
    )
    with open(reference_PDB) as f:
        line = f.readline()
        ref_basenum = int(line.split()[5])
    ref_motif_starts = [num + ref_basenum for num in ref_motif_starts]
    ref_motif_ends = [num + ref_basenum for num in ref_motif_ends]
    results += calc_rmsd_tmscore(
        pdb_name=ori_pdb,
        reference_PDB=reference_PDB,
        scaffold_pdb_path=f"{scaffold_dir}/scaffold_fasta/{ori_pdb}/esmfold_pdb",
        scaffold_info_path=f"{scaffold_dir}/scaffold_info",
        ref_motif_starts=ref_motif_starts,
        ref_motif_ends=ref_motif_ends,
        output_path=output_dir,
    )

results = pd.DataFrame(
    results, columns=["pdb_name", "index", "rmsd", "plddt", "tmscore"]
)

In [4]:
df_cat = []
pdb_name_list = []
for path in glob.glob(
    scaffold_dir + "/scaffold_fasta/**/**/**/eval/all_top_samples.csv"
):
    df = pd.read_csv(path)
    pdb_name = path.split("/")[-3]
    if pdb_name in pdb_name_list:
        continue
    pdb_name_list.append(pdb_name)
    df["pdb_name"] = pdb_name
    df["index"] = df["sample_path"].apply(
        lambda x: int(x.split("/")[-2].split("_")[1].split(".")[0])
    )
    df = df[["pdb_name", "index", "bb_rmsd", "bb_tmscore"]]
    df_cat.append(df)
df_cat = pd.concat(df_cat, axis=0)
results = pd.merge(results, df_cat, on=["pdb_name", "index"])
results.to_csv(os.path.join(output_dir, "rmsd_tmscore.csv"), index=False)

### Calculate success rate

In [None]:
def cal_success_scaffold(pdb):
    total = len(pdb)
    pdb["total"] = total
    pdb = pdb[(pdb["rmsd"] < 1.0) & (pdb["bb_tmscore"] > 0.8)]
    return pdb


rmsd_tmscore = pd.read_csv(os.path.join(output_dir, "rmsd_tmscore.csv"))
success_scaffold = rmsd_tmscore.groupby("pdb_name", as_index=False).apply(
    cal_success_scaffold
)
success_scaffold_count = success_scaffold.groupby("pdb_name").size()
success_scaffold_count = success_scaffold_count.reset_index(
    name="success_count"
)

all_pdb = list(rmsd_tmscore["pdb_name"].unique())
success_pdb = list(success_scaffold_count["pdb_name"])
failed_pdb = list(set(all_pdb) - set(success_pdb))
failed_scaffold_count = {
    "pdb_name": failed_pdb,
    "success_count": [0] * len(failed_pdb),
}
results = pd.concat(
    [success_scaffold_count, pd.DataFrame(failed_scaffold_count)]
).sort_values("pdb_name")
results.to_csv(os.path.join(output_dir, "result.csv"))
results