In [None]:
# analysis paper
# import libraries
import sys
import seaborn as sns
import numpy as np
import scipy.stats as _stats
from functools import reduce
from pipeline.analysis import *
from pipeline.utils import *
from pipeline import *
import logging
import networkx as nx
import glob
from scipy.stats import sem as sem
from matplotlib import colormaps
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from os.path import join

from matplotlib.ticker import MaxNLocator

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

logging.getLogger().setLevel(logging.INFO)

from cinnabar import wrangle as _wrangle

print(BSS.__file__)

In [None]:
def check_normal_dist(values):
    # check normally dist
    if len(values) < 50:
        stat, p = _stats.shapiro(values)
    else:
        stat, p = _stats.kstest(values)
    if p < 0.05:
        return True
    else:
        return False


def flatten_comprehension(matrix):
    return [item for row in matrix for item in row]

In [None]:
# define the analysis method to use
ana_dicts = {
    "plain": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "subsampling": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    # "1ns": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 25,
    #     "name": None,
    # },
    # "2ns": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 50,
    #     "name": None,
    # },
    # "3ns": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 75,
    #     "name": None,
    # },
    # "autoeq": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": True,
    #     "statistical inefficiency": True,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    # "TI": {
    # "estimator": "TI",
    # "method": "alchemlyb",
    # "check overlap": True,
    # "try pickle": True,
    # "save pickle": True,
    # "auto equilibration": False,
    # "statistical inefficiency": False,
    # "truncate lower": 0,
    # "truncate upper": 100,
    # "name": None,
    # },
    #     "single_0": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_1": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_2": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # }
}

In [None]:
prot_dict_name = {
    "tyk2": "TYK2",
    "mcl1": "MCL1",
    "p38": "P38α",
    "syk": "SYK",
    "hif2a": "HIF2A",
    "cmet": "CMET",
}
eng_dict_name = {
    "AMBER": "AMBER22",
    "SOMD": "SOMD1",
    "GROMACS": "GROMACS23",
    "hahn": "Hahn et al.",
    "openfe": "OpenFE",
    "fepplus": "Ross et al.",
}

set_cols = pipeline.analysis.set_colours(
    other_results_names=["hahn", "openfe", "fepplus"],
    colour_dict={"openfe": "cadetblue", "hahn": "thistle", "fepplus": "indigo"},
)
col_dict = {}
for eng in eng_dict_name:
    col_dict[eng_dict_name[eng]] = set_cols[eng]
col_dict

In [None]:
network_dict = {}
# , "rbfenn", "flare", "combined", "lomap-a-optimal", "lomap-d-optimal", "rbfenn-a-optimal", "rbfenn-d-optimal"
for network in [
    "lomap",
    "rbfenn",
    "flare",
    "combined",
    "lomap-a-optimal",
    "lomap-d-optimal",
    "rbfenn-a-optimal",
    "rbfenn-d-optimal",
]:  # lomap rbfenn combined
    # all the options
    ana_obj_dict = {}

    for protein in ["tyk2", "mcl1", "p38", "syk", "hif2a", "cmet"]:  #
        ana_obj_dict[protein] = {}

        for ana_dict in ana_dicts:
            ana_prot = analysis_protocol(ana_dicts[ana_dict])

            bench_folder = f"/home/anna/Documents/benchmark"
            # main_dir = f"{bench_folder}/reruns/{protein}"
            main_dir = f"/backup/{protein}"

            # # if need size of protein
            # try:
            #     prot = BSS.IO.readMolecules(
            #         [
            #             f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.gro",
            #             f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.top",
            #         ]
            #     )[0]
            # except:
            #     prot = BSS.IO.readMolecules(
            #         [
            #             f"{bench_folder}/inputs/{protein}/{protein}_parameterised.prm7",
            #             f"{bench_folder}/inputs/{protein}/{prot}_parameterised.rst7",
            #         ]
            #     )[0]

            # print(f"no of residues in the prot: {prot.nResidues()}")

            # choose location for the files
            if protein == "syk" or protein == "cmet" or protein == "hif2a":
                # the lomap network
                if network == "lomap":
                    net_file = f"{main_dir}/execution_model/network_lomap.dat"
                elif network == "combined":
                    net_file = f"{main_dir}/execution_model/network_lomap.dat"
                else:
                    ana_obj_dict[protein][ana_dict] = None
                    continue
            elif protein == "p38":
                if (
                    network == "lomap-a-optimal"
                    or network == "lomap-d-optimal"
                    or network == "rbfenn-a-optimal"
                    or network == "rbfenn-d-optimal"
                ):
                    ana_obj_dict[protein][ana_dict] = None
                    continue
                else:
                    net_file = f"{main_dir}/execution_model/network_{network}.dat"

            else:
                net_file = f"{main_dir}/execution_model/network_{network}.dat"

            exp_file = f"{bench_folder}/inputs/experimental/{protein}.yml"
            output_folder = f"{main_dir}/outputs_extracted"

            # prot_file = f"{main_dir}/execution_model/protocol.dat" # no protocol used , name added after if needed
            pipeline_prot = pipeline_protocol(auto_validate=True)
            # pipeline_prot.name("")

            # initialise the network object
            all_analysis_object = analysis_network(
                output_folder,
                exp_file=exp_file,
                net_file=net_file,
                analysis_prot=ana_prot,
                # method=pipeline_prot.name(),  # if the protocol had a name
                # engines=pipeline_prot.engines(),
            )

            if ana_dict == "single":
                all_analysis_object.file_ext = (
                    all_analysis_object.file_ext + f"_{ana_dict}"
                )

            # compute
            try:
                all_analysis_object.compute_results()
            except:
                print("failed analysis")

            # add ligands folder
            all_analysis_object.add_ligands_folder(
                f"{bench_folder}/inputs/reruns/{protein}/ligands_intermediates"
            )

            ana_obj_dict[protein][ana_dict] = all_analysis_object

    # print(ana_obj_dict)

    network_dict[network] = ana_obj_dict

In [None]:
ana_obj_dict = network_dict["lomap"]

# initial
ana_obj = ana_obj_dict["tyk2"]["plain"]

In [None]:
# check maximum possible accuracy
r2_dict = {}
r2_error_dict = {}
for prot in prot_dict_name.keys():
    r2_dict[prot] = {}
    r2_error_dict[prot] = {}
    ana_obj = ana_obj_dict[prot]["plain"]
    print(prot, len(ana_obj.ligands))
    print(
        "max",
        max(ana_obj.exper_val_dict.values())[0],
        "min",
        min(ana_obj.exper_val_dict.values())[0],
        "range",
        max(ana_obj.exper_val_dict.values())[0]
        - min(ana_obj.exper_val_dict.values())[0],
    )
    avg = np.mean([val[1] for val in ana_obj.exper_val_dict.values()])
    std = np.std([val[0] for val in ana_obj.exper_val_dict.values()])
    print("mean of error", avg, "std of val", std)
    # experimental uncertainty is std of measurement error
    # max is measurement error / std dev of the affinity , squared
    # tyk2 mcl1 Ki 0.44
    # others IC50 0.75
    r2max = 1 - (avg / std) ** 2
    print(r2max)
    r2_dict[prot]["maximum"] = r2max
    r2_error_dict[prot]["maximum"] = (0, 0)

In [None]:
# make single vs triplicate results
for prot in ana_obj_dict.keys():
    for r in range(0, 3, 1):
        ana_obj = ana_obj_dict[prot][f"single_{r}"]
        # function for single dicts
        ana_obj.compute_single_repeat_results(repeat=r)
        for eng in ["AMBER", "SOMD", "GROMACS"]:
            print(prot, eng)
            ana_obj.change_name(eng, f"{eng}_old")
            ana_obj.change_name(f"{eng}_single", eng)
            if eng not in ana_obj.engines:
                ana_obj.engines.append(eng)
            if eng in ana_obj.other_results_names:
                ana_obj.other_results_names.remove(eng)
        print(ana_obj.engines + ana_obj.other_results_names)
        print(ana_obj.calc_pert_dict[eng])

# # error for a perturbation per single run

# uncertainty_dict_single = {}

# for eng in all_analysis_object.engines:
#     uncertainty_dict_single[eng] = {}
#     repeat = 0
#     for file in all_analysis_object._results_repeat_files[eng]:
#         uncertainty_dict_single[eng][repeat] = {}
#         calc_diff_dict = make_dict.comp_results(
#             file, all_analysis_object.perturbations, eng, name=None
#         )

#         for pert in calc_diff_dict.keys():
#             uncertainty_dict_single[eng][repeat][pert] = calc_diff_dict[pert][1]

#         repeat += 1

In [None]:
# identify any outliers and plot again if needed above
failed_perts_dict_percen = {}
failed_perts_dict = {}

for prot in ana_obj_dict.keys():
    failed_perts_dict_percen[prot_dict_name[prot]] = {}
    failed_perts_dict[prot_dict_name[prot]] = {}
    ana_obj = ana_obj_dict[prot]["plain"]
    print(prot)
    for eng in ana_obj.engines:  # ana_obj.engines
        failed_perts_dict_percen[prot_dict_name[prot]][eng_dict_name[eng]] = (
            100 - ana_obj.successful_perturbations(eng)[1]
        )
        failed_perts_dict[prot_dict_name[prot]][
            eng_dict_name[eng]
        ] = ana_obj.failed_perturbations(eng)
        print(
            f"failed percentage for {eng}: {100 - ana_obj.successful_perturbations(eng)[1]} ({len(ana_obj.perturbations) - len(ana_obj.successful_perturbations(eng)[2])} / {len(ana_obj.perturbations)})"
        )
        print(f"{eng} failed perturbations: {ana_obj.failed_perturbations(engine=eng)}")
        print(f"{eng} disconnected ligands: {ana_obj.disconnected_ligands(engine=eng)}")
        print(f"outliers {eng}: {ana_obj.get_outliers(threshold=10, name=eng)}")

In [None]:
# list of the 2 fs and reverse runs
failed_dict = failed_perts_dict
twofs_run_dict = {
    "TYK2": {
        "AMBER22": [
            "lig_ejm48~lig_ejm53",
            "lig_ejm31~lig_ejm54",
            "lig_ejm31~lig_ejm43",
        ],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "MCL1": {
        "AMBER22": [
            "lig_27~lig_39",
            "lig_32~lig_34",
            "lig_65~lig_67",
            "lig_53~lig_58",
            "lig_53~lig_63",
            "lig_37~lig_65",
            "lig_37~lig_67",
            "lig_39~lig_67",
            "lig_61~lig_63",
            "lig_27~lig_65",
            "lig_27~lig_60",
            "lig_27~lig_63",
            "lig_34~lig_53",
            "lig_27~lig_37",
            "lig_50~lig_56",
        ],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "P38α": {
        "AMBER22": [
            "lig_2y~lig_2w",
            "lig_2ee~lig_2a",
            "lig_2ff~lig_2g",
            "lig_2bb~lig_2z",
            "lig_2ee~lig_2l",
            "lig_2gg~lig_2j",
            "lig_2ee~lig_2j",
            "lig_2ee~lig_2w",
            "lig_2p~lig_2t",
            "lig_2ii~lig_2v",
            "lig_2m~lig_2x",
            "lig_2a~lig_2gg",
            "lig_2ii~Intermediate_6",
        ],
        "SOMD1": ["lig_2b~lig_2h", "lig_2dd~lig_2hh"],
        "GROMACS23": [
            "lig_2bb~lig_2w",
            "lig_2j~lig_2k",
            "lig_2n~lig_2a",
            "lig_2o~lig_2a",
            "lig_2p~lig_2a",
            "lig_2t~lig_2a",
            "lig_2y~lig_2w",
            "lig_2dd~lig_2w",
            "lig_2b~Intermediate_2",
            "lig_2d~Intermediate",
            "lig_2bb~lig_2v",
            "lig_2aa~lig_2v",
            "lig_2b~lig_2w",
            "lig_2b~lig_2z",
            "lig_2ii~lig_2v",
            "lig_2t~lig_2x",
            "lig_2l~lig_2x",
            "lig_2gg~lig_2k",
            "lig_2a~lig_2k",
        ],
    },
    "SYK": {
        "AMBER22": [
            "lig_CHEMBL3265005~lig_CHEMBL3265026",
            "lig_CHEMBL3265016~lig_CHEMBL3265018",
            "lig_CHEMBL3265018~lig_CHEMBL3265020",
            "lig_CHEMBL3265025~lig_CHEMBL3265026",
            "lig_CHEMBL3265017~lig_CHEMBL3265021",
        ],
        "SOMD1": [],
        "GROMACS23": [
            "lig_CHEMBL3259820~lig_CHEMBL3264999",
            "lig_CHEMBL3264996~lig_CHEMBL3264999",
            "lig_CHEMBL3265006~lig_CHEMBL3265009",
            "lig_CHEMBL3259820~lig_CHEMBL3265005",
            "lig_CHEMBL3265005~lig_CHEMBL3265026",
        ],
    },
    "HIF2A": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": ["lig_235~lig_266", "lig_251~lig_256"],
    },
    "CMET": {
        "AMBER22": [
            "lig_CHEMBL3402753_200~lig_CHEMBL3402761_1",
            "lig_CHEMBL3402744_300~lig_CHEMBL3402745_200",
            "lig_CHEMBL3402747_3400~lig_CHEMBL3402751_2100",
            "lig_CHEMBL3402754_40~lig_CHEMBL3402755_4200",
            "lig_CHEMBL3402753_200~lig_CHEMBL3402754_40",
        ],
        "SOMD1": [],
        "GROMACS23": [
            "lig_CHEMBL3402744_300~lig_CHEMBL3402748_5300",
            "lig_CHEMBL3402744_300~lig_CHEMBL3402753_200",
        ],
    },
}
reverse_run_dict = {
    "TYK2": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "MCL1": {
        "AMBER22": ["lig_65~lig_67", "lig_37~lig_67", "lig_39~lig_67", "lig_61~lig_63"],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "P38α": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "SYK": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "HIF2A": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": [],
    },
    "CMET": {
        "AMBER22": [],
        "SOMD1": [],
        "GROMACS23": [],
    },
}

failed_actual_dict = {
    "TYK2": {
        "AMBER22": len(failed_dict["TYK2"]["AMBER22"]),
        "SOMD1": len(failed_dict["TYK2"]["SOMD1"]),
        "GROMACS23": len(failed_dict["TYK2"]["GROMACS23"]),
    },
    "MCL1": {
        "AMBER22": len(failed_dict["MCL1"]["AMBER22"]),
        "SOMD1": len(failed_dict["MCL1"]["SOMD1"]),
        "GROMACS23": len(failed_dict["MCL1"]["GROMACS23"]),
    },
    "P38α": {
        "AMBER22": len(failed_dict["P38α"]["AMBER22"]),
        "SOMD1": len(failed_dict["P38α"]["SOMD1"]),
        "GROMACS23": len(failed_dict["P38α"]["GROMACS23"]),
    },
    "SYK": {
        "AMBER22": len(failed_dict["SYK"]["AMBER22"]),
        "SOMD1": len(failed_dict["SYK"]["SOMD1"]),
        "GROMACS23": len(failed_dict["SYK"]["GROMACS23"]),
    },
    "HIF2A": {
        "AMBER22": len(failed_dict["HIF2A"]["AMBER22"]),
        "SOMD1": len(failed_dict["HIF2A"]["SOMD1"]),
        "GROMACS23": len(failed_dict["HIF2A"]["GROMACS23"]),
    },
    "CMET": {
        "AMBER22": len(failed_dict["CMET"]["AMBER22"]),
        "SOMD1": len(failed_dict["CMET"]["SOMD1"]),
        "GROMACS23": len(failed_dict["CMET"]["GROMACS23"]),
    },
}
twofs_dict = {
    "TYK2": {
        "AMBER22": len(twofs_run_dict["TYK2"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["TYK2"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["TYK2"]["GROMACS23"]),
    },
    "MCL1": {
        "AMBER22": len(twofs_run_dict["MCL1"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["MCL1"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["MCL1"]["GROMACS23"]),
    },
    "P38α": {
        "AMBER22": len(twofs_run_dict["P38α"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["P38α"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["P38α"]["GROMACS23"]),
    },
    "SYK": {
        "AMBER22": len(twofs_run_dict["SYK"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["SYK"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["SYK"]["GROMACS23"]),
    },
    "HIF2A": {
        "AMBER22": len(twofs_run_dict["HIF2A"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["HIF2A"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["HIF2A"]["GROMACS23"]),
    },
    "CMET": {
        "AMBER22": len(twofs_run_dict["CMET"]["AMBER22"]),
        "SOMD1": len(twofs_run_dict["CMET"]["SOMD1"]),
        "GROMACS23": len(twofs_run_dict["CMET"]["GROMACS23"]),
    },
}
reverse_dict = {
    "TYK2": {
        "AMBER22": len(reverse_run_dict["TYK2"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["TYK2"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["TYK2"]["GROMACS23"]),
    },
    "MCL1": {
        "AMBER22": len(reverse_run_dict["MCL1"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["MCL1"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["MCL1"]["GROMACS23"]),
    },
    "P38α": {
        "AMBER22": len(reverse_run_dict["P38α"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["P38α"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["P38α"]["GROMACS23"]),
    },
    "SYK": {
        "AMBER22": len(reverse_run_dict["SYK"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["SYK"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["SYK"]["GROMACS23"]),
    },
    "HIF2A": {
        "AMBER22": len(reverse_run_dict["HIF2A"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["HIF2A"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["HIF2A"]["GROMACS23"]),
    },
    "CMET": {
        "AMBER22": len(reverse_run_dict["CMET"]["AMBER22"]),
        "SOMD1": len(reverse_run_dict["CMET"]["SOMD1"]),
        "GROMACS23": len(reverse_run_dict["CMET"]["GROMACS23"]),
    },
}

df_failed = pd.DataFrame(failed_actual_dict).T
df_twofs = pd.DataFrame(twofs_dict).T
df_reverse = pd.DataFrame(reverse_dict).T

In [None]:
# plot the failed perturbations
# df = pd.DataFrame(failed_perts_dict_percen).T
# ax =df.plot(color=pipeline.analysis.set_colours(),
#     kind="bar", xlabel="Protein System", ylabel="failed perturbations (%)")


ax = df_failed.plot(
    color=col_dict,
    kind="bar",
    xlabel="Protein System",
    ylabel="Number of failed perturbations",
)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
for p in ax.patches:
    ax.annotate(str(p.get_height()), (p.get_x() * 1.005, p.get_height() * 1.005))

In [None]:
adaptive_protocol_dict = {
    "AMBER22": {
        "4 fs": 354
        - (
            np.sum(df_reverse["AMBER22"])
            + np.sum(df_twofs["AMBER22"])
            + np.sum(df_failed["AMBER22"])
        ),
        "2 fs": np.sum(df_twofs["AMBER22"]),
        "2 fs reverse": np.sum(df_reverse["AMBER22"]),
        "failed": np.sum(df_failed["AMBER22"]),
    },
    "GROMACS23": {
        "4 fs": 354
        - (
            np.sum(df_reverse["GROMACS23"])
            + np.sum(df_twofs["GROMACS23"])
            + np.sum(df_failed["GROMACS23"])
        ),
        "2 fs": np.sum(df_twofs["GROMACS23"]),
        "2 fs reverse": np.sum(df_reverse["GROMACS23"]),
        "failed": np.sum(df_failed["GROMACS23"]),
    },
    "SOMD1": {
        "4 fs": 354
        - (
            np.sum(df_reverse["SOMD1"])
            + np.sum(df_twofs["SOMD1"])
            + np.sum(df_failed["SOMD1"])
        ),
        "2 fs": np.sum(df_twofs["SOMD1"]),
        "2 fs reverse": np.sum(df_reverse["SOMD1"]),
        "failed": np.sum(df_failed["SOMD1"]),
    },
}
# for key in adaptive_protocol_dict:
#     print(key)
#     assert np.sum(adaptive_protocol_dict[key].values) == 354
# total no of perturbations is 354
df = pd.DataFrame(adaptive_protocol_dict).T.rename(eng_dict_name)

fig, ax = plt.subplots(figsize=(5, 5), dpi=500)
df.plot(
    color=["darkslateblue", "purple", "orchid", "lavender"],
    kind="bar",
    xlabel="MD engine",
    ylabel="Number of perturbations",
    ax=ax,
    width=0.8,
)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
for p in ax.patches:
    ax.annotate(
        str(p.get_height()),
        (p.get_x() + p.get_width() / 2.0, p.get_height()),
        ha="center",
        va="bottom",
        fontsize=10,
        color="black",
    )

ax.legend(loc="center right", fontsize=10)
plt.xlabel("MD Engine", fontsize=12)
plt.ylabel("Number of perturbations", fontsize=12)
plt.tick_params(axis="x", labelsize=10, rotation=0)
plt.tick_params(axis="y", labelsize=10, rotation=0)

In [None]:
# exclude outliers
threshold = 10
for prot in ana_obj_dict.keys():
    for name in ana_dicts.keys():
        print(prot, name)
        ana_obj = ana_obj_dict[prot][name]

        for eng in ana_obj.engines:
            ana_obj.file_ext = ana_obj.file_ext + f"_outliers{threshold}removed"
            ana_obj.remove_outliers(threshold=threshold, name=eng)
        # print(ana_obj.file_ext)

In [None]:
def val_range(val_list):
    min_val = min(val_list)
    max_val = max(val_list)
    # print(min_val)
    # print(max_val)

    return max_val[0] - min_val[0]

In [None]:
# max edge range
for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = ana_obj_dict[prot]["plain"]

    for eng in ana_obj.engines:
        # names = [val for val in ana_obj.calc_pert_dict[eng].keys()]
        # vals = [val[0] for val in ana_obj.calc_pert_dict[eng].values()]
        # sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
        # print(prot, eng, names[sems.index(max(sems))],
        #       vals[sems.index(max(sems))], sems[sems.index(max(sems))])
        ranges = []
        for pert in ana_obj._perturbations_dict[eng]:
            try:
                ra = val_range(
                    [ana_obj.calc_repeat_pert_dict[eng][r][pert] for r in [0, 1, 2]]
                )
                ranges.append(ra)
            except:
                pass
        clean_ranges = [x for x in ranges if str(x) != "nan"]
        print(f"{eng}, {np.mean(clean_ranges):.2f}")

In [None]:
# calcualte the differences in SEM
# SEM differences

sem_dict = {}
sem_dict_name = {}

for name in ana_dicts:
    sem_list_name = []
    sem_dict[name] = {}

    for prot in ana_obj_dict.keys():
        sem_dict[name][prot] = {}

        ana_obj = ana_obj_dict[prot][name]  # subsampling

        for eng in ana_obj.engines:
            sem_dict[name][prot][eng] = {}

            sem_list = []
            sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
            sem_list.append(sems)
            sem_list_name.append(sems)

            sem_list = reduce(lambda xs, ys: xs + ys, sem_list)
            sem_list = [x for x in sem_list if str(x) != "nan"]

            # if not check_normal_dist(sem_list):
            #     print(f"{prot} {name} not normally dist")

            mean = np.mean(sem_list)
            lower_ci, upper_ci = _stats.norm.interval(
                confidence=0.95, loc=np.mean(sem_list), scale=_stats.sem(sem_list)
            )
            print(prot, name, eng, mean, lower_ci, upper_ci)
            sem_dict[name][prot][eng] = (
                mean,
                _stats.tstd(sem_list),
                (lower_ci, upper_ci),
                sem_list,
            )

    sem_list_name = reduce(lambda xs, ys: xs + ys, sem_list_name)
    sem_list_name = [x for x in sem_list_name if str(x) != "nan"]
    mean = np.mean(sem_list_name)
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95, loc=np.mean(sem_list_name), scale=_stats.sem(sem_list_name)
    )
    print(name, mean, lower_ci, upper_ci)
    sem_dict_name[name] = (
        mean,
        _stats.tstd(sem_list_name),
        (lower_ci, upper_ci),
        sem_list_name,
    )

In [None]:
# also calc mae perts

mae_dict = {}

for name in ana_dicts:
    mae_dict[name] = {}

    for prot in ana_obj_dict.keys():
        print(prot, name)

        mae_dict[name][prot] = {}

        ana_obj = ana_obj_dict[prot][name]

        stats_string_all = ""
        try:
            mae = ana_obj.calc_mae_engines(pert_val="pert", recalculate=False)
        except Exception as e:
            print(e)

        for eng in ana_obj.engines:
            stats_string = ""
            try:
                mae_dict[name][prot][eng] = (
                    mae[0][eng]["experimental"],
                    mae[1][eng]["experimental"],
                    mae[2][eng]["experimental"],
                )
                stats_string += f"{eng} MAE: {mae[0][eng]['experimental']:.2f} +/- {mae[1][eng]['experimental']:.2f} kcal/mol, "

                if sem_dict[name][prot][eng][0]:
                    stats_string += f"SEM: {sem_dict[name][prot][eng][0]:.2f} +/- {sem_dict[name][prot][eng][1]:.2f} kcal/mol\n"
                elif name == "single":
                    errors = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
                    stats_string += f"error: {np.mean(errors):.2f} +/- {_stats.tstd(errors):.2f} kcal/mol\n"

                print(stats_string)

            except Exception as e:
                print(e)
                print(f"could not compute for {prot} {name} {eng}")

            # try:
            #     ana_obj.plot_scatter_ddG(
            #         engines=eng, suptitle=f"{prot}, {method}\n", title=f"{stats_string}")
            #     ana_obj.plot_scatter_ddG(engines=eng, use_cinnabar=True)
            # except:
            #     pass
            # stats_string_all+=stats_string

        # try:
        #     ana_obj.plot_scatter_ddG(
        #         suptitle=f"{prot}, {method}\n \n \n \n \n", title=f"{stats_string_all}", engines=ana_obj.engines)
        # except:
        #     print(f"could not plot {prot} {method}")

In [None]:
# graphs based on engine
plotting_dict = sem_dict  # mae_dict or sem_dict
stats_name = "ΔΔG SEM"  # MAE or SEM

fig, axes = plt.subplots(
    nrows=3, ncols=1, figsize=(20, 20), sharex=True, sharey=True, dpi=500
)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    df_list = []
    df_err_list = []
    for name in ana_dicts:
        # print(name)
        df = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[0])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        df_err = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[1])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        # df_lower = df_err.applymap(lambda x: x[0])
        # df_upper = df_err.applymap(lambda x: x[1])
        # df_err = (df_upper - df_lower) / 2

        df_list.append(df)
        df_err_list.append(df_err)

    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_list,
    )
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_err_list,
    )

    # print(df)
    # print(engine)
    # print(df.mean())
    # print(df.sem())
    # print(df_err)

    # engine colours
    # col_dict = {
    #     "AMBER": plt.get_cmap("autumn"),
    #     "SOMD": plt.get_cmap("cool"),
    #     "GROMACS": plt.get_cmap("viridis"),
    # }

    # scale data for compatibility with cmap
    data = [i for i in range(1, len(df.columns) + 1)]
    den = max(data) - min(data)
    scaled_data = [(datum - min(data)) / den for datum in data]

    # get colors corresponding to data
    colors = []
    my_cmap = plt.get_cmap("plasma")  # col_dict[engine]

    for decimal in scaled_data:
        colors.append(my_cmap(decimal))

    df.plot(
        kind="bar",
        color=colors,
        yerr=df_err,
        title=eng_dict_name[engine],
        ax=pos,
        xlabel="Protein System",
        ylabel=f"{stats_name} (kcal/mol)",
        legend=False,
    )

    if engine == "AMBER":
        print("yay")
        # pos.set_ylim(bottom=0)
        pos.legend(loc="upper left", fontsize=12)

# fig.suptitle(f'{stats_name} perturbations for LOMAP/RBFENN-score')
plt.tick_params(axis="x", labelsize=12, rotation=0)
plt.tick_params(axis="y", labelsize=12, rotation=0)

In [None]:
# one graph for one method but that compared for each protein
plotting_dict = sem_dict  # mae_dict or sem_dict
stats_name = "ΔΔG SEM"  # MAE or SEM
fig, ax = plt.subplots(figsize=(5, 5), dpi=500)
plt.xlim = ()
plt.ylim = (0, 2)

name = "plain"
df = (
    pd.DataFrame(plotting_dict[name])
    .applymap(lambda x: x[0])
    .rename(prot_dict_name, axis=1)
    .T.rename(eng_dict_name, axis=1)
)
df_err = (
    pd.DataFrame(plotting_dict[name])
    .applymap(lambda x: x[1])
    .rename(prot_dict_name, axis=1)
    .T.rename(eng_dict_name, axis=1)
)
# df_lower = df_err.applymap(lambda x: x[0])
# df_upper = df_err.applymap(lambda x: x[1])
# df_err = (df_upper - df_lower) / 2
ax = df.plot(
    kind="bar",
    color=col_dict,
    xlabel="Protein System",
    ylabel=f"{stats_name} (kcal/mol)",
    yerr=df_err,
    ax=ax,
)
ax.set_ylim(bottom=0)
plt.tick_params(axis="x", rotation=0)
plt.tick_params(axis="y", rotation=0)

In [None]:
# check the statistical significance and make a violin plot

stats_name = "ΔΔG SEM"

if stats_name == "ΔG MAE":
    ana_obj_dict = network_dict["lomap"]
    print("yay")
else:
    ana_obj_dict = network_dict["rbfenn"]

# checking for significance
eng1 = "AMBER"
eng2 = "SOMD"
eng3 = "GROMACS"
first_err_vals = []
second_err_vals = []
third_err_vals = []

for prot in ana_obj_dict:
    ana_obj = ana_obj_dict[prot]["plain"]

    filtered_keys = [
        key
        for key in ana_obj.calc_pert_dict[eng1]
        if key in ana_obj.calc_pert_dict[eng2]
        and key in ana_obj.calc_pert_dict[eng3]
        and not (
            np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
            or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()
            or np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
        )
    ]
    if stats_name == "ΔΔG MAE":
        # MAE
        f_err_vals = [
            abs(ana_obj.calc_pert_dict[eng1][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
        s_err_vals = [
            abs(ana_obj.calc_pert_dict[eng2][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
        t_err_vals = [
            abs(ana_obj.calc_pert_dict[eng3][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
    elif stats_name == "ΔG MAE":
        # MAE
        ligs = [lig for lig in ana_obj.ligands if "lig_23" not in lig]
        f_err_vals = [
            abs(
                ana_obj.cinnabar_calc_val_dict[eng1][key][0]
                - ana_obj.cinnabar_exper_val_dict[eng1][key][0]
            )
            for key in ligs
            if not "Intermediate" in key
        ]
        s_err_vals = [
            abs(
                ana_obj.cinnabar_calc_val_dict[eng2][key][0]
                - ana_obj.cinnabar_exper_val_dict[eng2][key][0]
            )
            for key in ana_obj.ligands
            if not "Intermediate" in key
        ]
        t_err_vals = [
            abs(
                ana_obj.cinnabar_calc_val_dict[eng3][key][0]
                - ana_obj.cinnabar_exper_val_dict[eng3][key][0]
            )
            for key in ana_obj.ligands
            if not "Intermediate" in key
        ]
    elif stats_name == "ΔΔG SEM":
        # SEM
        f_err_vals = [
            abs(ana_obj.calc_pert_dict[eng1][key][1]) for key in filtered_keys
        ]
        s_err_vals = [
            abs(ana_obj.calc_pert_dict[eng2][key][1]) for key in filtered_keys
        ]
        t_err_vals = [
            abs(ana_obj.calc_pert_dict[eng3][key][1]) for key in filtered_keys
        ]
    else:
        print("wrong name")

    first_err_vals.append(f_err_vals)
    second_err_vals.append(s_err_vals)
    third_err_vals.append(t_err_vals)

first_err_vals = flatten_comprehension(first_err_vals)
second_err_vals = flatten_comprehension(second_err_vals)
third_err_vals = flatten_comprehension(third_err_vals)

# filtered_data = [t for t in zip(first_err_vals, second_err_vals, third_err_vals) if not any(np.isnan(x) for x in t)]
# first_err_vals, second_err_vals, third_err_val = map(list, zip(*filtered_data))
valid_indices = [
    i
    for i in range(len(first_err_vals))
    if not (
        np.isnan(first_err_vals[i])
        or np.isnan(second_err_vals[i])
        or np.isnan(third_err_vals[i])
    )
]
first_err_vals = [first_err_vals[i] for i in valid_indices]
second_err_vals = [second_err_vals[i] for i in valid_indices]
third_err_vals = [third_err_vals[i] for i in valid_indices]

assert len(first_err_vals) == len(second_err_vals)
assert len(first_err_vals) == len(third_err_vals)
assert len(second_err_vals) == len(third_err_vals)

eng_list_dict = {}
eng_list_dict[eng1] = first_err_vals
eng_list_dict[eng2] = second_err_vals
eng_list_dict[eng3] = third_err_vals

for eng, vals in zip(
    [eng1, eng2, eng3], [first_err_vals, second_err_vals, third_err_vals]
):
    mean = np.mean(vals)
    std = np.std(vals)
    ci = 1.96 * (std / np.sqrt(len(vals)))
    print(f"{eng}, {mean:.2f} ({mean-ci:.2f},{mean+ci:.2f}) , {std:.2f}")

stats_test_dict = {}

for enga in [eng1, eng2, eng3]:
    stats_test_dict[enga] = {}
    for engb in [eng1, eng2, eng3]:
        if enga == engb:
            stats_test_dict[enga][engb] = 100
        else:
            # check normally distributed
            if (
                _stats.shapiro(
                    abs(np.array(eng_list_dict[enga] - np.array(eng_list_dict[engb])))
                )[1]
                > 0.05
            ):
                print("data is normally distributed !!")
            else:
                # absolute error  # ttest_rel
                t, p = _stats.wilcoxon(eng_list_dict[enga], eng_list_dict[engb])
                stats_test_dict[enga][engb] = p
        # print(enga, engb, t, p)

df = pd.DataFrame(stats_test_dict).applymap(lambda x: float(x))
print(f"statistical significance for the {stats_name} between engines")
df

In [None]:
plt.hist(
    first_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng1],
    label=eng_dict_name[eng1],
    alpha=0.5,
)
plt.hist(
    second_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng2],
    label=eng_dict_name[eng2],
    alpha=0.5,
)
plt.hist(
    third_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng3],
    label=eng_dict_name[eng3],
    alpha=0.5,
)
plt.legend(loc="upper right")
plt.xlabel(f"ddG {stats_name} (kcal/mol)")
plt.ylabel("Density")
plt.title("All proteins")
# plot for all proteins, and ind proteins too
# also check sem error sig again

In [None]:
# violin plots
data = {
    "engine": [f"{eng_dict_name[eng1]}"] * len(first_err_vals)
    + [f"{eng_dict_name[eng2]}"] * len(second_err_vals)
    + [f"{eng_dict_name[eng3]}"] * len(third_err_vals),
    "error": flatten_comprehension([first_err_vals, second_err_vals, third_err_vals]),
}

df = pd.DataFrame(data)

fig, ax = plt.subplots(figsize=(8, 5), dpi=500)
sns.violinplot(x="engine", y="error", data=df, inner="box", palette=col_dict)

# plt.title(f"{stats_name} Distribution for Different MD Engines across all protein systems")
plt.tick_params(axis="x", labelsize=12, rotation=0)
plt.tick_params(axis="y", labelsize=12, rotation=0)
plt.xlabel("MD Engine")
plt.ylabel(f"{stats_name} (kcal/mol)")

In [None]:
# MAD between engines

eng1 = "AMBER"
eng2 = "SOMD"
eng3 = "GROMACS"
first_err_vals = []
second_err_vals = []
third_err_vals = []

for prot in ana_obj_dict:
    ana_obj = ana_obj_dict[prot]["plain"]
    use_dict = ana_obj.calc_pert_dict  # use_dict = ana_obj.calc_pert_dict
    filtered_keys = [
        key
        for key in use_dict[eng1]
        if key in use_dict[eng2]
        and key in use_dict[eng3]
        and not (
            np.isnan(use_dict[eng1][key]).any()
            or np.isnan(use_dict[eng2][key]).any()
            or np.isnan(use_dict[eng3][key]).any()
        )
    ]

    f_err_vals = [use_dict[eng1][key][0] for key in filtered_keys]
    s_err_vals = [use_dict[eng2][key][0] for key in filtered_keys]
    t_err_vals = [use_dict[eng3][key][0] for key in filtered_keys]

    first_err_vals.append(f_err_vals)
    second_err_vals.append(s_err_vals)
    third_err_vals.append(t_err_vals)

first_err_vals = flatten_comprehension(first_err_vals)
second_err_vals = flatten_comprehension(second_err_vals)
third_err_vals = flatten_comprehension(third_err_vals)

# filtered_data = [t for t in zip(first_err_vals, second_err_vals, third_err_vals) if not any(np.isnan(x) for x in t)]
# first_err_vals, second_err_vals, third_err_val = map(list, zip(*filtered_data))
valid_indices = [
    i
    for i in range(len(first_err_vals))
    if not (
        np.isnan(first_err_vals[i])
        or np.isnan(second_err_vals[i])
        or np.isnan(third_err_vals[i])
    )
]
first_err_vals = [first_err_vals[i] for i in valid_indices]
second_err_vals = [second_err_vals[i] for i in valid_indices]
third_err_vals = [third_err_vals[i] for i in valid_indices]

assert len(first_err_vals) == len(second_err_vals)
assert len(first_err_vals) == len(third_err_vals)
assert len(second_err_vals) == len(third_err_vals)

eng_list_dict = {}
eng_list_dict[eng1] = first_err_vals
eng_list_dict[eng2] = second_err_vals
eng_list_dict[eng3] = third_err_vals

stats_test_dict = {}

for enga in [eng1, eng2, eng3]:
    stats_test_dict[enga] = {}
    for engb in [eng1, eng2, eng3]:
        if enga == engb:
            stats_test_dict[enga][engb] = 100
        else:
            # check normally distributed
            if (
                _stats.shapiro(
                    abs(np.array(eng_list_dict[enga] - np.array(eng_list_dict[engb])))
                )[1]
                > 0.05
            ):
                print(
                    f"data is normally distributed for {enga} and {engb}!! Still carrying out wilcoxon signed rank ...."
                )

            t, p = _stats.wilcoxon(
                eng_list_dict[enga], eng_list_dict[engb]
            )  # absolute error  # ttest_rel
            stats_test_dict[enga][engb] = p
        # print(enga, engb, t, p)

df_col = pd.DataFrame(stats_test_dict).applymap(lambda x: float(x))
print("statistical significance between the perturbations calculated")
df_col

In [None]:
# compared to each other - MAD
engines = ana_obj.engines

df = pd.DataFrame(columns=engines, index=engines)
df_err = pd.DataFrame(columns=engines, index=engines)
df_ci = pd.DataFrame(columns=engines, index=engines)

# iterate compared to experimental
for eng1, eng2 in it.product(engines, engines):
    res = stats_engines.compute_stats(
        x=eng_list_dict[eng2], y=eng_list_dict[eng1], statistic="MUE"
    )

    mean_absolute_error = res[0]  # the computed statistic
    err = res[1]  # the stderr from bootstrapping
    ci = res[2]

    # loc index, column
    df.loc[eng2, eng1] = mean_absolute_error
    df_err.loc[eng2, eng1] = err
    df_ci.loc[eng2, eng1] = ci

mad = (df, df_err, df_ci)
mad[0].update(mad[0].applymap(lambda x: f"" if x == 0 else f"{x:.2f}"))
mad[2].update(
    mad[2].applymap(lambda x: f"" if x[0] == 0 else f"({x[0]:.2f}, {x[1]:.2f})")
)
df_val = mad[0].astype(str) + "\n" + mad[2].astype(str)
df_val

In [None]:
# plotting the stats test

# threshold for colour labels
color_labels = df_col.applymap(
    lambda x: "white"
    if x == 100
    else "grey"  # same engine, no stats test
    if x > 0.05
    else "pink"  # no statistically significant difference
)  # statistically significant difference

# mapping dictionary
color_mapping = {"pink": "#FFC0CB", "white": "#FFFFFF", "grey": "#BEBEBE"}

# Convert text labels to a numerical array
color_numeric = (
    df_col.applymap(lambda x: 0 if x == 100 else 1 if x > 0.05 else 2)
    .rename(eng_dict_name, axis=1)
    .rename(eng_dict_name, axis=0)
)

# below as otherwise problem if only two colours
array_col_dict = {
    0: color_mapping["white"],
    1: color_mapping["grey"],
    2: color_mapping["pink"],
}

numeric_colours_list = flatten_comprehension(color_numeric.values.tolist())

cmap = mcolors.ListedColormap(
    [
        array_col_dict[key]
        for key in array_col_dict.keys()
        if key in numeric_colours_list
    ]
)

# Plot heatmap using numeric mapping for colors
fig, ax = plt.subplots(figsize=(4, 4), dpi=500)
sns.heatmap(
    color_numeric,
    annot=df_val,
    fmt="s",
    cmap=cmap,
    cbar=False,
    ax=ax,
)
plt.xticks(rotation=0, fontsize=10)
plt.yticks(rotation=90, fontsize=10)

legend_patches = [  # mpatches.Patch(color=color_mapping["white"], label=""),
    mpatches.Patch(color=color_mapping["pink"], label="p ≤ 0.05"),
    mpatches.Patch(color=color_mapping["grey"], label="p > 0.05"),
]

# Add legend to the plot
plt.legend(
    handles=legend_patches,
    loc="center left",
    title="Statistical significance",
    bbox_to_anchor=(1, 0.5),
)

plt.title("ΔΔG MAD (kcal/mol) between MD engines (95% CI)")

In [None]:
# 2d contour plot - for all proteins combined
stats_name = "SEM"

for eng1, eng2 in (["AMBER", "SOMD"], ["AMBER", "GROMACS"], ["GROMACS", "SOMD"]):
    first_err_vals = []
    second_err_vals = []

    for prot in ana_obj_dict:
        ana_obj = ana_obj_dict[prot]["plain"]

        filtered_keys = [
            key
            for key in ana_obj.calc_pert_dict[eng1]
            if key
            in ana_obj.calc_pert_dict[eng2]  # and key in ana_obj.calc_pert_dict[eng3]
            and not (
                np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
                or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()  # or
                # np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
            )
        ]

        if stats_name == "SEM":
            # MAE
            f_err_vals = [
                abs(
                    ana_obj.calc_pert_dict[eng1][key][0]
                    - ana_obj.exper_pert_dict[key][0]
                )
                for key in filtered_keys
                if not "Intermediate" in key
            ]
            s_err_vals = [
                abs(
                    ana_obj.calc_pert_dict[eng2][key][0]
                    - ana_obj.exper_pert_dict[key][0]
                )
                for key in filtered_keys
                if not "Intermediate" in key
            ]

        elif stats_name == "MAE":
            # SEM
            f_err_vals = [
                abs(ana_obj.calc_pert_dict[eng1][key][1]) for key in filtered_keys
            ]
            s_err_vals = [
                abs(ana_obj.calc_pert_dict[eng2][key][1]) for key in filtered_keys
            ]

        first_err_vals.append(f_err_vals)
        second_err_vals.append(s_err_vals)

    x = flatten_comprehension(first_err_vals)
    y = flatten_comprehension(second_err_vals)
    z = np.abs(np.array(x) - np.array(y))  # z = np.sin(x) + np.cos(y)

    fig, ax = plt.subplots(figsize=(8, 6))

    # # Scatter plot on top to show data points
    plt.scatter(x, y, c=z, cmap="Purples", edgecolors="black")
    sns.kdeplot(x=x, y=y, cmap="PuRd", fill=True, levels=10, thresh=0.05, alpha=0.7)
    plt.colorbar(label=f"Absolute difference\n between engine {stats_name} (kcal/mol)")

    # Labels and Title
    plt.xlabel(f"{eng1} {stats_name}")
    plt.ylabel(f"{eng2} {stats_name}")
    plt.title(f"{stats_name}")

In [None]:
# 2d contour plot
stats_name = "MAE (kcal/mol)"

for eng1, eng2, eng3 in [
    ("GROMACS", "SOMD", "AMBER"),
    ("SOMD", "AMBER", "GROMACS"),
    ("AMBER", "GROMACS", "SOMD"),
]:
    first_err_vals = []
    second_err_vals = []
    third_err_vals = []

    for prot in ana_obj_dict:
        ana_obj = ana_obj_dict[prot]["plain"]

        filtered_keys = [
            key
            for key in ana_obj.calc_pert_dict[eng1]
            if key in ana_obj.calc_pert_dict[eng2]
            and key in ana_obj.calc_pert_dict[eng3]
            and not (
                np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
                or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()
                or np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
            )
        ]

        # MAE
        f_err_vals = [
            abs(ana_obj.calc_pert_dict[eng1][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
        s_err_vals = [
            abs(ana_obj.calc_pert_dict[eng2][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
        t_err_vals = [
            abs(ana_obj.calc_pert_dict[eng3][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]

        # SEM
        # f_err_vals = [abs(ana_obj.calc_pert_dict[eng1][key][1]) for key in filtered_keys]
        # s_err_vals = [abs(ana_obj.calc_pert_dict[eng2][key][1]) for key in filtered_keys]
        # t_err_vals = [abs(ana_obj.calc_pert_dict[eng3][key][1]) for key in filtered_keys]

        first_err_vals.append(f_err_vals)
        second_err_vals.append(s_err_vals)
        third_err_vals.append(t_err_vals)

    x = np.array(flatten_comprehension(first_err_vals)) - np.array(
        flatten_comprehension(second_err_vals)
    )
    y = np.array(flatten_comprehension(third_err_vals)) - np.array(
        flatten_comprehension(second_err_vals)
    )
    z = np.abs(np.array(x) - np.array(y))  # z = np.sin(x) + np.cos(y)

    fig, ax = plt.subplots(figsize=(8, 6))

    # # Scatter plot on top to show data points
    plt.scatter(
        x,
        y,
        c=flatten_comprehension(second_err_vals),
        cmap="Purples",
        edgecolors="black",
    )
    plt.colorbar(label=f"{eng2} {stats_name}")
    sns.kdeplot(x=x, y=y, cmap="PuRd", fill=True, levels=10, thresh=0.05, alpha=0.7)
    ax.axhline(0, color="black", linewidth=1)
    ax.axvline(0, color="black", linewidth=1)

    ax.text(
        0.05,
        0.95,
        f"better than {eng3}\nworse than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.70,
        0.10,
        f"worse than {eng3}\nbetter than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.05,
        0.10,
        f"worse than {eng3}\nworse than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.70,
        0.95,
        f"better than {eng3}\nbetter than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    # Labels and Title
    plt.xlabel(f"{eng1}-{eng2} {stats_name}")
    plt.ylabel(f"{eng3}-{eng2} {stats_name}")
    plt.title(f"{eng2} {stats_name.replace('(kcal/mol)','')}comparison")

In [None]:
# correlating the number of perturbed atoms with precision (SEM) and accuracy (MAE)

pert_overlap_dict = {}

for prot in ana_obj_dict:
    ana_obj = ana_obj_dict[prot]["plain"]

    df = ana_obj.perturbing_atoms_and_overlap(read_file=True)

    df["score"] = np.nan
    # read in all the lomap scores
    score_dict = {}
    # print(f"{main_dir}/execution_model/network_scores.dat")
    with open(
        f"{join('/', *ana_obj.output_folder.split('/')[:-1])}/execution_model/network_scores.dat"
    ) as lfile:
        for line in lfile:
            score_dict[
                f"{line.split(',')[0].strip()}~{line.split(',')[1].strip()}"
            ] = float(line.split(",")[-1].strip())
    print(score_dict)

    for index, row in df.iterrows():
        if row["perturbation"] not in ana_obj.perturbations:
            df = df.drop(index)
        else:
            try:
                df.at[index, "score"] = score_dict[row["perturbation"]]
            except:
                try:
                    df.at[index, "score"] = score_dict[
                        f'{row["perturbation"].split("~")[1]}~{row["perturbation"].split("~")[0]}'
                    ]
                except:
                    # print(f"not {row['perturbation']}")
                    pass

    pert_overlap_dict[prot] = df

    for eng in ana_obj.engines:
        df2 = df[df["engine"] == eng]
        print(prot, eng, len(df2))

In [None]:
df_plot = pd.concat(pert_overlap_dict.values())

df_plot.rename(
    columns={
        "perturbing_atoms": "Average number of perturbing atoms",
        "diff_to_exp": "MAE (kcal/mol)",
        "percen_overlap_okay": "Overlap > 0.03 (%)",
        "error": "SEM (kcal/mol)",
        "score": "LOMAP-score",
    },
    inplace=True,
)

x = "LOMAP-score"
y = "MAE (kcal/mol)"
z = "SEM (kcal/mol)"

fig, ax = plt.subplots(figsize=(8, 6))
# sns.kdeplot(x=df_plot[x], y=df_plot[y], cmap="PuRd", fill=True, levels=10, thresh=0.05)
df_plot.plot.scatter(
    x,
    y,
    c=z,
    colormap="plasma",
    ax=ax,  # vmin=0, vmax=100, ax=ax
)

for eng in ana_obj.engines:
    df2 = df_plot[df_plot["engine"] == eng]
    print(len(df2))
    df2.plot.scatter(
        x,
        y,
        c=z,
        colormap="plasma",
        title=eng_dict_name[eng],  # vmin=80, vmax=100
    )

In [None]:
mae_list = {}
len_list = {}
for eng in ana_obj.engines:
    mae_list[eng_dict_name[eng]] = []
    len_list[eng_dict_name[eng]] = []
    df2 = df_plot[df_plot["engine"] == eng]
    print(len(df2))
    df_check_good = df2[df2["Overlap > 0.03 (%)"] >= 100]

    for percen_hi, percen_l in zip([120, 100, 80, 40, 20], [100, 80, 40, 20, 0]):
        # df3 = df2[df2["Overlap > 0.03 (%)"] >= percen_l]
        df_check_bad = df2[df2["Overlap > 0.03 (%)"] < percen_hi]
        print(len(df_check_bad))
        mae_list[eng_dict_name[eng]].append(df_check_bad["MAE (kcal/mol)"].mean())
        len_list[eng_dict_name[eng]].append(len(df_check_bad))

    print(eng, np.sum(len_list[eng_dict_name[eng]]))

    # print(eng, len(df_check_good), len(df_check_bad))
    print(mae_list[eng_dict_name[eng]])

# doubel check w lomap score

In [None]:
fig, ax1 = plt.subplots(figsize=(6, 6))

# Plot the MAE data
df = pd.DataFrame(mae_list)
ax1.plot(
    [100, 80, 60, 40, 20], df["AMBER22"], color=col_dict["AMBER22"], label="AMBER22"
)
ax1.plot([100, 80, 60, 40, 20], df["SOMD1"], color=col_dict["SOMD1"], label="SOMD1")
ax1.plot(
    [100, 80, 60, 40, 20],
    df["GROMACS23"],
    color=col_dict["GROMACS23"],
    label="GROMACS23",
)
ax1.set_xlabel(
    "Amount of overlap off-diagonals greater\nthan 0.03 per perturbation (%)"
)
ax1.set_ylabel("ΔΔG MAE (kcal/mol)")
ax1.legend(loc="upper left")

# Create a secondary y-axis
# ax2 = ax1.twinx()

# # Plot the length data
# df_len = pd.DataFrame(len_list)
# bar_width = 2
# positions = [100, 80, 60, 40, 20]

# ax2.bar([p - bar_width for p in positions], df_len["AMBER22"], width=bar_width, color=col_dict["AMBER22"], linestyle='--', alpha=0.5)
# ax2.bar(positions, df_len["SOMD1"], width=bar_width, color=col_dict["SOMD1"], linestyle='--', alpha=0.5)
# ax2.bar([p + bar_width for p in positions], df_len["GROMACS23"], width=bar_width, color=col_dict["GROMACS23"], linestyle='--', alpha=0.5)
# ax2.set_ylabel("Number of Perturbations", rotation=270, labelpad=15)

# ax2.plot(positions, df_len["AMBER22"],color=col_dict["AMBER22"], linestyle='--', alpha=0.5)
# ax2.plot(positions, df_len["SOMD1"], color=col_dict["SOMD1"], linestyle='--', alpha=0.5)
# ax2.plot(positions, df_len["GROMACS23"],color=col_dict["GROMACS23"], linestyle='--', alpha=0.5)
# ax2.set_ylabel("Number of Perturbations (---)", rotation=270, labelpad=15)

# Combine legends
lines_1, labels_1 = ax1.get_legend_handles_labels()
lines_2, labels_2 = ax2.get_legend_handles_labels()
ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc="lower left")

plt.show()

In [None]:
ana_obj_dict = network_dict["lomap"]

In [None]:
# comparing to the literature results for plotting

# obtain the literature results

fepplus_ligs_dict = {}
fepplus_perts_dict = {}
hahn_ligs_dict = {}
hahn_perts_dict = {}
openfe_ligs_dict = {}
openfe_perts_dict = {}

for prot in ana_obj_dict:
    print(prot)

    ana_obj = ana_obj_dict[prot]["plain"]

    file = (
        f"/home/anna/Documents/benchmark/inputs/other_computed/fepplus/{prot}_perts.csv"
    )
    df = pd.read_csv(file, delimiter=",")

    # for perturbations
    fepplus_perts_dict[prot] = {}
    for index, row in df.iterrows():
        fepplus_perts_dict[prot][
            f"lig_{row['Lig 1'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}~lig_{row['Lig 2'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}"
        ] = (
            row["Bennett ddG (kcal/mol)"],
            row["Bennett std. error (kcal/mol)"],
        )

    write_perts_file(
        fepplus_perts_dict[prot],
        # .csv
        file_path=f"/home/anna/Documents/benchmark/inputs/{prot}/perts_file_fepplus_new",
    )

    # for ligands
    file = (
        f"/home/anna/Documents/benchmark/inputs/other_computed/fepplus/{prot}_ligs.csv"
    )
    df = pd.read_csv(file, delimiter=",")

    fepplus_ligs_dict[prot] = {}
    for index, row in df.iterrows():
        fepplus_ligs_dict[prot][
            f"lig_{row['Ligand name'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}"
        ] = (
            row["Pred. dG (kcal/mol)"],
            row["Pred. dG std. error (kcal/mol)"],
        )

    normalised_ligs_dict = {}
    avg = np.mean([val[0] for val in fepplus_ligs_dict[prot].values()])
    for lig in fepplus_ligs_dict[prot]:
        normalised_ligs_dict[lig] = (
            fepplus_ligs_dict[prot][lig][0] - avg,
            fepplus_ligs_dict[prot][lig][1],
        )

    fepplus_ligs_dict[prot] = normalised_ligs_dict

    write_vals_file(
        fepplus_ligs_dict[prot],
        # .csv
        file_path=f"/home/anna/Documents/benchmark/inputs/{prot}/ligs_file_fepplus_new",
    )

    # Hahn et al

    file = f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/{prot}.dat"

    df = pd.read_csv(file, delimiter="  ")

    hahn_perts_dict[prot] = {}
    for index, row in df.iterrows():
        hahn_perts_dict[prot][f"{row['edge']}"] = (
            float(row["ddg"]),
            float(row["ddg_err"]),
        )

    # need to convert into kcal/mol
    for key in hahn_perts_dict[prot]:
        hahn_perts_dict[prot][key] = (
            hahn_perts_dict[prot][key][0] * 0.239006,
            hahn_perts_dict[prot][key][1] * 0.239006,
        )

    write_perts_file(
        hahn_perts_dict[prot],
        # .csv
        file_path=f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/perts_file_{prot}",
    )

    files = [
        f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/perts_file_{prot}.csv"
    ]

    calc_diff_dict = make_dict.comp_results(files)  # older method

    perts, ligs = get_info_network_from_dict(calc_diff_dict)
    exper_dict = ana_obj.exper_val_dict

    convert.cinnabar_file(
        files,
        exper_dict,
        f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/cinnabar_{prot}",
        perturbations=perts,
        method=None,
    )

    # compute the per ligand for the network
    network = _wrangle.FEMap(
        f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/cinnabar_{prot}.csv"
    )

    # for self plotting of per ligand
    hahn_ligs_dict[prot] = make_dict.from_cinnabar_network_node(network, "calc")
    hahn_perts_dict[prot] = make_dict.from_cinnabar_network_edges(
        network, "calc", perts
    )

    # OpenFE
    if prot == "syk":
        openfe_ligs_dict[prot] = {lig: (np.nan, np.nan) for lig in ana_obj.ligands}
        openfe_perts_dict[prot] = {
            lig: (np.nan, np.nan) for lig in ana_obj.perturbations
        }
    else:
        df_main = pd.read_csv(
            "/home/anna/Documents/benchmark/inputs/other_computed/openfe/combined_pymbar3_edge_data.csv"
        )

        for rep in [0, 1, 2]:
            df = df_main[df_main["system name"] == prot]
            df["freenrg"] = (
                df[f"complex_repeat_{rep}_DG (kcal/mol)"]
                - df[f"solvent_repeat_{rep}_DG (kcal/mol)"]
            )
            df["dG_err_temp"] = df[f"complex_repeat_{rep}_DG (kcal/mol)"].apply(
                lambda x: math.pow(x, 2)
            ) + df[f"solvent_repeat_{rep}_DG (kcal/mol)"].apply(
                lambda x: math.pow(x, 2)
            )
            df["error"] = df[f"dG_err_temp"].apply(lambda x: math.sqrt(x))
            df["lig_0"] = "lig_" + df["ligand_A"].str.replace(
                "_redocked", "", regex=True
            ).replace("-charged-pKa-8.1", "", regex=True).replace(
                "-flip", "", regex=True
            ).replace(
                "ejm_", "ejm", regex=True
            ).replace(
                "jmc_", "jmc", regex=True
            )
            df["lig_1"] = "lig_" + df["ligand_B"].str.replace(
                "_redocked", "", regex=True
            ).replace("-charged-pKa-8.1", "", regex=True).replace(
                "-flip", "", regex=True
            ).replace(
                "ejm_", "ejm", regex=True
            ).replace(
                "jmc_", "jmc", regex=True
            )
            df.to_csv(
                f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/{prot}_{rep}.csv",
                columns=["lig_0", "lig_1", "freenrg", "error"],
                index=False,
            )

        files = [
            f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/{prot}_{rep}.csv"
            for rep in [0, 1, 2]
        ]

        calc_diff_dict = make_dict.comp_results(files)  # older method

        perts, ligs = get_info_network_from_dict(calc_diff_dict)
        exper_dict = ana_obj.exper_val_dict

        convert.cinnabar_file(
            files,
            exper_dict,
            f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/cinnabar_{prot}",
            perturbations=perts,
            method=None,
        )

        # compute the per ligand for the network
        print(
            f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/cinnabar_{prot}.csv"
        )
        network = _wrangle.FEMap(
            f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/cinnabar_{prot}.csv"
        )

        # for self plotting of per ligand
        openfe_ligs_dict[prot] = make_dict.from_cinnabar_network_node(network, "calc")
        openfe_perts_dict[prot] = make_dict.from_cinnabar_network_edges(
            network, "calc", perts
        )

In [None]:
# calculate perturbation statistics

stats_name = "KTAU"
print(stats_name)

val_dict = {}

for prot in ana_obj_dict:
    print(prot)

    ana_obj = ana_obj_dict[prot]["plain"]

    if stats_name == "MAE":
        func = ana_obj.calc_mae_engines
        cinn_stats_name = "MUE"

    elif stats_name == "RMSE":
        func = ana_obj.calc_rmse_engines
        cinn_stats_name = "RMSE"
    elif stats_name == "KTAU":
        func = ana_obj.calc_kendalls_rank_engines
        cinn_stats_name = "KTAU"

    else:
        print("no")

    val_dict[prot] = {}

    res = func(pert_val="val", recalculate=False)  # TODO val/pert
    for eng in ana_obj.engines:
        val_dict[prot][eng_dict_name[eng]] = (
            res[0][eng]["experimental"],
            res[1][eng]["experimental"],
            res[2][eng]["experimental"],
        )
    # print(res)

    # literature

    exper_dict = ana_obj.exper_val_dict

    for lit_perts_dict, name in zip(
        [openfe_ligs_dict[prot], hahn_ligs_dict[prot], fepplus_ligs_dict[prot]],
        ["openfe", "hahn", "fepplus"],
    ):
        # for lit_perts_dict, name in zip([openfe_perts_dict[prot], hahn_perts_dict[prot], fepplus_perts_dict[prot]], ["openfe", "hahn", "fepplus"]):  #
        print(name)

        if prot == "syk" and name == "openfe":
            val_dict[prot][eng_dict_name[name]] = (
                0,
                0,
                (0, 0),
            )
            continue

        x = []
        y = []
        xerr = []
        yerr = []

        # perturbations = []
        # excl = 0
        # incl = 0
        # for pert in lit_perts_dict:
        #     if pert.split("~")[0] in exper_dict.keys() and pert.split("~")[1] in exper_dict.keys():
        #         perturbations.append(pert)
        #         incl += 1
        #     else:
        #         excl += 1
        # print("only including perturbations that also have the same ligands used. Not necessarily the same perturbations.")
        # print(f"{excl} perturbations excluded for {name}, {incl} included: {incl/(incl+excl)*100}")

        # # additionally only if there are the same perturbations, check what this would be
        # use_perts = []
        # reverse_perts = []
        # for pert in ana_obj.perturbations:
        #     if pert in perturbations:
        #         use_perts.append(pert)
        #     if f"{pert.split('~')[1]}~{pert.split('~')[0]}" in perturbations:
        #         reverse_perts.append(pert)
        # print(f"{len(use_perts)+len(reverse_perts)} perturbations of these would be the same/reverse perturbations ({(len(use_perts)+len(reverse_perts))/(perturbations)*100} %)")
        # perturbations = flatten_comprehension([use_perts, reverse_perts])

        # exper_pert_dict = make_dict.exper_from_perturbations(exper_dict, perturbations)

        exper_pert_dict = {}
        avg = np.mean([val[0] for val in exper_dict.values()])
        for lig in exper_dict:
            exper_pert_dict[lig] = (exper_dict[lig][0] - avg, exper_dict[lig][1])

        for pert in ana_obj.ligands:  # perturbations
            try:
                x.append(lit_perts_dict[pert][0])
                xerr.append(lit_perts_dict[pert][1])
                y.append(exper_pert_dict[pert][0])
                yerr.append(exper_pert_dict[pert][1])
            except:
                print(pert)

        # calculate statistics

        res = stats_engines.compute_stats(
            x=x, xerr=xerr, y=y, yerr=yerr, statistic=cinn_stats_name
        )
        # print("cinnabar", name, res)

        val_dict[prot][eng_dict_name[name]] = (
            res[0],
            res[1],
            res[2],
        )

df = pd.DataFrame(val_dict)
# df.to_markdown()
df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
df

In [None]:
val_res = pd.DataFrame(val_dict).T.map(lambda x: x[0])
fig, ax = plt.subplots(figsize=(8, 5), dpi=500)

df_err = pd.DataFrame(val_dict).T.map(lambda x: x[2])
df_lower = val_res - df_err.applymap(lambda x: x[0])
df_upper = df_err.applymap(lambda x: x[1]) - val_res
df_err = np.stack([df_lower.T.values, df_upper.T.values], axis=1)

val_res.plot.bar(
    color=col_dict,
    yerr=df_err,
    xlabel="All Protein Systems",
    ylabel=f"{stats_name} (kcal/mol)",
    ax=ax,
    legend=None,
)
# ax.legend(loc='center right', bbox_to_anchor=( # lower center (0.5,1)
#     1.35, 0.5), #fancybox=True, shadow=True
#     )
ax.set_xticklabels([val for val in prot_dict_name.values()], rotation=0)  # [],
# ax.set_ylabel("Kendall's Tau")

In [None]:
# calculate perturbation statistics for all proteins together
# comments to plot ligands
stats_name = "RMSE"
print(stats_name)

if stats_name == "MAE":
    cinn_stats_name = "MUE"
elif stats_name == "RMSE":
    cinn_stats_name = "RMSE"
elif stats_name == "KTAU":
    cinn_stats_name = "KTAU"
elif stats_name == "R2":
    cinn_stats_name = "R2"

val_dict = {}
val_dict["all"] = {}

for eng in ana_obj.engines:  #
    x = []
    y = []
    xerr = []
    yerr = []

    for prot in ana_obj_dict:
        ana_obj = ana_obj_dict[prot]["plain"]

        exper_dict = ana_obj.exper_val_dict

        perturbations = []
        excl = 0
        incl = 0
        for pert in ana_obj._perturbations_dict[eng]:
            try:
                if (
                    str(ana_obj.calc_pert_dict[eng][pert][0]) != "nan"
                    and "Intermediate" not in pert
                ):
                    perturbations.append(pert)
            except:
                pass

        exper_pert_dict = make_dict.exper_from_perturbations(exper_dict, perturbations)

        # exper_pert_dict = {}
        # avg = np.mean([val[0] for val in exper_dict.values()])
        # for lig in exper_dict:
        #     exper_pert_dict[lig] = (exper_dict[lig][0] - avg, exper_dict[lig][1])

        for pert in perturbations:
            # for pert in ana_obj.ligands:
            try:
                x.append(ana_obj.calc_pert_dict[eng][pert][0])
                xerr.append(ana_obj.calc_pert_dict[eng][pert][1])
                # x.append(ana_obj.cinnabar_calc_val_dict[eng][pert][0])
                # xerr.append(ana_obj.cinnabar_calc_val_dict[eng][pert][1])
                y.append(exper_pert_dict[pert][0])
                yerr.append(exper_pert_dict[pert][1])
            except:
                print(pert)

    res = stats_engines.compute_stats(
        x=x, xerr=xerr, y=y, yerr=yerr, statistic=cinn_stats_name
    )
    # print("cinnabar", name, res)

    val_dict["all"][eng_dict_name[eng]] = (
        res[0],
        res[1],
        res[2],
    )

# for lit_perts_dict, name in zip([openfe_ligs_dict, hahn_ligs_dict, fepplus_ligs_dict], ["openfe", "hahn", "fepplus"]):
for lit_perts_dict, name in zip(
    [openfe_perts_dict, hahn_perts_dict, fepplus_perts_dict],
    ["openfe", "hahn", "fepplus"],
):  #
    print(name)

    x = []
    y = []
    xerr = []
    yerr = []

    for prot in ana_obj_dict:
        print(prot)

        if prot == "syk" and name == "openfe":
            continue

        ana_obj = ana_obj_dict[prot]["plain"]
        exper_dict = ana_obj.exper_val_dict

        perturbations = []
        excl = 0
        incl = 0
        for pert in lit_perts_dict[prot]:
            if (
                pert.split("~")[0] in exper_dict.keys()
                and pert.split("~")[1] in exper_dict.keys()
            ):
                perturbations.append(pert)

        exper_pert_dict = make_dict.exper_from_perturbations(exper_dict, perturbations)

        # exper_pert_dict = {}
        # avg = np.mean([val[0] for val in exper_dict.values()])
        # for lig in exper_dict:
        #     exper_pert_dict[lig] = (exper_dict[lig][0] - avg, exper_dict[lig][1])

        # additionally only if there are the same perturbations, check what this would be
        # use_perts = []
        # reverse_perts = []
        # for pert in ana_obj.perturbations:
        #     if pert in perturbations:
        #         use_perts.append(pert)
        #     if f"{pert.split('~')[1]}~{pert.split('~')[0]}" in perturbations:
        #         reverse_perts.append(pert)
        # print(f"{prot}, {len(use_perts)+len(reverse_perts)} perturbations of these would be the same/reverse perturbations ({(len(use_perts)+len(reverse_perts))/(len(perturbations))*100} %)")
        # perturbations = flatten_comprehension([use_perts, reverse_perts])

        for pert in perturbations:  # ligands:
            if pert in lit_perts_dict[prot].keys():
                x.append(lit_perts_dict[prot][pert][0])
                xerr.append(lit_perts_dict[prot][pert][1])
                y.append(exper_pert_dict[pert][0])
                yerr.append(exper_pert_dict[pert][1])

        # calculate statistics

    res = stats_engines.compute_stats(
        x=x, xerr=xerr, y=y, yerr=yerr, statistic=cinn_stats_name
    )
    # print("cinnabar", name, res)

    val_dict["all"][eng_dict_name[name]] = (
        res[0],
        res[1],
        res[2],
    )

df = pd.DataFrame(val_dict)
# df.to_markdown()
df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
df

In [None]:
# for all proteins - radial plot

for eng in ana_obj.engines:
    categories = list([prot_dict_name[key] for key in prot_dict_name.keys()])
    data = []

    for prot in prot_dict_name.keys():
        ana_obj = ana_obj_dict[prot]["plain"]
        x = np.array(
            [
                ana_obj.exper_pert_dict[val][0]
                for val in ana_obj.cinnabar_calc_pert_dict[eng]
                if "Intermediate" not in val
            ]
        )
        y = np.array(
            [
                ana_obj.cinnabar_calc_pert_dict[eng][val][0]
                for val in ana_obj.cinnabar_calc_pert_dict[eng]
                if "Intermediate" not in val
            ]
        )
        abs_error = y - x  # Absolute error calculation - np.abs
        for xi, yi, err in zip(x, y, abs_error):
            data.append([prot_dict_name[prot], xi, yi, err])

    df = pd.DataFrame(
        data,
        columns=[
            "Category",
            "Experimental dG (kcal/mol)",
            "Predicted dG (kcal/mol)",
            "Error",
        ],
    )

    # Set up polar plot
    fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={"projection": "polar"})

    # Convert categories to base angles (evenly spaced)
    base_angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False)
    category_to_angle = {
        category: angle for category, angle in zip(categories, base_angles)
    }

    # Define max angle spread per category (e.g., ±10 degrees)
    max_angle_spread = np.radians(10)

    mae_thresholds = [0.5, 1, 2, 3]
    norm = matplotlib.colors.Normalize(vmin=0, vmax=len(mae_thresholds))

    for i, threshold in enumerate(mae_thresholds):
        for category, angle in category_to_angle.items():
            ax.fill_between(
                # Need high res or you'll fill a triangle
                np.linspace(
                    angle - threshold * max_angle_spread / 2,
                    angle + threshold * max_angle_spread / 2,
                ),
                -4,
                4,
                alpha=0.1,
                color=cm.Grays(norm(i), bytes=False),
            )

    # Create a color mapping using the "plasma" colormap
    cmap = cm.get_cmap("plasma", len(categories))
    category_color_map = {
        cat: cmap(i / (len(categories) - 1)) for i, cat in enumerate(categories)
    }

    # Plot each category
    for category, base_angle in category_to_angle.items():
        subset = df[df["Category"] == category]

        r = subset["Experimental dG (kcal/mol)"]  # Radial distance
        abs_err_norm = subset[
            "Error"
        ]  # / subset["Absolute_Error"].max()  # Normalize error
        angles = (
            base_angle + (abs_err_norm) * max_angle_spread / 2
        )  # Spread points in segment

        sc = ax.scatter(
            angles,
            r,
            c=category_color_map[category],
            cmap="plasma",
            edgecolors=None,
            alpha=0.75,
        )

    # Customize the plot
    ax.set_xticks(base_angles)
    ax.set_xticklabels(categories, fontsize=12)
    ax.set_ylim(
        df["Experimental dG (kcal/mol)"].min() - 0.5,
        df["Experimental dG (kcal/mol)"].max() + 0.5,
    )  # Adjust radial range
    ax.set_title(
        f"Radial Error Plot for {eng_dict_name[eng]} (Angle = Error, Radius = Experimental ΔΔG)"
    )
    # ax.set_xlabel("Experimental dG (kcal/mol)")
    ax.text(
        0.68,
        0.56,
        f"Experimental ΔΔG (kcal/mol)",
        transform=ax.transAxes,
        fontsize=10,
        color="black",
        ha="left",
        va="top",
        rotation=21,
        rotation_mode="anchor",
    )

    legend_patches = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=category_color_map[cat],
            markersize=10,
            label=cat,
        )
        for cat in categories
    ]
    ax.legend(
        handles=legend_patches,
        loc="lower left",
        bbox_to_anchor=(-0.3, -0.1),  # fancybox=True, shadow=True
    )

In [None]:
# for all proteins engines together - radial plot

# Set up polar plot
fig, ax = plt.subplots(figsize=(10, 10), subplot_kw={"projection": "polar"})

categories = list([prot_dict_name[key] for key in prot_dict_name.keys()])

# Convert categories to base angles (evenly spaced)
base_angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False)
category_to_angle = {
    category: angle for category, angle in zip(categories, base_angles)
}

# Define max angle spread per category (e.g., ±10 degrees)
max_angle_spread = np.radians(10)

mae_thresholds = [0.5, 1, 2, 3]
norm = matplotlib.colors.Normalize(vmin=0, vmax=len(mae_thresholds))

for i, threshold in enumerate(mae_thresholds):
    for category, angle in category_to_angle.items():
        ax.fill_between(
            # Need high res or you'll fill a triangle
            np.linspace(
                angle - threshold * max_angle_spread / 2,
                angle + threshold * max_angle_spread / 2,
            ),
            -4,
            4,
            alpha=0.1,
            color=cm.Grays(norm(i), bytes=False),
        )

for eng in ana_obj.engines:
    data = []

    for prot in prot_dict_name.keys():
        ana_obj = ana_obj_dict[prot]["plain"]
        x = np.array(
            [
                ana_obj.exper_pert_dict[val][0]
                for val in ana_obj.cinnabar_calc_pert_dict[eng]
                if "Intermediate" not in val
            ]
        )
        y = np.array(
            [
                ana_obj.cinnabar_calc_pert_dict[eng][val][0]
                for val in ana_obj.cinnabar_calc_pert_dict[eng]
                if "Intermediate" not in val
            ]
        )
        abs_error = y - x  # Absolute error calculation - np.abs
        for xi, yi, err in zip(x, y, abs_error):
            data.append([prot_dict_name[prot], xi, yi, err])

    df = pd.DataFrame(
        data,
        columns=[
            "Category",
            "Experimental dG (kcal/mol)",
            "Predicted dG (kcal/mol)",
            "Error",
        ],
    )

    # Create a color mapping using the "plasma" colormap
    cmap = cm.get_cmap("plasma", len(categories))
    category_color_map = {
        cat: cmap(i / (len(categories) - 1)) for i, cat in enumerate(categories)
    }

    # Plot each category
    for category, base_angle in category_to_angle.items():
        subset = df[df["Category"] == category]

        r = subset["Experimental dG (kcal/mol)"]  # Radial distance
        abs_err_norm = subset[
            "Error"
        ]  # / subset["Absolute_Error"].max()  # Normalize error
        angles = (
            base_angle + (abs_err_norm) * max_angle_spread / 2
        )  # Spread points in segment

        sc = ax.scatter(angles, r, c=col_dict[eng_dict_name[eng]], edgecolors=None)

# Customize the plot
ax.set_xticks(base_angles)
ax.set_xticklabels(categories, fontsize=12)
ax.set_ylim(
    df["Experimental dG (kcal/mol)"].min() - 0.5,
    df["Experimental dG (kcal/mol)"].max() + 0.5,
)  # Adjust radial range
ax.set_title(
    f"Radial Error Plot for all engines (Angle = Error, Radius = Experimental ΔΔG)"
)
# ax.set_xlabel("Experimental dG (kcal/mol)")
ax.text(
    0.68,
    0.56,
    f"Experimental ΔΔG (kcal/mol)",
    transform=ax.transAxes,
    fontsize=10,
    color="black",
    ha="left",
    va="top",
    rotation=21,
    rotation_mode="anchor",
)

# legend_patches = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=col_dict,
#                             markersize=10, label=eng_dict_name[eng]) for eng in ana_obj.engines]
# ax.legend(handles=legend_patches, loc='lower left', bbox_to_anchor=(
#         -0.3,-0.1), #fancybox=True, shadow=True
#         )

In [None]:
# get the cinnabar stats into a dict
net_ana_method_dict = {"method": [], "engine": [], "protein": [], "value": []}

for method in list(ana_dicts.keys()):  # + ["single_0", "single_1", "single_2"]:
    print(method)
    for eng in ana_obj.engines:
        overall_dg_list = []

        for prot in ana_obj_dict.keys():
            print(prot, eng)
            dg_list = []

            if "single" in method:
                print(f"method is {method}!")
                ana_obj = ana_obj_dict[prot][method]
                ana_obj.calc_pert_dict = ana_obj.calc_repeat_pert_dict[eng][
                    int(method.split("_")[-1])
                ]
            else:
                ana_obj = ana_obj_dict[prot][method]
                ana_obj.calc_pert_dict = ana_obj.calc_pert_dict[eng]

            for key in ana_obj.calc_pert_dict.keys():
                if key not in ana_obj._perturbations_dict[eng]:
                    print(f"{key} not in pert dict")
                    continue
                try:
                    value = abs(
                        ana_obj.calc_pert_dict[key][0] - ana_obj.exper_pert_dict[key][0]
                    )
                    # if value > 10:
                    #     print(prot, eng, key, value)
                    # else:
                    dg_list.append(value)
                except:
                    print(f"{key} not in dict for {eng} {method}")

            net_ana_method_dict["method"].append(
                [method for l in range(0, len(dg_list))]
            )
            net_ana_method_dict["engine"].append([eng for l in range(0, len(dg_list))])
            net_ana_method_dict["protein"].append([prot for val in dg_list])
            net_ana_method_dict["value"].append([val for val in dg_list])
            overall_dg_list.append(dg_list)

        print(
            f"{eng} {method} mean is {np.mean([dg for dg in flatten_comprehension(overall_dg_list) if dg])}"
        )


plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE ddG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

df = pd.DataFrame(plotting_dict)
ax = sns.boxplot(
    df, x="MD engine", y="MAE ddG (kcal/mol)", hue="method", palette="plasma"
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

In [None]:
ax = sns.displot(df, x="MAE ddG (kcal/mol)", hue="MD engine", palette="plasma")
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

In [None]:
ax = sns.barplot(
    df, x="MD engine", y="MAE ddG (kcal/mol)", hue="method", errorbar=("ci")  # 95%
)
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

val_dict = {
    method: {}
    for method in list(ana_dicts.keys()) + ["single_0", "single_1", "single_2"]
}

for p, l, name in zip(
    ax.patches,
    ax.lines,
    it.product(
        list(ana_dicts.keys()) + ["single_0", "single_1", "single_2"], ana_obj.engines
    ),
):
    xy = l.get_xydata()
    # print(f"{name}: {p.get_height():.2f} ({xy[0][1]:.2f}, {xy[1][1]:.2f})")
    val_dict[name[0]][
        name[1]
    ] = f"{p.get_height():.2f} ({xy[0][1]:.2f}, {xy[1][1]:.2f})"

df_val = pd.DataFrame(val_dict)
df_val.T

In [None]:
# get the perturbations that are 'well' converged

splitby = "all"
ana_obj_dict = network_dict["combined"]

con_vals = []
noncon_vals = []
all_vals = []
for prot in ana_obj_dict:
    ana_obj = ana_obj_dict[prot]["plain"]
    ana_obj.check_convergence(compute_missing=False)

    for eng in ["AMBER"]:
        print(prot, eng)
        con_perts = []
        noncon_perts = []
        for pert in ana_obj.convergence_dict[eng]:
            try:
                if splitby == "all":
                    con_arr = np.array(
                        [val for val in ana_obj.convergence_dict[eng][pert].values()]
                    )
                if splitby == "free":
                    con_arr = np.array(
                        [
                            ana_obj.convergence_dict[eng][pert][key]
                            for key in ana_obj.convergence_dict[eng][pert]
                            if "free" in key
                        ]
                    )
                if splitby == "bound":
                    con_arr = np.array(
                        [
                            ana_obj.convergence_dict[eng][pert][key]
                            for key in ana_obj.convergence_dict[eng][pert]
                            if "bound" in key
                        ]
                    )

                con_arr = [c for c in con_arr if c]

                # con_arr = con_arr[~np.isnan(con_arr)]
                mean = np.mean(con_arr)
                if mean > 0.75:
                    con_perts.append(pert)
                else:
                    noncon_perts.append(pert)
            except:
                print(pert, "failed")

        con_vals.append(
            [
                abs(
                    ana_obj.calc_pert_dict[eng][pert][0]
                    - ana_obj.exper_pert_dict[pert][0]
                )
                for pert in con_perts
                if "Intermediate" not in pert
            ]
        )
        noncon_vals.append(
            [
                abs(
                    ana_obj.calc_pert_dict[eng][pert][0]
                    - ana_obj.exper_pert_dict[pert][0]
                )
                for pert in noncon_perts
                if "Intermediate" not in pert
            ]
        )
        all_vals.append(
            [
                abs(
                    ana_obj.calc_pert_dict[eng][pert][0]
                    - ana_obj.exper_pert_dict[pert][0]
                )
                for pert in ana_obj.convergence_dict[eng]
                if "Intermediate" not in pert
            ]
        )
        print(len(con_vals), len(noncon_vals), len(all_vals))

plt.figure(figsize=(6, 6), dpi=500)
plt.hist(
    flatten_comprehension(con_vals),
    density=True,
    color="magenta",
    label="Converged",
    alpha=0.5,
)
plt.hist(
    flatten_comprehension(noncon_vals),
    density=True,
    color="plum",
    label="Non-converged",
    alpha=0.5,
)
# plt.hist(flatten_comprehension(all_vals), density=True, color="darkblue", label="All", alpha=0.5)
print(
    np.mean(flatten_comprehension(con_vals)),
    np.mean(flatten_comprehension(noncon_vals)),
    np.mean(flatten_comprehension(all_vals)),
)
plt.legend(loc="upper right")
plt.xlabel(f"MAE (kcal/mol)")
plt.ylabel(f"Density")

***cycle closures***

In [None]:
# cycle closures

cc_dict = {}

for prot in ana_obj_dict.keys():
    for name in ana_dicts:
        ana_obj = ana_obj_dict[prot][name]

        print(prot, name)
        ana_obj.compute_cycle_closures()
        cc_dict[f"{prot_dict_name[prot]}"] = ana_obj.cycle_dict

# plot the cycle closures
# plot the errors
df = pd.DataFrame.from_dict(cc_dict).transpose()

df_ci = df.map(lambda x: x[3]).rename(eng_dict_name, axis=1)
df_mean = df.map(lambda x: x[1]).fillna(0).rename(eng_dict_name, axis=1)
df_low = df_mean - df_ci.map(lambda x: x[0])
df_high = df_ci.map(lambda x: x[1]) - df_mean
df_err = np.stack([df_low.T.values, df_high.T.values], axis=1)

fig, ax = plt.subplots(figsize=(8, 5), dpi=500)
df_mean.plot.bar(
    color=col_dict,
    yerr=df_err,
    xlabel="Protein System",
    ylabel="Cycle Closure Error (kcal/mol)",
    ax=ax,
)
plt.tick_params(axis="x", rotation=0)
plt.legend(loc="upper left")

***grow/shrink***

In [None]:
# directionality, data from denoting perturbations as grow or shrink

ana_obj_dict = network_dict["combined"]

grow_shrink_dict = {}
for eng in ana_obj.engines:
    grow_shrink_dict[eng] = {}

df_list = {}
for eng in ana_obj.engines:
    df_list[eng] = []

for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = ana_obj_dict[prot]["plain"]

    for eng in ana_obj.engines:
        diff_dict = {
            key: abs(
                ana_obj.calc_pert_dict[eng][key][0] - ana_obj.exper_pert_dict[key][0]
            )
            for key in ana_obj._perturbations_dict[eng]
            if not "Intermediate" in key
        }
        error_dict = {
            key: abs(ana_obj.calc_pert_dict[eng][key][1])
            for key in ana_obj._perturbations_dict[eng]
        }

        df = pd.read_csv(
            f"{join('/', *ana_obj.output_folder.split('/')[:-1])}/execution_model/grow_shrink_featurise.dat"
        )
        df[f"error_{eng}"] = df["pert"].map(error_dict)
        df[f"diff_{eng}"] = df["pert"].map(diff_dict)
        df = df.dropna()

        df_list[eng].append(df)

        group1 = df.loc[df["grow/shrink"] == "grow"][f"error_{eng}"]
        group2 = df.loc[df["grow/shrink"] == "shrink"][f"error_{eng}"]
        group3 = df.loc[df["grow/shrink"] == "same"][f"error_{eng}"]
        print("grow ", len(group1), "shrink ", len(group2), "same ", len(group3))
        ustats, pvalue = _stats.mannwhitneyu(group1, group2)
        print(f"mann u for error {eng}: {ustats, pvalue}")
        print(
            f"mean for error {eng} grow: {np.mean(group1)}, and for shrink: {np.mean(group2)}, and for same: {np.mean(group3)}"
        )

        group1 = df.loc[df["grow/shrink"] == "grow"][f"diff_{eng}"]
        group2 = df.loc[df["grow/shrink"] == "shrink"][f"diff_{eng}"]
        group3 = df.loc[df["grow/shrink"] == "same"][f"diff_{eng}"]
        ustats, pvalue = _stats.mannwhitneyu(group1, group2)
        print(f"mann u for diff to exp {eng}: {ustats, pvalue}")
        print(
            f"mean for diff to exp {eng} grow: {np.mean(group1)}, and for shrink: {np.mean(group2)} and for same: {np.mean(group3)}"
        )

# across the systems
print("all")

for eng in ana_obj.engines:
    df = pd.concat(df_list[eng], ignore_index=True)

    group1 = df.loc[df["grow/shrink"] == "grow"][f"error_{eng}"]
    group2 = df.loc[df["grow/shrink"] == "shrink"][f"error_{eng}"]
    group3 = df.loc[df["grow/shrink"] == "same"][f"error_{eng}"]
    print("grow ", len(group1), "shrink ", len(group2), "same ", len(group3))

    ustats, pvalue = _stats.mannwhitneyu(group1, group2)
    print(f"mann u for error {eng}: {ustats, pvalue}")
    print(
        f"mean for error {eng} grow: {np.mean(group1)}, and for shrink: {np.mean(group2)}, and for same: {np.mean(group3)}"
    )

    # Plotting the distribution of the stats test
    plt.figure(figsize=(10, 6))

    sns.histplot(group1, color="cadetblue", label="Grow", kde=True)
    sns.histplot(group2, color="pink", label="Shrink", kde=True)
    sns.histplot(group3, color="plum", label="Same", kde=True)

    # sns.kdeplot(group1, color='cadetblue', label='Grow')
    # sns.kdeplot(group2, color='pink', label='Shrink')
    # sns.kdeplot(group3, color='plum', label='Same')

    # # Plotting the mean and std as a shaded area
    # mean1, std1 = np.mean(group1), np.std(group1)
    # mean2, std2 = np.mean(group2), np.std(group2)
    # mean3, std3 = np.mean(group3), np.std(group3)

    # upper_ylim = plt.gca().get_ylim()[1]

    # plt.axvline(mean1, color='cadetblue', linestyle='--')
    # plt.axvline(mean2, color='pink', linestyle='--')
    # plt.axvline(mean3, color='plum', linestyle='--')

    # plt.fill_betweenx([0, upper_ylim], mean1 - std1, mean1 + std1, color='cadetblue', alpha=0.2)
    # plt.fill_betweenx([0, upper_ylim], mean2 - std2, mean2 + std2, color='pink', alpha=0.2)
    # plt.fill_betweenx([0, upper_ylim], mean3 - std3, mean3 + std3, color='plum', alpha=0.2)

    plt.title(f"{eng_dict_name[eng]}")
    plt.xlabel("SEM (kcal/mol)")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

    group1 = df.loc[df["grow/shrink"] == "grow"][f"diff_{eng}"]
    group2 = df.loc[df["grow/shrink"] == "shrink"][f"diff_{eng}"]
    group3 = df.loc[df["grow/shrink"] == "same"][f"diff_{eng}"]

    ustats, pvalue = _stats.mannwhitneyu(group1, group2)
    print(f"mann u for diff to exp {eng}: {ustats, pvalue}")
    print(
        f"mean for diff to exp {eng} grow: {np.mean(group1)}, and for shrink: {np.mean(group2)} and for same: {np.mean(group3)}"
    )

    # Plotting the distribution of the stats test
    plt.figure(figsize=(10, 6))

    sns.histplot(group1, color="cadetblue", label="Grow", kde=True)
    sns.histplot(group2, color="pink", label="Shrink", kde=True)
    sns.histplot(group3, color="plum", label="Same", kde=True)

    # sns.kdeplot(group1, color='cadetblue', label='Grow')
    # sns.kdeplot(group2, color='pink', label='Shrink')
    # sns.kdeplot(group3, color='plum', label='Same')

    # # Plotting the mean and std as a shaded area
    # mean1, std1 = np.mean(group1), np.std(group1)
    # mean2, std2 = np.mean(group2), np.std(group2)
    # mean3, std3 = np.mean(group3), np.std(group3)

    # upper_ylim = plt.gca().get_ylim()[1]

    # plt.axvline(mean1, color='cadetblue', linestyle='--')
    # plt.axvline(mean2, color='pink', linestyle='--')
    # plt.axvline(mean3, color='plum', linestyle='--')

    # plt.fill_betweenx([0, upper_ylim], mean1 - std1, mean1 + std1, color='cadetblue', alpha=0.2)
    # plt.fill_betweenx([0, upper_ylim], mean2 - std2, mean2 + std2, color='pink', alpha=0.2)
    # plt.fill_betweenx([0, upper_ylim], mean3 - std3, mean3 + std3, color='plum', alpha=0.2)

    plt.title(f"{eng_dict_name[eng]}")
    plt.xlabel("MAE (kcal/mol)")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

# if below 0.05 (if confidence interval) there is significant difference (reject null hypothesis)

In [None]:
# different between engines significant?

res_dict = {}

for size in ["grow_err", "shrink_err", "grow_diff", "shrink_diff"]:
    res_dict[size] = {}

    for eng in ana_obj.engines:
        res_dict[size][eng] = {}

    for combo in it.product(grow_shrink_dict.keys(), grow_shrink_dict.keys()):
        eng1 = combo[0]
        eng2 = combo[1]

        if eng1 == eng2:
            continue

        group1 = grow_shrink_dict[eng1][size]
        group2 = grow_shrink_dict[eng2][size]

        ustats, pvalue = _stats.mannwhitneyu(group1, group2)
        print(f"{eng1, eng2}, {size}: {ustats, pvalue}")
        print(f"mean for {eng1}: {np.mean(group1)}, and for {eng2}: {np.mean(group2)}")

        res_dict[size][eng1][eng2] = pvalue

***AUTOEQUILIBRATION***

In [None]:
ac_dict_avg = {}

overall_array = {}
for eng in ana_obj.engines:
    overall_array[eng] = np.array([])

for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = ana_obj_dict[prot]["plain"]
    ana_obj.check_Ac(compute_missing=False)

    ac_dict_avg[prot] = {}

    for eng in ana_obj.engines:
        # print(eng)
        ac_dict_avg[prot][eng] = []
        for pert in ana_obj.ac_dict[eng].keys():
            pert_ac_list = []
            try:
                for key in ana_obj.ac_dict[eng][pert].keys():
                    if ana_obj.ac_dict[eng][pert][key] != None:
                        pert_ac_list.append(ana_obj.ac_dict[eng][pert][key])
            except:
                pass

            ac_dict_avg[prot][eng].append(pert_ac_list)

        ac_dict_avg[prot][eng] = flatten_comprehension(ac_dict_avg[prot][eng])
        print(prot, eng, np.mean(ac_dict_avg[prot][eng]), len(ac_dict_avg[prot][eng]))
        overall_array[eng] = np.concatenate(
            [overall_array[eng], ac_dict_avg[prot][eng]]
        )
for eng in ana_obj.engines:
    print(eng, np.mean(overall_array[eng]), len(overall_array[eng]))

In [None]:
# default autoeq time for each engine
# recalculate eq times if needed
for prot in ana_obj_dict.keys():
    try:
        print(prot)
        ana_obj = ana_obj_dict[prot]["autoeq"]
        ana_obj.compute_equilibration_times(compute_missing=False)
    except Exception as e:
        print(e)
        print(f"could not for {prot}")

# check equilibration times
eq_dict_avg = {}
std_dict_avg = {}
sem_dict_avg = {}
overall_array = {}
for eng in ana_obj_dict["tyk2"]["autoeq"].engines:
    overall_array[eng] = np.array([])

for prot in ana_obj_dict.keys():
    print(prot)

    eq_dict_avg[prot] = {}
    std_dict_avg[prot] = {}
    sem_dict_avg[prot] = {}

    ana_obj = ana_obj_dict[prot]["autoeq"]

    for eng in ana_obj.engines:
        # print(eng)
        eq_dict_avg[prot][eng] = []
        for pert in ana_obj.eq_times_dict[eng].keys():
            eq_dict_avg[prot][eng].append(
                [
                    ana_obj.eq_times_dict[eng][pert][key]["mean"]
                    for key in ana_obj.eq_times_dict[eng][pert].keys()
                    if ana_obj.eq_times_dict[eng][pert][key]["mean"] != None
                ]
            )

        eq_dict_avg[prot][eng] = flatten_comprehension(eq_dict_avg[prot][eng])
        print(prot, eng, eq_dict_avg[prot][eng])
        overall_array[eng] = np.concatenate(
            [overall_array[eng], eq_dict_avg[prot][eng]]
        )

        sem_dict_avg[prot][eng] = _stats.sem(eq_dict_avg[prot][eng])
        std_dict_avg[prot][eng] = _stats.tstd(eq_dict_avg[prot][eng])
        eq_dict_avg[prot][eng] = np.mean(eq_dict_avg[prot][eng])

df = pd.DataFrame.from_dict(eq_dict_avg).transpose() * 100  # transpose for per engine
df_sem = (
    pd.DataFrame.from_dict(sem_dict_avg).transpose() * 100
)  # transpose for per engine
df_std = (
    pd.DataFrame.from_dict(std_dict_avg).transpose() * 100
)  # transpose for per engine
print(df)

dict_lower = {}
dict_higher = {}
for eng in ana_obj.engines:
    # check normally dist
    # if not check_normal_dist(overall_array[eng]):
    #     print("not normal distribution")

    mean = np.mean(overall_array[eng])
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95,
        loc=np.mean(overall_array[eng]),
        scale=_stats.sem(overall_array[eng]),
    )
    print(eng, mean, lower_ci, upper_ci)
    dict_lower[eng] = [lower_ci * 100]
    dict_higher[eng] = [upper_ci * 100]

mean = np.mean(flatten_comprehension(overall_array.values()))
lower_ci, upper_ci = _stats.norm.interval(
    confidence=0.95,
    loc=np.mean(flatten_comprehension(overall_array.values())),
    scale=_stats.sem(flatten_comprehension(overall_array.values())),
)
print("all", mean, lower_ci, upper_ci)
dict_lower[eng] = [lower_ci * 100]
dict_higher[eng] = [upper_ci * 100]

In [None]:
# plot the average for the engines and also per system
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), sharex=False, sharey=True, dpi=500)

df_plot = df
df_plot.rename(prot_dict_name, inplace=True)
df_plot.rename(eng_dict_name, inplace=True, axis=1)

# plt.tick_params(axis="x", labelsize=10, rotation=45)
# plt.tick_params(axis="y", labelsize=10)

df_plot.T.mean().plot.bar(
    color="purple",
    yerr=df_plot.T.sem(),
    xlabel="Protein System",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[0],
)
plt.tick_params(axis="x", rotation=0)
# plt.tick_params(axis="x", labelsize=10, rotation=45)
# plt.tick_params(axis="y", labelsize=10)

df_plot.plot.bar(
    color=col_dict.values(),
    yerr=df_sem,
    xlabel="Protein System",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[1],
)
plt.tick_params(axis="x", rotation=0)
plt.legend(loc="best")
loc='', bbox_to_anchor=(0.5, 0.5)
df_plot.mean().plot.bar(
    color=col_dict.values(),
    yerr=df_plot.sem(),
    xlabel="MD engine",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
    ax=axes[2],
)
plt.tick_params(axis="x", rotation=0)

In [None]:
try:
    df = df.set_index("protein")
except:
    pass
# plotting w the bars representing how much of the average is each engine
df.div(df.sum(axis=1), axis=0).mul(df.mean(axis=1), axis=0).plot(
    kind="bar",
    stacked=True,
    color=pipeline.analysis.set_colours(),
    xlabel="protein system",
    ylabel="Average amount of run\ndiscarded by auto-equilibration (%)",
)
plt.errorbar(x=df.index, y=df.T.mean(), yerr=df.T.sem(), ecolor="black", linestyle="")

***Network Analysis***

In [None]:
# plotting statistics

plotting_dict = {}

stats_name = "ΔG MAE (kcal/mol)"
print(stats_name)

for network in [
    "lomap",
    "rbfenn",
    "flare",
    "combined",
    "lomap-a-optimal",
    "lomap-d-optimal",
    "rbfenn-a-optimal",
    "rbfenn-d-optimal",
]:  # , "lomap-a-optimal", "lomap-d-optimal", "rbfenn-a-optimal", "rbfenn-d-optimal"
    plotting_dict[network] = {}

    for prot in ["tyk2", "mcl1"]:  # ana_obj_dict.keys():
        print(network, prot)
        plotting_dict[network][prot] = {}

        try:
            ana_obj = network_dict[network][prot]["plain"]
            # print(len(ana_obj.perturbations))

            pert_val = "val"

            if stats_name == "ΔG MAE (kcal/mol)":
                func = ana_obj.calc_mae_engines
            if stats_name == "ΔΔG MAE (kcal/mol)":
                func = ana_obj.calc_mae_engines
                pert_val = "pert"
            elif stats_name == "$r^2$":
                func = ana_obj.calc_r2_engines
            elif stats_name == "Kendall's Tau":
                func = ana_obj.calc_kendalls_rank_engines
            else:
                print("no")

            stats_string_all = ""
            mae = func(pert_val=pert_val, recalculate=False)

            for eng in ana_obj.engines:
                # print(
                #     f"{eng} MAE: {mae[0][eng]['experimental']:.2f} {mae[2][eng]['experimental']}"
                # )
                plotting_dict[network][prot][eng] = (
                    mae[0][eng]["experimental"],
                    mae[1][eng]["experimental"],
                    mae[2][eng]["experimental"],
                )

        except:
            for eng in network_dict["lomap"]["tyk2"]["plain"].engines:
                plotting_dict[network][prot][eng] = (0, 0, [0, 0])

In [None]:
name = "lomap"
df = (
    pd.DataFrame(plotting_dict[name])
    .applymap(lambda x: x[0])
    .rename(eng_dict_name)
    .T.rename(prot_dict_name)
)
df_err = (
    pd.DataFrame(plotting_dict[name])
    .applymap(lambda x: x[1])
    .rename(eng_dict_name)
    .T.rename(prot_dict_name)
)

# df_lower = df - df_err.applymap(lambda x: x[0])
# df_upper = df_err.applymap(lambda x: x[1]) - df
# df_err = np.stack([df_lower.T.values, df_upper.T.values], axis=1)

# df.drop(["AMBER_adjusted_GROMACS"], axis=1, inplace=True)
# df_err.drop(["AMBER_adjusted_GROMACS"], axis=1, inplace=True)
fig, ax = plt.subplots(figsize=(5, 5), dpi=500)

ax = df.plot(
    kind="bar",
    color=col_dict,
    yerr=df_err,
    ax=ax,
)
plt.legend(loc="upper left")
plt.xlabel("Protein System")
plt.ylabel(f"{stats_name}")
plt.tick_params(axis="x", rotation=0)
plt.tick_params(axis="y")

In [None]:
# graphs based on engine
ana_obj = network_dict["lomap"]["tyk2"]["plain"]
fig, axes = plt.subplots(
    nrows=1, ncols=3, figsize=(15, 5), sharex=True, sharey=True, dpi=500
)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    df_lomap = (
        pd.DataFrame(plotting_dict["lomap"])
        .applymap(lambda x: x[0])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "LOMAP-score"}, axis=1)
        .rename(prot_dict_name)
    )
    df_rbfenn = (
        pd.DataFrame(plotting_dict["rbfenn"])
        .applymap(lambda x: x[0])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "RBFENN-score"}, axis=1)
        .rename(prot_dict_name)
    )
    df = df_lomap.merge(df_rbfenn, left_index=True, right_index=True)

    df_lomap = (
        pd.DataFrame(plotting_dict["lomap"])
        .applymap(lambda x: x[2])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "LOMAP-score"}, axis=1)
        .rename(prot_dict_name)
    )
    df_rbfenn = (
        pd.DataFrame(plotting_dict["rbfenn"])
        .applymap(lambda x: x[2])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "RBFENN-score"}, axis=1)
        .rename(prot_dict_name)
    )
    df_err = df_lomap.merge(df_rbfenn, left_index=True, right_index=True)

    df_lower = df_err.applymap(lambda x: x[0])
    df_upper = df_err.applymap(lambda x: x[1])
    df_err = (df_upper - df_lower) / 2

    # engine colours
    col_dict = {
        "AMBER": ["orange", "moccasin"],
        "SOMD": ["darkturquoise", "paleturquoise"],
        "GROMACS": ["orchid", "plum"],
    }
    df.plot(
        kind="bar",
        color=col_dict[engine],  # ["mediumslateblue","indigo"],
        yerr=df_err,
        title=eng_dict_name[engine],
        ax=pos,
        xlabel="Protein System",
        ylabel=f"{stats_name}",
    )
    pos.tick_params(axis="x", rotation=0)
    pos.legend(loc="upper right")
# fig.suptitle(f'{stats_name} perturbations for LOMAP/RBFENN-score')

In [None]:
# number of perts in each network normalised by the no of ligands

no_ligands_dict = {
    "kuhn": [16, 42, 34],
    "lomap": [17, 15, 34],
    "rbfenn": [17, 15, 34],
    "flare": [17, 15, 34],
    "combined": [17, 15, 34],
    "lomap-a-optimal": [17, 15, 34],
    "lomap-d-optimal": [17, 15, 34],
    "rbfenn-a-optimal": [17, 15, 34],
    "rbfenn-d-optimal": [17, 15, 34],
}
no_perts_dict = {
    "kuhn": [23, 70, 54],
    "lomap": [24, 17, 49],
    "rbfenn": [30, 20, 51],
    "flare": [24, 19, 62],
    "combined": [62, 53, 132],
    "lomap-a-optimal": [24, 17, 49],
    "lomap-d-optimal": [24, 17, 49],
    "rbfenn-a-optimal": [30, 20, 51],
    "rbfenn-d-optimal": [30, 20, 51],
}
normalised_dict = {}
for key in no_ligands_dict.keys():
    normalised_dict[key] = [
        pert / val for pert, val in zip(no_perts_dict[key], no_ligands_dict[key])
    ]

# [1.4,1.7,1.3,1.4,1.4,1.4,1.8,1.3,1.5,2.6,2.5,2.6]

In [None]:
# for a specific engine
# can adjust which are getting plotted in reduce

fig, ax = plt.subplots(figsize=(5, 3.25), dpi=500)
plt.xlim = ()
plt.ylim = ()

ana_obj = network_dict["lomap"]["tyk2"]["plain"]
engine = "GROMACS"

df_lomap = (
    pd.DataFrame(plotting_dict["lomap"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-score"}, axis=1)
)
df_rbfenn = (
    pd.DataFrame(plotting_dict["rbfenn"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-score"}, axis=1)
)
df_flare = (
    pd.DataFrame(plotting_dict["flare"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "Flare"}, axis=1)
)
df_lomapa = (
    pd.DataFrame(plotting_dict["lomap-a-optimal"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-A-optimal"}, axis=1)
)
df_lomapd = (
    pd.DataFrame(plotting_dict["lomap-d-optimal"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-D-optimal"}, axis=1)
)
df_rbfenna = (
    pd.DataFrame(plotting_dict["rbfenn-a-optimal"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-A-optimal"}, axis=1)
)
df_rbfennd = (
    pd.DataFrame(plotting_dict["rbfenn-d-optimal"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-D-optimal"}, axis=1)
)
df_combined = (
    pd.DataFrame(plotting_dict["combined"])
    .applymap(lambda x: x[0])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "Combined"}, axis=1)
)
df = reduce(
    lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
    [
        df_lomap,
        df_rbfenn,
        df_flare,
        df_combined,
    ],  # , df_lomap, df_lomapa, df_lomapd, df_rbfenn, df_rbfenna, df_rbfennd
)
# df.insert(0, "Kuhn et al.", [0.70, 0.94, 1.18], True) # dG
# df.insert(0, "Kuhn et al.", [0.54, 0.56, 0.54], True)  # ktau

# TODO consider error which
df_lomap = (
    pd.DataFrame(plotting_dict["lomap"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-score"}, axis=1)
)
df_rbfenn = (
    pd.DataFrame(plotting_dict["rbfenn"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-score"}, axis=1)
)
df_flare = (
    pd.DataFrame(plotting_dict["flare"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "Flare"}, axis=1)
)
df_lomapa = (
    pd.DataFrame(plotting_dict["lomap-a-optimal"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-A-optimal"}, axis=1)
)
df_lomapd = (
    pd.DataFrame(plotting_dict["lomap-d-optimal"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "LOMAP-D-optimal"}, axis=1)
)
df_rbfenna = (
    pd.DataFrame(plotting_dict["rbfenn-a-optimal"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-A-optimal"}, axis=1)
)
df_rbfennd = (
    pd.DataFrame(plotting_dict["rbfenn-d-optimal"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "RBFENN-D-optimal"}, axis=1)
)
# a d optimal
df_combined = (
    pd.DataFrame(plotting_dict["combined"])
    .applymap(lambda x: x[1])
    .T.rename(prot_dict_name)
    .drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
    .rename({engine: "Combined"}, axis=1)
)
df_err = reduce(
    lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
    [
        df_lomap,
        df_rbfenn,
        df_flare,
        df_combined,
    ],  # df_lomap, df_rbfenn, df_flare, df_combined
)

# df_err.insert(0, "Kuhn et al.", [0.14, 0.14, 0.15], True) # dG
# df_err.insert(0, "Kuhn et al.", [0.28, 0.28, 0.27], True)  # ktau

# df_low = df - df_err.map(lambda x: x[0])
# df_high = df_err.map(lambda x: x[1]) - df
# df_err = np.stack([df_low.T.values, df_high.T.values], axis=1)

df.plot(
    kind="bar",
    colormap="plasma",
    # color=[
    #     "mediumblue", "dodgerblue", "lightsteelblue",
    #     "darkviolet", "orchid", "thistle"
    # ],
    #     # "mediumslateblue",
    #     "darkturquoise",
    #     "paleturquoise",
    #     "cadetblue",
    #     "indigo",
    # ],  # ["mediumslateblue","indigo"],
    yerr=df_err,
    xlabel="Protein System",
    ylabel=f"{stats_name} (kcal/mol)",
    ax=ax,
    width=0.9
    # fontsize=8
)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# kuhn, lomap, rbfenn
# normalised lig pert ratio [1.4,1.7,1.3,1.4,1.4,1.4,1.8,1.3,1.5,2.6,2.5,2.6]
# number of perts [23, 70, 54, 24, 48, 48, 30, 45, 51, 44, 84, 88]
#
# for p, v in zip(ax.patches, [f"{v:.1f}" for v in flatten_comprehension([normalised_dict[val] for val in no_perts_dict])]):
#     ax.annotate(str(v), (p.get_x() + 0.03, 0.2), fontsize=7.5)
plt.legend(fontsize=8, loc="lower right")  # , bbox_to_anchor=(0.15, 1.4))
plt.xlabel("Protein System", fontsize=10)
plt.ylabel(f"{stats_name}", fontsize=10)
plt.tick_params(axis="both", which="major", labelsize=10, rotation=0)
plt.title(f"{eng_dict_name[engine]}")
# ax.set_ylim(top=1)

In [None]:
# other plot format - rbfenn and lomap next to each other on same graph
# ie graphs based on protein system

ana_obj = network_dict["lomap"]["tyk2"]["plain"]

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 5), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()

proteins = ["tyk2", "mcl1"]
for prot, pos in zip(proteins, [axes[0], axes[1]]):  # , axes[2]
    df_lomap = (
        pd.DataFrame(plotting_dict["lomap"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP"}, axis=1)
    )
    df_rbfenn = (
        pd.DataFrame(plotting_dict["rbfenn"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN"}, axis=1)
    )  # .rename({eng: f"{eng}_2" for eng in ana_obj.engines}, axis=0)
    df_lomapa = (
        pd.DataFrame(plotting_dict["lomap-a-optimal"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP-A-optimal"}, axis=1)
    )
    df_lomapd = (
        pd.DataFrame(plotting_dict["lomap-d-optimal"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP-D-optimal"}, axis=1)
    )
    df_rbfenna = (
        pd.DataFrame(plotting_dict["rbfenn-a-optimal"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN-A-optimal"}, axis=1)
    )
    df_rbfennd = (
        pd.DataFrame(plotting_dict["rbfenn-d-optimal"])
        .applymap(lambda x: x[0])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN-D-optimal"}, axis=1)
    )
    # df = df_lomap.merge(
    #     df_rbfenn, left_index=True, right_index=True
    # ).T  # pd.concat if renamed to 2
    # TODO some way diff colour for lomap or rbfenn
    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_lomap, df_lomapa, df_lomapd],  # df_lomap, df_rbfenn, df_flare, df_combined
    ).T
    df_lomap = (
        pd.DataFrame(plotting_dict["lomap"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP"}, axis=1)
    )
    df_rbfenn = (
        pd.DataFrame(plotting_dict["rbfenn"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN"}, axis=1)
    )  # .rename({eng: f"{eng}_2" for eng in ana_obj.engines}, axis=0)

    df_lomapa = (
        pd.DataFrame(plotting_dict["lomap-a-optimal"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP-A-optimal"}, axis=1)
    )
    df_lomapd = (
        pd.DataFrame(plotting_dict["lomap-d-optimal"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "LOMAP-D-optimal"}, axis=1)
    )
    df_rbfenna = (
        pd.DataFrame(plotting_dict["rbfenn-a-optimal"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN-A-optimal"}, axis=1)
    )
    df_rbfennd = (
        pd.DataFrame(plotting_dict["rbfenn-d-optimal"])
        .applymap(lambda x: x[1])
        .drop(labels=[eng for eng in proteins if eng != prot], axis=1)
        .rename({prot: "RBFENN-D-optimal"}, axis=1)
    )
    # df_err = df_lomap.merge(df_rbfenn, left_index=True, right_index=True).T
    # df_err.applymap(lambda x: (None, None) if x[0] is None else x)
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_lomap, df_lomapa, df_lomapd],  # df_lomap, df_rbfenn, df_flare, df_combined
    ).T
    # df_lower = df_err.applymap(lambda x: x[0])
    # df_upper = df_err.applymap(lambda x: x[1])
    # df_err = (df_upper - df_lower) / 2

    # engine colours
    re_col_dict = {
        "AMBER": ["orange", "moccasin", "oldlace"],
        "SOMD": ["darkturquoise", "paleturquoise", "azure"],
        "GROMACS": ["orchid", "plum", "pink"],
    }
    df.plot(
        kind="bar",
        color=re_col_dict,
        yerr=df_err,
        title=prot_dict_name[prot],
        ax=pos,
        xlabel="Network Method",
        ylabel=f"{stats_name}",
        legend=True,
    )
    pos.tick_params(axis="x", rotation=20)
    # key = pipeline.analysis.set_colours()
    # key.pop("experimental")
    # pos.legend(key)

In [None]:
# comparing the ddG MAE and SEM between the networks

# SEM differences collect first
sem_dict = {}
sem_dict_name = {}
sem_dict_eng = {}

for network in network_dict:
    sem_dict[network] = {}
    sem_dict_eng[network] = {}
    sem_list_name = []

    for prot in ["tyk2", "mcl1"]:
        sem_dict[network][prot] = {}

        ana_obj = network_dict[network][prot]["plain"]

        for eng in ana_obj.engines:
            sem_dict[network][prot][eng] = {}

            sem_list = []
            sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
            sem_list.append(sems)
            sem_list_name.append(sems)

            sem_list = reduce(lambda xs, ys: xs + ys, sem_list)
            sem_list = [x for x in sem_list if str(x) != "nan"]

            # if not check_normal_dist(sem_list):
            #     print(f"{prot} {name} not normally dist")

            mean = np.mean(sem_list)
            lower_ci, upper_ci = _stats.norm.interval(
                confidence=0.95, loc=np.mean(sem_list), scale=_stats.sem(sem_list)
            )
            print(prot, network, eng, mean, lower_ci, upper_ci)
            sem_dict[network][prot][eng] = (
                mean,
                _stats.tstd(sem_list),
                (lower_ci, upper_ci),
                sem_list,
            )

    # for all the network
    sem_list_name = reduce(lambda xs, ys: xs + ys, sem_list_name)
    sem_list_name = [x for x in sem_list_name if str(x) != "nan"]
    mean = np.mean(sem_list_name)
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95, loc=np.mean(sem_list_name), scale=_stats.sem(sem_list_name)
    )
    print(network, mean, lower_ci, upper_ci)
    sem_dict_name[network] = (
        mean,
        _stats.tstd(sem_list_name),
        (lower_ci, upper_ci),
        sem_list_name,
    )

    # for per engine
    for eng in ana_obj.engines:
        sem_list_eng = []
        for prot in ["tyk2", "mcl1"]:
            sem_list_eng.append(sem_dict[network][prot][eng][3])

        sem_list_eng = reduce(lambda xs, ys: xs + ys, sem_list_eng)
        sem_list_eng = [x for x in sem_list_eng if str(x) != "nan"]
        mean = np.mean(sem_list_eng)
        lower_ci, upper_ci = _stats.norm.interval(
            confidence=0.95, loc=np.mean(sem_list_eng), scale=_stats.sem(sem_list_eng)
        )
        print(network, mean, lower_ci, upper_ci)
        sem_dict_eng[network][eng] = (
            mean,
            _stats.tstd(sem_list_name),
            (lower_ci, upper_ci),
            sem_list_eng,
        )

In [None]:
# same format as above but w the SEM

df = pd.DataFrame(sem_dict_eng).applymap(lambda x: x[0]).T
df_err = pd.DataFrame(sem_dict_eng).applymap(lambda x: x[2]).T  # confidene intervals
df_lower = df_err.applymap(lambda x: x[0])
df_upper = df_err.applymap(lambda x: x[1])
df_err = (df_upper - df_lower) / 2

col_dict = {
    "AMBER": ["orange", "moccasin"],
    "SOMD": ["darkturquoise", "paleturquoise"],
    "GROMACS": ["orchid", "plum"],
}
df.plot(
    kind="bar",
    color=col_dict,
    yerr=df_err,
    xlabel="Network score method",
    ylabel=f"ddG SEM (kcal/mol)",
    legend=True,
)
plt.legend(title=False)
# key = pipeline.analysis.set_colours()
# key.pop("experimental")
# pos.legend(key)

In [None]:
# get the cinnabar stats into a dict
net_ana_method_dict = {"method": [], "engine": [], "protein": [], "value": []}
network = "lomap"

for eng in ["AMBER", "SOMD", "GROMACS"]:
    for prot in ana_obj_dict.keys():
        print(prot)
        dg_list = []

        ana_obj = network_dict[network][prot]["plain"]

        for key in ana_obj.cinnabar_calc_val_dict[eng].keys():
            value = abs(
                abs(
                    ana_obj.cinnabar_calc_val_dict[eng][key][0]
                    - ana_obj.cinnabar_exper_val_dict[eng][key][0]
                )
            )
            dg_list.append(value)
            if value > 5:
                print(prot, eng, key, value)

        net_ana_method_dict["method"].append(
            ["cinnabar" for l in range(0, len(dg_list))]
        )
        net_ana_method_dict["engine"].append(
            [eng_dict_name[eng] for l in range(0, len(dg_list))]
        )
        net_ana_method_dict["protein"].append([prot_dict_name[prot] for val in dg_list])
        net_ana_method_dict["value"].append([val for val in dg_list])

In [None]:
# also want to compare fwf and cinnabar
fwf_path = (
    "/home/anna/Documents/september_2022_workshops/freenrgworkflows/networkanalysis"
)

for eng in ana_obj.engines:
    for prot in ana_obj_dict.keys():
        print(eng, prot)
        dg_list = []

        ana_obj = network_dict[network][prot]["plain"]

        # add path for fwf
        ana_obj._add_fwf_path(fwf_path)
        ana_obj._get_exp_fwf()

        try:
            fwf_dict = ana_obj._get_ana_fwf(engine=eng, use_repeat_files=True)
        except:
            print(f"{prot} {eng} did not fwf w repeat files, trying w out")
            # try:
            #     fwf_dict = ana_obj._get_ana_fwf(engine=eng, use_repeat_files=False)
            # except:
            #     print("non repeat files also failed")

        try:
            di2 = {}
            for di in ana_obj._fwf_computed_DGs[eng]:
                di2[[k for k in di.keys()][0]] = di[[k for k in di.keys()][0]]
            # experimental computed normally outside of fwf and normalised
            for key in di2.keys():
                value = abs(di2[key] - ana_obj.normalised_exper_val_dict[key][0])
                dg_list.append(value)
                if value > 5:
                    print(prot, eng, key, value)
        except:
            print("did not fwf at all")

        net_ana_method_dict["method"].append(["fen" for l in range(0, len(dg_list))])
        net_ana_method_dict["engine"].append(
            [eng_dict_name[eng] for l in range(0, len(dg_list))]
        )
        net_ana_method_dict["protein"].append([prot_dict_name[prot] for val in dg_list])
        net_ana_method_dict["value"].append([val for val in dg_list])

In [None]:
from time import sleep

for prot in ana_obj_dict.keys():
    print(prot)
    sleep(5)
    ana_obj = network_dict["lomap"][prot]["plain"]
    ana_obj.check_html_exists(ana_obj.engines)

In [None]:
# compute all first
# for prot in ana_obj_dict.keys():
#     print(prot)
ana_obj = network_dict["lomap"]["syk"]["plain"]

for eng in ["AMBER"]:
    try:
        ana_obj.analyse_mbarnet(
            compute_missing=False,
            write_xml=False,
            run_xml_py=False,
            use_experimental=True,
            overwrite=False,
            engines=[eng],
            normalise=True,
        )
        print(ana_obj._mbarnet_computed_DGs[eng])
    except Exception as e:
        print(e)
        print(f"failed for {prot} {eng}")

In [None]:
# mbarnet

# compute all first
for prot in ana_obj_dict.keys():
    print(prot)
    ana_obj = network_dict["lomap"][prot]["plain"]

    for eng in ana_obj.engines:
        try:
            ana_obj.analyse_mbarnet(
                compute_missing=False,
                write_xml=False,
                run_xml_py=False,
                use_experimental=True,
                overwrite=False,
                engines=[eng],
                normalise=True,
            )
        except Exception as e:
            print(e)
            print(f"failed for {prot} {eng}")

for eng in ana_obj.engines:
    for prot in ana_obj_dict.keys():
        dg_list = []
        print(prot, eng)

        ana_obj = network_dict["lomap"][prot]["plain"]

        try:
            for key in ana_obj._mbarnet_computed_DGs[eng].keys():
                value = abs(
                    ana_obj._mbarnet_computed_DGs[eng][key][0]
                    - ana_obj.normalised_exper_val_dict[key][0]
                )
                dg_list.append(value)
                if value > 5:
                    print(prot, eng, key, value)
        except:
            pass

        net_ana_method_dict["method"].append(
            ["MBARNet" for l in range(0, len(dg_list))]
        )
        net_ana_method_dict["engine"].append(
            [eng_dict_name[eng] for l in range(0, len(dg_list))]
        )
        net_ana_method_dict["protein"].append([prot_dict_name[prot] for val in dg_list])
        net_ana_method_dict["value"].append([val for val in dg_list])

In [None]:
df_dict = {}

df_dict["cinnabar"] = {}
for prot in ana_obj_dict.keys():
    print(prot)
    df_dict["cinnabar"][prot] = {}
    ana_obj = network_dict[network][prot]["plain"]

    for eng in ana_obj.engines:
        try:
            df, df_err, df_ci = ana_obj.calc_mae_engines(
                engines=[eng], pert_val="val", recalculate=False
            )
            df_dict["cinnabar"][prot][eng] = (
                df[eng]["experimental"],
                df_err[eng]["experimental"],
            )
            # print(df[eng]["experimental"], df_err[eng]["experimental"])
        except:
            df_dict["cinnabar"][prot][eng] = (0, 0)

df_dict["fen"] = {}
for prot in ana_obj_dict.keys():
    df_dict["fen"][prot] = {}
    ana_obj = network_dict[network][prot]["plain"]

    for eng in ana_obj.engines:
        try:
            df, df_err, df_ci = ana_obj._get_stats_fwf(engines=[eng], statistic="MUE")
            df_dict["fen"][prot][eng] = (
                df[eng]["experimental"],
                df_err[eng]["experimental"],
            )
            # print(df[eng]["experimental"], df_err[eng]["experimental"])
        except:
            print("ooft")
            df_dict["fen"][prot][eng] = (0, 0)

df_dict["mbarnet"] = {}
for prot in ana_obj_dict.keys():
    df_dict["mbarnet"][prot] = {}
    print(prot)
    ana_obj = network_dict[network][prot]["plain"]
    for eng in ana_obj.engines:
        try:
            print(
                prot,
                eng,
                len(ana_obj._perturbations_dict[eng]),
                len(ana_obj._mbarnet_computed_DGs[eng]),
            )
        except Exception as e:
            print(e)

    for eng in ana_obj.engines:
        try:
            df, df_err, df_ci = ana_obj._get_stats_mbarnet(
                engines=[eng], statistic="MUE"
            )
            df_dict["mbarnet"][prot][eng] = (
                df[eng]["experimental"],
                df_err[eng]["experimental"],
            )
            # print(df[eng]["experimental"], df_err[eng]["experimental"])
        except Exception as e:
            print("oop")
            print(e)
            df_dict["mbarnet"][prot][eng] = (0, 0)

In [None]:
# compare stats
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    df_cinnabar = (
        pd.DataFrame(df_dict["cinnabar"])
        .applymap(lambda x: x[0])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "cinnabar"}, axis=1)
    )
    df_fen = (
        pd.DataFrame(df_dict["fen"])
        .applymap(lambda x: x[0])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "fen"}, axis=1)
    )
    df_mbarnet = (
        pd.DataFrame(df_dict["mbarnet"])
        .applymap(lambda x: x[0])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "mbarnet"}, axis=1)
    )
    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_cinnabar, df_fen, df_mbarnet],
    )

    df_cinnabar = (
        pd.DataFrame(df_dict["cinnabar"])
        .applymap(lambda x: x[1])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "cinnabar"}, axis=1)
    )
    df_fen = (
        pd.DataFrame(df_dict["fen"])
        .applymap(lambda x: x[1])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "fen"}, axis=1)
    )
    df_mbarnet = (
        pd.DataFrame(df_dict["mbarnet"])
        .applymap(lambda x: x[1])
        .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
        .rename({engine: "mbarnet"}, axis=1)
    )
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_cinnabar, df_fen, df_mbarnet],
    )

    # df_lower = df_err.applymap(lambda x: x[0])
    # df_upper = df_err.applymap(lambda x: x[1])
    # df_err = (df_upper - df_lower)/2

    df.plot(
        kind="bar",
        color=["purple", "orchid", "lavender"],
        yerr=df_err,
        title=engine,
        ax=pos,
        xlabel="protein system",
        ylabel=f"MAE",
    )

In [None]:
# plotting per system per negine

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

df = pd.DataFrame(plotting_dict)
df.drop(df.loc[df["method"] == "fen"].index, inplace=True)
df.drop(df.loc[df["method"] == "mbarnet"].index, inplace=True)

sns.barplot(
    df,
    x="Protein",
    y="MAE dG (kcal/mol)",
    hue="MD engine",
    palette=["darkorange", "turquoise", "orchid"],
    ax=axes,
).set_title("MAE")

In [None]:
plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

fig, ax = plt.subplots(figsize=(3.25, 3.25), dpi=500)
df = pd.DataFrame(plotting_dict)
sns.boxplot(
    df,
    x="MD engine",
    y="MAE dG (kcal/mol)",
    hue="method",
    palette=["purple", "orchid", "lavender"],
    ax=ax,
)
# modify individual font size of elements
plt.legend(fontsize=10)
plt.xlabel("MD Engine", fontsize=10)
plt.ylabel("MAE ΔG (kcal/mol)", fontsize=10)
plt.tick_params(axis="both", which="major", labelsize=10)
ax.set_ylim(top=12)

In [None]:
# plotting per system per negine

fig, axes = plt.subplots(nrows=1, ncols=1, figsize=(10, 10), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

df = pd.DataFrame(plotting_dict)
df.drop(df.loc[df["method"] == "fen"].index, inplace=True)
df.drop(df.loc[df["method"] == "mbarnet"].index, inplace=True)

sns.barplot(
    df,
    x="Protein",
    y="MAE dG (kcal/mol)",
    hue="MD engine",
    palette=["darkorange", "turquoise", "orchid"],
    ax=axes,
).set_title("MAE")

In [None]:
# plotting per system

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(20, 20), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    print(engine)
    plotting_dict = {
        "method": flatten_comprehension(net_ana_method_dict["method"]),
        "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
        "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
        "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
    }

    df = pd.DataFrame(plotting_dict)
    for eng in [eng for eng in ana_obj.engines if eng != engine]:
        df.drop(df.loc[df["MD engine"] == eng_dict_name[eng]].index, inplace=True)

    sns.boxplot(
        df,
        x="Protein",
        y="MAE dG (kcal/mol)",
        hue="method",
        palette=["purple", "orchid", "lavender"],
        ax=pos,
    ).set_title(eng_dict_name[engine])
    # pos.set_ylim(top=12)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=6, figsize=(25, 5), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
for prot, pos in zip(
    ana_obj_dict.keys(), [axes[0], axes[1], axes[2], axes[3], axes[4], axes[5]]
):
    plotting_dict = {
        "method": flatten_comprehension(net_ana_method_dict["method"]),
        "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
        "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
        "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
    }

    df = pd.DataFrame(plotting_dict)
    for eng in [eng for eng in ana_obj_dict.keys() if eng != prot]:
        df.drop(df.loc[df["Protein"] == prot_dict_name[eng]].index, inplace=True)

    sns.boxplot(
        df,
        x="MD engine",
        y="MAE dG (kcal/mol)",
        hue="method",
        palette=["purple", "orchid", "lavender"],
        ax=pos,
    ).set_title(prot.upper())

In [None]:
sns.barplot(
    df,
    x="Protein",
    y="MAE dG (kcal/mol)",
    hue="MD engine",
    palette=pipeline.analysis.set_colours().values(),
)