In [None]:
%matplotlib inline

# analysis per system

# import libraries
import sys
import seaborn as sns
import numpy as np
import scipy.stats as _stats
from scipy.interpolate import griddata
from functools import reduce
from pipeline.analysis import *
from pipeline.utils import *
from pipeline import *
import logging
import networkx as nx
import glob
from scipy.stats import sem as sem
from matplotlib import colormaps
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

from matplotlib.ticker import MaxNLocator

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)
# warnings.simplefilter(action='ignore', category=SettingWithCopyWarning)

logging.getLogger().setLevel(logging.INFO)

from cinnabar import wrangle as _wrangle


print(BSS.__file__)

In [None]:
def check_normal_dist(values):
    # check normally dist
    if len(values) < 50:
        stat, p = _stats.shapiro(values)
    else:
        stat, p = _stats.kstest(values)
    if p < 0.05:
        return True
    else:
        return False


def flatten_comprehension(matrix):
    return [item for row in matrix for item in row]

In [None]:
# define the analysis method to use
ana_dicts = {
    "plain": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "subsampling": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    "1ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 25,
        "name": None,
    },
    "2ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 50,
        "name": None,
    },
    "3ns": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": False,
        "statistical inefficiency": False,
        "truncate lower": 0,
        "truncate upper": 75,
        "name": None,
    },
    "autoeq": {
        "estimator": "MBAR",
        "method": "alchemlyb",
        "check overlap": True,
        "try pickle": True,
        "save pickle": True,
        "auto equilibration": True,
        "statistical inefficiency": True,
        "truncate lower": 0,
        "truncate upper": 100,
        "name": None,
    },
    # "TI": {
    # "estimator": "TI",
    # "method": "alchemlyb",
    # "check overlap": True,
    # "try pickle": True,
    # "save pickle": True,
    # "auto equilibration": False,
    # "statistical inefficiency": False,
    # "truncate lower": 0,
    # "truncate upper": 100,
    # "name": None,
    # },
    #     "single_0": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_1": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # },
    #     "single_2": {
    #     "estimator": "MBAR",
    #     "method": "alchemlyb",
    #     "check overlap": True,
    #     "try pickle": True,
    #     "save pickle": True,
    #     "auto equilibration": False,
    #     "statistical inefficiency": False,
    #     "truncate lower": 0,
    #     "truncate upper": 100,
    #     "name": None,
    # }
}

In [None]:
prot_dict_name = {
    "tyk2": "TYK2",
    "mcl1": "MCL1",
    "p38": "P38α",
    "syk": "SYK",
    "hif2a": "HIF2A",
    "cmet": "CMET",
}
eng_dict_name = {
    "AMBER": "AMBER22",
    "SOMD": "SOMD1",
    "GROMACS": "GROMACS23",
    "hahn": "Hahn et al.",
    "openfe": "OpenFE",
    "fepplus": "FEP+",
}

set_cols = pipeline.analysis.set_colours(
    other_results_names=["hahn", "openfe", "fepplus"]
)
col_dict = {}
for eng in eng_dict_name:
    col_dict[eng_dict_name[eng]] = set_cols[eng]

In [None]:
protein = "tyk2"

In [None]:
network_dict = {}

#
# , "combined", "lomap", "rbfenn", "flare", "lomap-a-optimal", "lomap-d-optimal", "rbfenn-a-optimal", "rbfenn-d-optimal"
for network in ["combined", "lomap", "rbfenn", "flare"]:
    # all the options
    ana_obj_dict = {}

    for ana_dict in ana_dicts.items():
        ana_prot = analysis_protocol(ana_dict[1])

        bench_folder = f"/home/anna/Documents/benchmark"
        # main_dir = f"{bench_folder}/reruns/{protein}"
        main_dir = f"/backup/{protein}"

        # # if need size of protein
        # try:
        #     prot = BSS.IO.readMolecules(
        #         [
        #             f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.gro",
        #             f"{bench_folder}/inputs/{protein}/{protein}_prep/{protein}.top",
        #         ]
        #     )[0]
        # except:
        #     prot = BSS.IO.readMolecules(
        #         [
        #             f"{bench_folder}/inputs/{protein}/{protein}_parameterised.prm7",
        #             f"{bench_folder}/inputs/{protein}/{protein}_parameterised.rst7",
        #         ]
        #     )[0]

        # print(f"no of residues in the protein: {prot.nResidues()}")

        # choose location for the files
        if protein == "syk" or protein == "cmet" or protein == "hif2a":
            # the lomap network
            if network == "lomap":
                net_file = f"{main_dir}/execution_model/network_lomap.dat"
            else:
                ana_obj_dict[protein][ana_dict[0]] = None
                continue
        elif protein == "p38":
            if (
                network == "lomap-a-optimal"
                or network == "lomap-d-optimal"
                or network == "rbfenn-a-optimal"
                or network == "rbfenn-d-optimal"
            ):
                ana_obj_dict[protein][ana_dict[0]] = None
                continue
            else:
                net_file = f"{main_dir}/execution_model/network_{network}.dat"

        else:
            net_file = f"{main_dir}/execution_model/network_{network}.dat"

        exp_file = f"{bench_folder}/inputs/experimental/{protein}.yml"
        output_folder = f"{main_dir}/outputs_extracted"

        # prot_file = f"{main_dir}/execution_model/protocol.dat" # no protocol used , name added after if needed
        pipeline_prot = pipeline_protocol(auto_validate=True)
        # pipeline_prot.name("")

        # initialise the network object
        all_analysis_object = analysis_network(
            output_folder,
            exp_file=exp_file,
            net_file=net_file,
            analysis_prot=ana_prot,
            # method=pipeline_prot.name(),  # if the protocol had a name
            # engines=pipeline_prot.engines(),
        )

        # compute
        try:
            all_analysis_object.compute_results()
        except:
            print("failed analysis")

        # add ligands folder
        all_analysis_object.add_ligands_folder(
            f"{bench_folder}/inputs/reruns/{protein}/ligands_intermediates"
        )

        ana_obj_dict[ana_dict[0]] = all_analysis_object

    network_dict[network] = ana_obj_dict

In [None]:
# set the network for the pertubation analysis
network = "combined"
ana_obj = network_dict[network]["plain"]

In [None]:
# identify any outliers and plot again if needed above
failed_perts_dict_percen = {}
failed_perts_dict = {}

failed_perts_dict_percen = {}
failed_perts_dict = {}

for eng in ana_obj.engines:  # ana_obj.engines
    failed_perts_dict_percen[eng] = 100 - ana_obj.successful_perturbations(eng)[1]
    failed_perts_dict[eng] = ana_obj.failed_perturbations(eng)
    print(
        f"failed percentage for {eng}: {100 - ana_obj.successful_perturbations(eng)[1]} ({len(ana_obj.perturbations) - len(ana_obj.successful_perturbations(eng)[2])} / {len(ana_obj.perturbations)})"
    )
    print(f"{eng} failed perturbations: {ana_obj.failed_perturbations(engine=eng)}")
    print(f"{eng} disconnected ligands: {ana_obj.disconnected_ligands(engine=eng)}")
    print(f"outliers 10 {eng}: {ana_obj.get_outliers(threshold=10, name=eng)}")
    print(f"outliers 5 {eng}: {ana_obj.get_outliers(threshold=5, name=eng)}")

    # ana_obj.remove_outliers(threshold=5, name=eng)

In [None]:
# analysis of the perturbation analysis methods

mae_dict = {}

sem_dict = {}
sem_dict_name = {}

for name in ana_dicts:
    sem_list_name = []
    sem_dict[name] = {}

    sem_dict[name][protein] = {}

    ana_obj = network_dict[network][name]

    for eng in ana_obj.engines:
        # print(name, eng)
        sem_dict[name][protein][eng] = {}

        sem_list = []
        sems = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
        sem_list.append(sems)
        sem_list_name.append(sems)

        sem_list = reduce(lambda xs, ys: xs + ys, sem_list)
        sem_list = [x for x in sem_list if str(x) != "nan"]

        # if not check_normal_dist(sem_list):
        #     print(f"{prot} {name} not normally dist")

        mean = np.mean(sem_list)
        lower_ci, upper_ci = _stats.norm.interval(
            confidence=0.95, loc=np.mean(sem_list), scale=_stats.sem(sem_list)
        )
        # print(protein, name, eng, mean, lower_ci, upper_ci)
        sem_dict[name][protein][eng] = (
            mean,
            _stats.tstd(sem_list),
            (lower_ci, upper_ci),
            sem_list,
        )

    sem_list_name = reduce(lambda xs, ys: xs + ys, sem_list_name)
    sem_list_name = [x for x in sem_list_name if str(x) != "nan"]
    mean = np.mean(sem_list_name)
    lower_ci, upper_ci = _stats.norm.interval(
        confidence=0.95, loc=np.mean(sem_list_name), scale=_stats.sem(sem_list_name)
    )
    # print(name, mean, lower_ci, upper_ci)
    sem_dict_name[name] = (
        mean,
        _stats.tstd(sem_list_name),
        (lower_ci, upper_ci),
        sem_list_name,
    )

    mae_dict[name] = {}

    mae_dict[name][protein] = {}

    ana_obj = network_dict[network][name]

    stats_string_all = ""
    try:
        mae = ana_obj.calc_mae_engines(pert_val="pert", recalculate=True)
    except Exception as e:
        print(e)

    for eng in ana_obj.engines:
        stats_string = ""
        try:
            mae_dict[name][protein][eng] = (
                mae[0][eng]["experimental"],
                mae[1][eng]["experimental"],
                mae[2][eng]["experimental"],
            )
            stats_string += f"{eng} MAE: {mae[0][eng]['experimental']:.2f} +/- {mae[1][eng]['experimental']:.2f} kcal/mol, "

            if sem_dict[name][protein][eng][0]:
                stats_string += f"SEM: {sem_dict[name][protein][eng][0]:.2f} +/- {sem_dict[name][protein][eng][1]:.2f} kcal/mol\n"
            elif name == "single":
                errors = [val[1] for val in ana_obj.calc_pert_dict[eng].values()]
                stats_string += f"error: {np.mean(errors):.2f} +/- {_stats.tstd(errors):.2f} kcal/mol\n"

            # print(stats_string)

        except Exception as e:
            print(e)
            print(f"could not compute for {protein} {name} {eng}")

In [None]:
# graphs based on engine
plotting_dict = sem_dict  # mae_dict or sem_dict
stats_name = "ΔΔG SEM"  # MAE or SEM

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(20, 20), sharex=True, sharey=True)
plt.xlim = ()
plt.ylim = ()
for engine, pos in zip(ana_obj.engines, [axes[0], axes[1], axes[2]]):
    df_list = []
    df_err_list = []
    for name in ana_dicts:
        df = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[0])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        df_err = (
            pd.DataFrame(plotting_dict[name])
            .applymap(lambda x: x[1])
            .rename(prot_dict_name, axis=1)
            .T.drop(labels=[eng for eng in ana_obj.engines if eng != engine], axis=1)
            .rename({engine: name}, axis=1)
            .rename(
                {
                    "plain": "Full data",
                    "subsampling": "Subsampling",
                    "autoeq": "Auto-equilibration",
                    "1ns": "1 ns sampling",
                    "2ns": "2 ns sampling",
                    "3ns": "3 ns sampling",
                },
                axis=1,
            )
        )

        # df_lower = df_err.applymap(lambda x: x[0])
        # df_upper = df_err.applymap(lambda x: x[1])
        # df_err = (df_upper - df_lower) / 2

        df_list.append(df)
        df_err_list.append(df_err)

    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_list,
    )
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        df_err_list,
    )

    print(df)
    print(engine)
    print(df.mean())
    print(df.sem())
    print(df_err)

    # engine colours
    col_dict = {
        "AMBER": plt.get_cmap("autumn"),
        "SOMD": plt.get_cmap("cool"),
        "GROMACS": plt.get_cmap("viridis"),
    }

    # scale data for compatibility with cmap
    data = [i for i in range(1, len(df.columns) + 1)]
    den = max(data) - min(data)
    scaled_data = [(datum - min(data)) / den for datum in data]

    # get colors corresponding to data
    colors = []
    my_cmap = plt.get_cmap("plasma")  # col_dict[engine]

    for decimal in scaled_data:
        colors.append(my_cmap(decimal))

    df.plot(
        kind="bar",
        color=colors,
        yerr=df_err,
        title=eng_dict_name[engine],
        ax=pos,
        xlabel="Protein System",
        ylabel=f"{stats_name} (kcal/mol)",
    )
# fig.suptitle(f'{stats_name} perturbations for LOMAP/RBFENN-score')

In [None]:
# set the network for the pertubation analysis
network = "combined"
ana_obj = network_dict[network]["plain"]

In [None]:
# check the statistical significance and make a violin plot

stats_name = "SEM"

# checking for significance
eng1 = "AMBER"
eng2 = "SOMD"
eng3 = "GROMACS"
first_err_vals = []
second_err_vals = []
third_err_vals = []

if stats_name == "MAE":
    # MAE
    f_err_vals = [
        abs(ana_obj.calc_pert_dict[eng1][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in ana_obj.calc_pert_dict[eng1]
        if not "Intermediate" in key
    ]
    s_err_vals = [
        abs(ana_obj.calc_pert_dict[eng2][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in ana_obj.calc_pert_dict[eng2]
        if not "Intermediate" in key
    ]
    t_err_vals = [
        abs(ana_obj.calc_pert_dict[eng3][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in ana_obj.calc_pert_dict[eng3]
        if not "Intermediate" in key
    ]
elif stats_name == "SEM":
    # SEM
    f_err_vals = [
        abs(ana_obj.calc_pert_dict[eng1][key][1])
        for key in ana_obj.calc_pert_dict[eng1]
    ]
    s_err_vals = [
        abs(ana_obj.calc_pert_dict[eng2][key][1])
        for key in ana_obj.calc_pert_dict[eng2]
    ]
    t_err_vals = [
        abs(ana_obj.calc_pert_dict[eng3][key][1])
        for key in ana_obj.calc_pert_dict[eng3]
    ]
else:
    print("wrong name")

first_err_vals.append(f_err_vals)
second_err_vals.append(s_err_vals)
third_err_vals.append(t_err_vals)

first_err_vals = flatten_comprehension(first_err_vals)
second_err_vals = flatten_comprehension(second_err_vals)
third_err_vals = flatten_comprehension(third_err_vals)

# filtered_data = [t for t in zip(first_err_vals, second_err_vals, third_err_vals) if not any(np.isnan(x) for x in t)]
# first_err_vals, second_err_vals, third_err_val = map(list, zip(*filtered_data))
valid_indices = [
    i
    for i in range(len(first_err_vals))
    if not (
        np.isnan(first_err_vals[i])
        or np.isnan(second_err_vals[i])
        or np.isnan(third_err_vals[i])
    )
]
first_err_vals = [first_err_vals[i] for i in valid_indices]
second_err_vals = [second_err_vals[i] for i in valid_indices]
third_err_vals = [third_err_vals[i] for i in valid_indices]

assert len(first_err_vals) == len(second_err_vals)
assert len(first_err_vals) == len(third_err_vals)
assert len(second_err_vals) == len(third_err_vals)

eng_list_dict = {}
eng_list_dict[eng1] = first_err_vals
eng_list_dict[eng2] = second_err_vals
eng_list_dict[eng3] = third_err_vals

stats_test_dict = {}

for enga in [eng1, eng2, eng3]:
    stats_test_dict[enga] = {}
    for engb in [eng1, eng2, eng3]:
        if enga == engb:
            stats_test_dict[enga][engb] = 100
        else:
            # check normally distributed
            if (
                _stats.shapiro(
                    abs(np.array(eng_list_dict[enga] - np.array(eng_list_dict[engb])))
                )[1]
                > 0.05
            ):
                print("data is normally distributed !!")
            else:
                t, p = _stats.wilcoxon(
                    eng_list_dict[enga], eng_list_dict[engb]
                )  # absolute error  # ttest_rel
                stats_test_dict[enga][engb] = p
        # print(enga, engb, t, p)

df = pd.DataFrame(stats_test_dict).applymap(lambda x: float(x))
print(
    f"statistical significance for the {stats_name} for the perturbations between engines"
)
df

In [None]:
# MAD between engines

filtered_keys = [
    key
    for key in ana_obj.calc_pert_dict[eng1]
    if key in ana_obj.calc_pert_dict[eng2]
    and key in ana_obj.calc_pert_dict[eng3]
    and not (
        np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
        or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()
        or np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
    )
]

# statistical sig

# only for the same perturabtions
eng_list_dict = {}
eng_list_dict[eng1] = [ana_obj.calc_pert_dict[eng1][key][0] for key in filtered_keys]
eng_list_dict[eng2] = [ana_obj.calc_pert_dict[eng2][key][0] for key in filtered_keys]
eng_list_dict[eng3] = [ana_obj.calc_pert_dict[eng3][key][0] for key in filtered_keys]

stats_test_dict = {}

for enga in [eng1, eng2, eng3]:
    stats_test_dict[enga] = {}
    for engb in [eng1, eng2, eng3]:
        if enga == engb:
            stats_test_dict[enga][engb] = 100
        else:
            # check normally distributed
            if (
                _stats.shapiro(
                    abs(np.array(eng_list_dict[enga] - np.array(eng_list_dict[engb])))
                )[1]
                > 0.05
            ):
                print(
                    f"data is normally distributed for {enga} and {engb}!! Still carrying out wilcoxon signed rank ...."
                )

            t, p = _stats.wilcoxon(
                eng_list_dict[enga], eng_list_dict[engb]
            )  # absolute error  # ttest_rel
            stats_test_dict[enga][engb] = p
        # print(enga, engb, t, p)

df_col = pd.DataFrame(stats_test_dict).applymap(lambda x: float(x))
print("statistical significance between the perturbations calculated")
df_col

In [None]:
# compared to each other
mad = ana_obj.calc_mad_engines(pert_val="pert", recalculate=False)
mad[0].update(mad[0].applymap(lambda x: f"" if x == 0 else f"{x:.2f}"))
mad[2].update(
    mad[2].applymap(lambda x: f"" if x[0] == 0 else f"({x[0]:.2f}, {x[1]:.2f})")
)
df_val = mad[0].astype(str) + "\n" + mad[2].astype(str)
df_val

In [None]:
# plotting the stats test

# threshold for colour labels
color_labels = df_col.applymap(
    lambda x: "white"
    if x == 100
    else "grey"  # same engine, no stats test
    if x > 0.05
    else "pink"  # no statistically significant difference
)  # statistically significant difference

# mapping dictionary
color_mapping = {"pink": "#FFC0CB", "white": "#FFFFFF", "grey": "#BEBEBE"}

# Convert text labels to a numerical array
color_numeric = df_col.applymap(lambda x: 0 if x == 100 else 1 if x > 0.05 else 2)

# below as otherwise problem if only two colours
array_col_dict = {
    0: color_mapping["white"],
    1: color_mapping["grey"],
    2: color_mapping["pink"],
}

numeric_colours_list = flatten_comprehension(color_numeric.values.tolist())

cmap = mcolors.ListedColormap(
    [
        array_col_dict[key]
        for key in array_col_dict.keys()
        if key in numeric_colours_list
    ]
)

# Plot heatmap using numeric mapping for colors
fig, ax = plt.subplots(figsize=(4, 4))
sns.heatmap(color_numeric, annot=df_val, fmt="s", cmap=cmap, cbar=False, ax=ax)

legend_patches = [  # mpatches.Patch(color=color_mapping["white"], label=""),
    mpatches.Patch(color=color_mapping["pink"], label="p ≤ 0.05"),
    mpatches.Patch(color=color_mapping["grey"], label="p > 0.05"),
]

# Add legend to the plot
plt.legend(
    handles=legend_patches,
    loc="center left",
    title="Statistical significance",
    bbox_to_anchor=(1, 0.5),
)

plt.title("MAD (kcal/mol) between engines (95% CI)")

In [None]:
# distribution of the perts values to investigate significant difference if present in the above plot
# histogram
# plt.hist(eng_list_dict[eng1], density=True, color=pipeline.analysis.set_colours()[eng1], label=eng_dict_name[eng1], alpha=0.5)
plt.hist(
    eng_list_dict[eng2],
    density=True,
    color=pipeline.analysis.set_colours()[eng2],
    label=eng_dict_name[eng2],
    alpha=0.5,
)
plt.hist(
    eng_list_dict[eng3],
    density=True,
    color=pipeline.analysis.set_colours()[eng3],
    label=eng_dict_name[eng3],
    alpha=0.5,
)
plt.legend(loc="upper right")
plt.xlabel(f"ddG value (kcal/mol)")
plt.ylabel("Density")
# plt.title(prot_name_dict[protein])

In [None]:
# violin plots
data = {
    "engine": [f"{eng_dict_name[eng1]}"] * len(first_err_vals)
    + [f"{eng_dict_name[eng2]}"] * len(second_err_vals)
    + [f"{eng_dict_name[eng3]}"] * len(third_err_vals),
    "error": flatten_comprehension([first_err_vals, second_err_vals, third_err_vals]),
}

df = pd.DataFrame(data)

plt.figure(figsize=(8, 5))
sns.violinplot(x="engine", y="error", data=df, inner="box", palette=col_dict)

# plt.title("MAE Distribution for Different MD Engines")
plt.xlabel("MD Engine")
plt.ylabel(f"{stats_name} (kcal/mol)")

In [None]:
# histogram
plt.hist(
    first_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng1],
    label=eng_dict_name[eng1],
    alpha=0.5,
)
plt.hist(
    second_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng2],
    label=eng_dict_name[eng2],
    alpha=0.5,
)
plt.hist(
    third_err_vals,
    density=True,
    color=pipeline.analysis.set_colours()[eng3],
    label=eng_dict_name[eng3],
    alpha=0.5,
)
plt.legend(loc="upper right")
plt.xlabel(f"ddG {stats_name} (kcal/mol)")
plt.ylabel("Density")
plt.title(prot_name_dict[protein])

In [None]:
# historgrams of SEM for the legs or repeats

ana_obj.plot_histogram_repeats()
ana_obj.plot_histogram_legs()

# ana_obj.plot_histogram_sem()

In [None]:
# 2d contour plot
stats_name = "MAE"

for eng1, eng2 in (["AMBER", "SOMD"], ["AMBER", "GROMACS"], ["GROMACS", "SOMD"]):
    first_err_vals = []
    second_err_vals = []

    filtered_keys = [
        key
        for key in ana_obj.calc_pert_dict[eng1]
        if key
        in ana_obj.calc_pert_dict[eng2]  # and key in ana_obj.calc_pert_dict[eng3]
        and not (
            np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
            or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()  # or
            # np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
        )
    ]

    if stats_name == "SEM":
        # MAE
        f_err_vals = [
            abs(ana_obj.calc_pert_dict[eng1][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]
        s_err_vals = [
            abs(ana_obj.calc_pert_dict[eng2][key][0] - ana_obj.exper_pert_dict[key][0])
            for key in filtered_keys
            if not "Intermediate" in key
        ]

    elif stats_name == "MAE":
        # SEM
        f_err_vals = [
            abs(ana_obj.calc_pert_dict[eng1][key][1]) for key in filtered_keys
        ]
        s_err_vals = [
            abs(ana_obj.calc_pert_dict[eng2][key][1]) for key in filtered_keys
        ]

    first_err_vals.append(f_err_vals)
    second_err_vals.append(s_err_vals)

    x = flatten_comprehension(first_err_vals)
    y = flatten_comprehension(second_err_vals)
    z = np.abs(np.array(x) - np.array(y))  # z = np.sin(x) + np.cos(y)

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.kdeplot(x=x, y=y, cmap="PuRd", fill=True, levels=10, thresh=0.05)
    # # Scatter plot on top to show data points
    plt.scatter(x, y, c=z, cmap="Purples", edgecolors="black")
    plt.colorbar(label=f"Absolute difference\n between engine {stats_name} (kcal/mol)")

    # Labels and Title
    plt.xlabel(f"{eng1} {stats_name}")
    plt.ylabel(f"{eng2} {stats_name}")
    plt.title(f"{stats_name}")

In [None]:
# 2d contour plot
stats_name = "MAE (kcal/mol)"

for eng1, eng2, eng3 in [
    ("GROMACS", "SOMD", "AMBER"),
    ("SOMD", "AMBER", "GROMACS"),
    ("AMBER", "GROMACS", "SOMD"),
]:
    first_err_vals = []
    second_err_vals = []
    third_err_vals = []

    filtered_keys = [
        key
        for key in ana_obj.calc_pert_dict[eng1]
        if key in ana_obj.calc_pert_dict[eng2]
        and key in ana_obj.calc_pert_dict[eng3]
        and not (
            np.isnan(ana_obj.calc_pert_dict[eng1][key]).any()
            or np.isnan(ana_obj.calc_pert_dict[eng2][key]).any()
            or np.isnan(ana_obj.calc_pert_dict[eng3][key]).any()
        )
    ]

    # MAE
    f_err_vals = [
        abs(ana_obj.calc_pert_dict[eng1][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in filtered_keys
        if not "Intermediate" in key
    ]
    s_err_vals = [
        abs(ana_obj.calc_pert_dict[eng2][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in filtered_keys
        if not "Intermediate" in key
    ]
    t_err_vals = [
        abs(ana_obj.calc_pert_dict[eng3][key][0] - ana_obj.exper_pert_dict[key][0])
        for key in filtered_keys
        if not "Intermediate" in key
    ]

    # SEM
    # f_err_vals = [abs(ana_obj.calc_pert_dict[eng1][key][1]) for key in filtered_keys]
    # s_err_vals = [abs(ana_obj.calc_pert_dict[eng2][key][1]) for key in filtered_keys]
    # t_err_vals = [abs(ana_obj.calc_pert_dict[eng3][key][1]) for key in filtered_keys]

    first_err_vals.append(f_err_vals)
    second_err_vals.append(s_err_vals)
    third_err_vals.append(t_err_vals)

    x = np.array(flatten_comprehension(first_err_vals)) - np.array(
        flatten_comprehension(second_err_vals)
    )
    y = np.array(flatten_comprehension(third_err_vals)) - np.array(
        flatten_comprehension(second_err_vals)
    )
    z = np.abs(np.array(x) - np.array(y))  # z = np.sin(x) + np.cos(y)

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.kdeplot(x=x, y=y, cmap="PuRd", fill=True, levels=10, thresh=0.05)
    # # Scatter plot on top to show data points
    plt.scatter(x, y, c=second_err_vals, cmap="Purples", edgecolors="black")
    plt.colorbar(label=f"{eng2} {stats_name}")
    ax.axhline(0, color="black", linewidth=1)
    ax.axvline(0, color="black", linewidth=1)

    ax.text(
        0.05,
        0.95,
        f"better than {eng3}\nworse than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.70,
        0.10,
        f"worse than {eng3}\nbetter than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.05,
        0.10,
        f"worse than {eng3}\nworse than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    ax.text(
        0.70,
        0.95,
        f"better than {eng3}\nbetter than {eng1}",
        transform=ax.transAxes,
        fontsize=9,
        color="black",
        ha="left",
        va="top",
    )
    # Labels and Title
    plt.xlabel(f"{eng1}-{eng2} {stats_name}")
    plt.ylabel(f"{eng3}-{eng2} {stats_name}")
    plt.title(f"{eng2} {stats_name.replace('(kcal/mol)','')}comparison")

In [None]:
# obtain the literature results

file = (
    f"/home/anna/Documents/benchmark/inputs/other_computed/fepplus/{protein}_perts.csv"
)
df = pd.read_csv(file, delimiter=",")

# for perturbations
fepplus_perts_dict = {}
for index, row in df.iterrows():
    fepplus_perts_dict[
        f"lig_{row['Lig 1'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}~lig_{row['Lig 2'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}"
    ] = (
        row["Bennett ddG (kcal/mol)"],
        row["Bennett std. error (kcal/mol)"],
    )

write_perts_file(
    fepplus_perts_dict,
    # .csv
    file_path=f"/home/anna/Documents/benchmark/inputs/{protein}/perts_file_fepplus_new",
)

# for ligands
file = (
    f"/home/anna/Documents/benchmark/inputs/other_computed/fepplus/{protein}_ligs.csv"
)
df = pd.read_csv(file, delimiter=",")

fepplus_ligs_dict = {}
for index, row in df.iterrows():
    fepplus_ligs_dict[
        f"lig_{row['Ligand name'].replace(' flip', '').replace('-charged-pKa-8.1', '').replace(' redocked', '').replace('_n', '').replace(' ground state', '').replace('docked ', '').replace(' adjust', '').replace('ejm_', 'ejm').replace('jmc_', 'jmc').strip()}"
    ] = (
        row["Pred. dG (kcal/mol)"],
        row["Pred. dG std. error (kcal/mol)"],
    )

normalised_ligs_dict = {}
avg = np.mean([val[0] for val in fepplus_ligs_dict.values()])
for lig in fepplus_ligs_dict:
    normalised_ligs_dict[lig] = (
        fepplus_ligs_dict[lig][0] - avg,
        fepplus_ligs_dict[lig][1],
    )

fepplus_ligs_dict = normalised_ligs_dict

write_vals_file(
    fepplus_ligs_dict,
    # .csv
    file_path=f"/home/anna/Documents/benchmark/inputs/{protein}/ligs_file_fepplus_new",
)


# Hahn et al

file = f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/{protein}.dat"

df = pd.read_csv(file, delimiter="  ")

hahn_perts_dict = {}
for index, row in df.iterrows():
    hahn_perts_dict[f"{row['edge']}"] = (float(row["ddg"]), float(row["ddg_err"]))

# need to convert into kcal/mol
for key in hahn_perts_dict:
    hahn_perts_dict[key] = (
        hahn_perts_dict[key][0] * 0.239006,
        hahn_perts_dict[key][1] * 0.239006,
    )

write_perts_file(
    hahn_perts_dict,
    # .csv
    file_path=f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/perts_file_{protein}",
)

files = [
    f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/perts_file_{protein}.csv"
]

calc_diff_dict = make_dict.comp_results(files)  # older method

perts, ligs = get_info_network_from_dict(calc_diff_dict)
exper_dict = ana_obj.exper_val_dict

convert.cinnabar_file(
    files,
    exper_dict,
    f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/cinnabar_{protein}",
    perturbations=perts,
    method=None,
)

# compute the per ligand for the network
network = _wrangle.FEMap(
    f"/home/anna/Documents/benchmark/inputs/other_computed/hahn/cinnabar_{protein}.csv"
)

# for self plotting of per ligand
hahn_ligs_dict = make_dict.from_cinnabar_network_node(network, "calc")
hahn_perts_dict = make_dict.from_cinnabar_network_edges(network, "calc", perts)


# OpenFE
if protein == "syk":
    openfe_ligs_dict = {lig: (np.nan, np.nan) for lig in ana_obj.ligands}
    openfe_perts_dict = {lig: (np.nan, np.nan) for lig in ana_obj.perturbations}
else:
    df_main = pd.read_csv(
        "/home/anna/Documents/benchmark/inputs/other_computed/openfe/combined_pymbar3_edge_data.csv"
    )

    for rep in [0, 1, 2]:
        df = df_main[df_main["system name"] == protein]
        df["freenrg"] = (
            df[f"complex_repeat_{rep}_DG (kcal/mol)"]
            - df[f"solvent_repeat_{rep}_DG (kcal/mol)"]
        )
        df["dG_err_temp"] = df[f"complex_repeat_{rep}_DG (kcal/mol)"].apply(
            lambda x: math.pow(x, 2)
        ) + df[f"solvent_repeat_{rep}_DG (kcal/mol)"].apply(lambda x: math.pow(x, 2))
        df["error"] = df[f"dG_err_temp"].apply(lambda x: math.sqrt(x))
        df["lig_0"] = "lig_" + df["ligand_A"].str.replace(
            "_redocked", "", regex=True
        ).replace("-charged-pKa-8.1", "", regex=True).replace(
            "-flip", "", regex=True
        ).replace(
            "ejm_", "ejm", regex=True
        ).replace(
            "jmc_", "jmc", regex=True
        )
        df["lig_1"] = "lig_" + df["ligand_B"].str.replace(
            "_redocked", "", regex=True
        ).replace("-charged-pKa-8.1", "", regex=True).replace(
            "-flip", "", regex=True
        ).replace(
            "ejm_", "ejm", regex=True
        ).replace(
            "jmc_", "jmc", regex=True
        )
        df.to_csv(
            f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/{protein}_{rep}.csv",
            columns=["lig_0", "lig_1", "freenrg", "error"],
            index=False,
        )

    files = [
        f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/{protein}_{rep}.csv"
        for rep in [0, 1, 2]
    ]

    calc_diff_dict = make_dict.comp_results(files)  # older method

    perts, ligs = get_info_network_from_dict(calc_diff_dict)
    exper_dict = ana_obj.exper_val_dict

    convert.cinnabar_file(
        files,
        exper_dict,
        f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/cinnabar_{protein}",
        perturbations=perts,
        method=None,
    )

    # compute the per ligand for the network
    network = _wrangle.FEMap(
        f"/home/anna/Documents/benchmark/inputs/other_computed/openfe/cinnabar_{protein}.csv"
    )

    # for self plotting of per ligand
    openfe_ligs_dict = make_dict.from_cinnabar_network_node(network, "calc")
    openfe_perts_dict = make_dict.from_cinnabar_network_edges(network, "calc", perts)

In [None]:
# calculate perturbation statistics

stats_name = "MAE"
print(stats_name)

if stats_name == "MAE":
    func = ana_obj.calc_mae_engines
    cinn_stats_name = "MUE"

elif stats_name == "RMSE":
    func = ana_obj.calc_rmse_engines
    cinn_stats_name = "RMSE"
else:
    print("no")

val_dict = {}
val_dict[protein] = {}

res = func(pert_val="pert", recalculate=False)  # mae / rmse
for eng in ana_obj.engines:
    val_dict[protein][eng_dict_name[eng]] = (
        res[0][eng]["experimental"],
        res[1][eng]["experimental"],
        res[2][eng]["experimental"],
    )
# print(res)

# literature

exper_dict = ana_obj.exper_val_dict

for lit_perts_dict, name in zip(
    [openfe_perts_dict, hahn_perts_dict, fepplus_perts_dict],
    ["openfe", "hahn", "fepplus"],
):  #
    x = []
    y = []
    xerr = []
    yerr = []

    perturbations = []
    excl = 0
    incl = 0
    for pert in lit_perts_dict:
        if (
            pert.split("~")[0] in exper_dict.keys()
            and pert.split("~")[1] in exper_dict.keys()
        ):
            perturbations.append(pert)
            incl += 1
        else:
            excl += 1
    print(
        "only including perturbations that also have the same ligands used. Not necessarily the same perturbations."
    )
    print(
        f"{excl} perturbations excluded for {name}, {incl} included: {incl/(incl+excl)*100}"
    )

    # additionally only if there are the same perturbations, check what this would be
    # use_perts = []
    # reverse_perts = []
    # for pert in ana_obj.perturbations:
    #     if pert in perturbations:
    #         use_perts.append(pert)
    #     if f"{pert.split('~')[1]}~{pert.split('~')[0]}" in perturbations:
    #         reverse_perts.append(pert)
    # print(f"{len(use_perts)+len(reverse_perts)} perturbations of these would be the same/reverse perturbations ({(len(use_perts)+len(reverse_perts))/(perturbations)*100} %)")
    # perturbations = flatten_comprehension([use_perts, reverse_perts])

    exper_pert_dict = make_dict.exper_from_perturbations(exper_dict, perturbations)

    for pert in perturbations:
        x.append(lit_perts_dict[pert][0])
        xerr.append(lit_perts_dict[pert][1])
        y.append(exper_pert_dict[pert][0])
        yerr.append(exper_pert_dict[pert][1])

    # calculate statistics

    res = stats_engines.compute_stats(
        x=x, xerr=xerr, y=y, yerr=yerr, statistic=cinn_stats_name
    )
    # print("cinnabar", name, res)

    val_dict[protein][eng_dict_name[name]] = (
        res[0],
        res[1],
        res[2],
    )

df = pd.DataFrame(val_dict)
# df.to_markdown()
df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
df

In [None]:
val_res = pd.DataFrame(val_dict).T.map(lambda x: x[0])
ax = val_res.plot.bar(
    color=col_dict,
    yerr=pd.DataFrame(val_dict).T.map(lambda x: x[1]),
    xlabel="protein",
    ylabel=f"{stats_name} (kcal/mol)",
)
ax.legend(
    loc="lower center",
    bbox_to_anchor=(0.5, 1.0),  # fancybox=True, shadow=True
)

In [None]:
# correlating the number of perturbed atoms with precision (SEM) and accuracy (MAE)

pert_overlap_dict = {}

df = ana_obj.perturbing_atoms_and_overlap(read_file=True)

df["score"] = np.nan
# read in all the lomap scores
score_dict = {}
# print(f"{main_dir}/execution_model/network_scores.dat")
with open(f"{main_dir}/execution_model/network_scores.dat") as lfile:
    for line in lfile:
        score_dict[
            f"{line.split(',')[0].strip()}~{line.split(',')[1].strip()}"
        ] = float(line.split(",")[-1].strip())

for index, row in df.iterrows():
    if row["perturbation"] not in ana_obj.perturbations:
        df = df.drop(index)
    else:
        try:
            df.at[index, "score"] = score_dict[row["perturbation"]]
        except:
            try:
                df.at[index, "score"] = score_dict[
                    f'{row["perturbation"].split("~")[1]}~{row["perturbation"].split("~")[0]}'
                ]
            except:
                # print(f"not {row['perturbation']}")
                pass

pert_overlap_dict[protein] = df

In [None]:
# sorting the df below to check things
df_plot.sort_values(by="Overlap > 0.03 (%)")[:10]

In [None]:
df_plot = df  # df.dropna()
df_plot.rename(
    columns={
        "perturbing_atoms": "Average number of perturbing atoms",
        "diff_to_exp": "MAE (kcal/mol)",
        "percen_overlap_okay": "Overlap > 0.03 (%)",
        "error": "SEM (kcal/mol)",
    },
    inplace=True,
)

x = "Average number of perturbing atoms"
y = "MAE (kcal/mol)"
z = "SEM (kcal/mol)"

fig, ax = plt.subplots(figsize=(8, 6))
# sns.kdeplot(x=df_plot[x], y=df_plot[y], cmap="PuRd", fill=True, levels=10, thresh=0.05)
df_plot.plot.scatter(
    x,
    y,
    c=z,
    colormap="plasma",  # vmin=0, vmax=100, ax=ax
)

for eng in ana_obj.engines:
    df2 = df[df["engine"] == eng]
    df2.plot.scatter(
        x,
        y,
        c=z,
        colormap="plasma",
        title=eng_dict_name[eng],  # vmin=80, vmax=100
    )

In [None]:
# checking perturbations from below
val_dict = {}
val_dict["value"] = {}
val_dict["difference"] = {}
pert = "lig_ejm42~lig_ejm43"
for eng in ["SOMD", "AMBER", "GROMACS"]:
    val_dict["value"][
        eng
    ] = f"{ana_obj.calc_pert_dict[eng][pert][0]:.2f} ({ana_obj.calc_pert_dict[eng][pert][1]:.2f})"
    val_dict["difference"][
        eng
    ] = f"{abs(ana_obj.calc_pert_dict[eng][pert][0] - ana_obj.exper_pert_dict[pert][0]):.2f}"
val_dict["value"][
    "experimental"
] = f"{ana_obj.exper_pert_dict[pert][0]:.2f} ({ana_obj.exper_pert_dict[pert][1]:.2f})"
ana_obj.draw_perturbations([pert])

nets = []
for key in network_dict.keys():
    aj = network_dict[key]["plain"]
    if pert in aj.perturbations:
        nets.append(key)
print(pert)
print(nets)
print(
    f"difference: {max([ana_obj.calc_pert_dict[eng][pert][0] for eng in ana_obj.engines]) - min([ana_obj.calc_pert_dict[eng][pert][0] for eng in ana_obj.engines]):.2f}"
)

df = pd.DataFrame(val_dict)
df

In [None]:
# find the greatest difference to experimental

for eng in ana_obj.engines:
    diff_dict = {}
    for pert in ana_obj.calc_pert_dict[eng]:
        if "Intermediate" in pert:
            pass
        else:
            diff_dict[pert] = abs(
                ana_obj.calc_pert_dict[eng][pert][0] - ana_obj.exper_pert_dict[pert][0]
            )

    print(eng)
    sorted_items = sorted(diff_dict.items(), key=lambda kv: (kv[1], kv[0]))

    df = pd.DataFrame(sorted_items, columns=["perturbation", "mae"])
    df["results"] = df["perturbation"].map(lambda x: ana_obj.calc_pert_dict[eng][x])
    df["experimental"] = df["perturbation"].map(lambda x: ana_obj.exper_pert_dict[x])
    print(df.nlargest(5, "mae"))

    # ana_obj.draw_perturbations([sorted_items[-1][0], sorted_items[-2][0], sorted_items[-3][0]])

In [None]:
# the greatest difference between engines

# Get shared keys
keys = (
    ana_obj.calc_pert_dict["AMBER"].keys()
    & ana_obj.calc_pert_dict["SOMD"].keys()
    & ana_obj.calc_pert_dict["GROMACS"].keys()
)

# Compute max difference for each key
diffs = {
    key: max(
        ana_obj.calc_pert_dict["AMBER"][key][0],
        ana_obj.calc_pert_dict["SOMD"][key][0],
        ana_obj.calc_pert_dict["GROMACS"][key][0],
    )
    - min(
        ana_obj.calc_pert_dict["AMBER"][key][0],
        ana_obj.calc_pert_dict["SOMD"][key][0],
        ana_obj.calc_pert_dict["GROMACS"][key][0],
    )
    for key in keys
}

# Sort keys by difference in descending order
sorted_keys = sorted(diffs, key=diffs.get, reverse=True)

for key in sorted_keys:
    print(key, diffs[key])
    print(
        "AMBER",
        ana_obj.calc_pert_dict["AMBER"][key],
        "SOMD",
        ana_obj.calc_pert_dict["SOMD"][key],
        "GROMACS",
        ana_obj.calc_pert_dict["GROMACS"][key],
        "experimental",
        ana_obj.exper_pert_dict[key],
    )

# ana_obj.draw_perturbations(sorted_keys[:5])
# ana_obj.draw_perturbations(sorted_keys[-5:])

In [None]:
# ligand analysis
ana_obj = network_dict["combined"]["plain"]

In [None]:
# ana_obj.remove_perturbations(["lig_CHEMBL3402754_40~lig_CHEMBL3402755_4200"])
# ana_obj.compute_results()
# for eng in ana_obj.engines:
#     ana_obj.disconnected_ligands(eng)

In [None]:
# for literature results
# calculate value statistics

for stats_name, stats, func in zip(
    ["MAE (kcal/mol)", "Kendall's Rank", "R2"],
    ["MUE", "KTAU", "R2"],
    [
        ana_obj.calc_mae_engines,
        ana_obj.calc_kendalls_rank_engines,
        ana_obj.calc_r2_engines,
    ],
):
    print(stats_name)

    val_dict = {}
    val_dict[protein] = {}

    res = func(pert_val="val", recalculate=False)  # mae / kendalls_rank / r2
    for eng in ana_obj.engines:
        val_dict[protein][eng_dict_name[eng]] = (
            res[0][eng]["experimental"],
            res[1][eng]["experimental"],
            res[2][eng]["experimental"],
        )
        # print("cinnabar", eng, res[0][eng]["experimental"], res[1][eng]["experimental"], res[2][eng]["experimental"])

    # add the experimental values

    # normalise exper dict
    exper_dict = ana_obj.exper_val_dict

    normalised_exper_dict = {}
    avg = np.mean([val[0] for val in exper_dict.values()])
    for lig in exper_dict:
        normalised_exper_dict[lig] = (exper_dict[lig][0] - avg, exper_dict[lig][1])

    for lit_val_dict, name in zip(
        [openfe_ligs_dict, hahn_ligs_dict, fepplus_ligs_dict],
        ["openfe", "hahn", "fepplus"],
    ):  #
        x = []
        y = []
        xerr = []
        yerr = []
        for lig in ana_obj.ligands:
            if lig in lit_val_dict.keys():
                if not np.isnan(lit_val_dict[lig][0]):
                    x.append(lit_val_dict[lig][0])
                    xerr.append(lit_val_dict[lig][1])
                    y.append(normalised_exper_dict[lig][0])
                    yerr.append(normalised_exper_dict[lig][1])

        # calculate statistics

        res = stats_engines.compute_stats(
            x=x, xerr=xerr, y=y, yerr=yerr, statistic=stats
        )  # MUE, KTAU, R2
        # print("cinnabar", name, res)

        val_dict[protein][eng_dict_name[name]] = (
            res[0],
            res[1],
            res[2],
        )

    val_res = pd.DataFrame(val_dict).T.map(lambda x: x[0])
    ax = val_res.plot.bar(
        color=col_dict,
        yerr=pd.DataFrame(val_dict).T.map(lambda x: x[1]),
        xlabel="protein",
        ylabel=f"{stats_name}",
    )
    ax.legend(
        loc="lower center",
        bbox_to_anchor=(0.5, 1.0),  # fancybox=True, shadow=True
    )

    df = pd.DataFrame(val_dict)
    df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
    print(df)

In [None]:
# plot the a/d optimal networks together w their counterpart

stats_name = "MAE"
net = "lomap"

plotting_dict = {}
for network in [f"{net}", f"{net}-a-optimal", f"{net}-d-optimal"]:
    plotting_dict[network] = {}
    aj = network_dict[network]["plain"]

    if stats_name == "MAE":
        func = aj.calc_mae_engines
    elif stats_name == "KTAU":
        func = aj.calc_kendalls_rank_engines
    elif stats_name == "R2":
        func = aj.calc_r2_engines
    else:
        print("no")
        func = None

    vals = func(pert_val="val", recalculate=False)

    for eng in aj.engines:
        data_point = (
            vals[0][eng]["experimental"],
            vals[1][eng]["experimental"],
            vals[2][eng]["experimental"],
        )

        # print(network, protein, eng, data_point)

        plotting_dict[network][eng] = data_point

# print(plotting_dict)

# fig, ax = plt.subplots(figsize=(8, 6))
# plt.xlim = ()
# plt.ylim = ()

df = pd.DataFrame(plotting_dict).applymap(lambda x: x[0]).T
df_err = pd.DataFrame(plotting_dict).applymap(lambda x: x[1]).T

# engine colours
col_dict = {
    "AMBER": ["orange", "moccasin", "oldlace"],
    "SOMD": ["darkturquoise", "paleturquoise", "azure"],
    "GROMACS": ["orchid", "plum", "pink"],
}
df.plot(
    kind="bar",
    color=col_dict,
    yerr=df_err,
    title=prot_dict_name[protein],
    xlabel="Network Method",
    ylabel=f"{stats_name}",
    legend=True,
)
# key = pipeline.analysis.set_colours()
# key.pop("experimental")
# pos.legend(key)

df = pd.DataFrame(plotting_dict)
df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
df

In [None]:
ana_obj.plot_scatter_dG(use_cinnabar=True)

In [None]:
for eng in ana_obj.engines:
    ana_obj.plot_outliers(engines=[eng], pert_val="val", no_outliers=5)

In [None]:
# ligand analysis
ana_obj = network_dict["rbfenn"]["plain"]

In [None]:
# checking ligands from below
val = "lig_2k"

val_dict = {}
val_dict["value"] = {}
val_dict["difference"] = {}

for eng in ["SOMD", "AMBER", "GROMACS"]:
    val_dict["value"][
        eng
    ] = f"{ana_obj.cinnabar_calc_val_dict[eng][val][0]:.2f} ({ana_obj.cinnabar_calc_val_dict[eng][val][1]:.2f})"
    val_dict["difference"][
        eng
    ] = f"{abs(ana_obj.cinnabar_calc_val_dict[eng][val][0] - ana_obj.normalised_exper_val_dict[val][0]):.2f}"
    perts = []
    for pert in ana_obj._perturbations_dict[eng]:
        if val in pert:
            perts.append(pert)
    print(eng, val, perts)
val_dict["value"][
    "experiemental"
] = f"{ana_obj.normalised_exper_val_dict[val][0]:.2f} ({ana_obj.normalised_exper_val_dict[val][1]:.2f})"
ana_obj.draw_ligands([val])

df = pd.DataFrame(val_dict)
df

In [None]:
aj_dict = {}
val_dict = {}

for net in network_dict:
    aj = network_dict[net]["plain"]
    perts = []
    for pert in aj.perturbations:
        if val in pert:
            perts.append(pert)
    aj_dict[net] = [perts]

    val_dict[net] = {}
    for eng in ["SOMD", "AMBER", "GROMACS"]:
        val_dict[net][
            eng
        ] = f"{aj.cinnabar_calc_val_dict[eng][val][0]:.2f} ({aj.cinnabar_calc_val_dict[eng][val][1]:.2f})"
    val_dict[net][
        "experimental"
    ] = f"{aj.normalised_exper_val_dict[val][0]:.2f} ({aj.normalised_exper_val_dict[val][1]:.2f})"

df_perts = pd.DataFrame(aj_dict).T
plot_perts = list(set(flatten_comprehension(df_perts[0])))
ana_obj.plot_bar_ddG(values=plot_perts)

pd.set_option("display.max_colwidth", None)
print(df_perts)

df = pd.DataFrame(val_dict)
df

In [None]:
# find the greatest difference to experimental

for eng in ana_obj.engines:
    diff_dict = {}
    for val in ana_obj.cinnabar_calc_val_dict[eng]:
        if "Intermediate" not in val:
            diff_dict[val] = abs(
                ana_obj.cinnabar_calc_val_dict[eng][val][0]
                - ana_obj.normalised_exper_val_dict[val][0]
            )

    print(eng)
    sorted_items = sorted(diff_dict.items(), key=lambda kv: (kv[1], kv[0]))

    df = pd.DataFrame(sorted_items, columns=["ligand", "mae"])
    df["results"] = df["ligand"].map(lambda x: ana_obj.cinnabar_calc_val_dict[eng][x])
    df["experimental"] = df["ligand"].map(
        lambda x: ana_obj.normalised_exper_val_dict[x]
    )
    print(df.nlargest(5, "mae"))

    for lig in [sorted_items[-1][0], sorted_items[-2][0], sorted_items[-3][0]]:
        perts = []
        for pert in ana_obj._perturbations_dict[eng]:
            if lig in pert:
                perts.append(pert)
        print(lig, perts)

    # ana_obj.draw_ligands([sorted_items[-1][0], sorted_items[-2][0], sorted_items[-3][0]])

In [None]:
# ligand analysis
ana_obj = network_dict["rbfenn"]["plain"]

In [None]:
# the greatest difference between engines

# Get shared keys
keys = (
    ana_obj.cinnabar_calc_val_dict["AMBER"].keys()
    & ana_obj.cinnabar_calc_val_dict["SOMD"].keys()
    & ana_obj.cinnabar_calc_val_dict["GROMACS"].keys()
)

# Compute max difference for each key
diffs = {
    key: max(
        ana_obj.cinnabar_calc_val_dict["AMBER"][key][0],
        ana_obj.cinnabar_calc_val_dict["SOMD"][key][0],
        ana_obj.cinnabar_calc_val_dict["GROMACS"][key][0],
    )
    - min(
        ana_obj.cinnabar_calc_val_dict["AMBER"][key][0],
        ana_obj.cinnabar_calc_val_dict["SOMD"][key][0],
        ana_obj.cinnabar_calc_val_dict["GROMACS"][key][0],
    )
    for key in keys
}

# Sort keys by difference in descending order
sorted_keys = sorted(diffs, key=diffs.get, reverse=True)

for key in sorted_keys:
    print(key, diffs[key])
    if "Intermediate" in key:
        print(
            "AMBER",
            ana_obj.cinnabar_calc_val_dict["AMBER"][key],
            "SOMD",
            ana_obj.cinnabar_calc_val_dict["SOMD"][key],
            "GROMACS",
            ana_obj.cinnabar_calc_val_dict["GROMACS"][key],
        )
    else:
        print(
            "AMBER",
            ana_obj.cinnabar_calc_val_dict["AMBER"][key],
            "SOMD",
            ana_obj.cinnabar_calc_val_dict["SOMD"][key],
            "GROMACS",
            ana_obj.cinnabar_calc_val_dict["GROMACS"][key],
            "experimental",
            ana_obj.normalised_exper_val_dict[key],
        )

for lig in sorted_keys[:3]:
    perts = []
    for pert in ana_obj._perturbations_dict[eng]:
        if lig in pert:
            perts.append(pert)
    print(lig, perts)

# ana_obj.draw_ligands(sorted_keys[:3])

for lig in sorted_keys[-3:]:
    perts = []
    for pert in ana_obj._perturbations_dict[eng]:
        if lig in pert:
            perts.append(pert)
    print(lig, perts)

# ana_obj.draw_ligands(sorted_keys[-3:])

In [None]:
ana_obj.plot_bar_ddG()

In [None]:
ana_obj.plot_bar_dG()

In [None]:
# investigating different network analysis methods:
network = "combined"
ana_obj = network_dict[network]["plain"]

In [None]:
# get the cinnabar stats into a dict
net_ana_method_dict = {"method": [], "engine": [], "protein": [], "value": []}

for eng in ["AMBER", "SOMD", "GROMACS"]:
    dg_list = []

    for key in ana_obj.cinnabar_calc_val_dict[eng].keys():
        value = abs(
            abs(
                ana_obj.cinnabar_calc_val_dict[eng][key][0]
                - ana_obj.cinnabar_exper_val_dict[eng][key][0]
            )
        )
        dg_list.append(value)
        if value > 5:
            print(protein, eng, key, value)

    net_ana_method_dict["method"].append(["cinnabar" for l in range(0, len(dg_list))])
    net_ana_method_dict["engine"].append(
        [eng_dict_name[eng] for l in range(0, len(dg_list))]
    )
    net_ana_method_dict["protein"].append([prot_dict_name[protein] for val in dg_list])
    net_ana_method_dict["value"].append([val for val in dg_list])

# also want to compare fwf and cinnabar
fwf_path = (
    "/home/anna/Documents/september_2022_workshops/freenrgworkflows/networkanalysis"
)

for eng in ana_obj.engines:
    print(eng)
    dg_list = []

    # add path for fwf
    ana_obj._add_fwf_path(fwf_path)
    ana_obj._get_exp_fwf()

    try:
        fwf_dict = ana_obj._get_ana_fwf(engine=eng, use_repeat_files=True)
    except:
        print(f"{protein} {eng} did not fwf w repeat files, tring w out")
        # try:
        #     fwf_dict = ana_obj._get_ana_fwf(engine=eng, use_repeat_files=False)
        # except:
        #     print("non repeat files also failed")

    try:
        di2 = {}
        for di in ana_obj._fwf_computed_DGs[eng]:
            di2[[k for k in di.keys()][0]] = di[[k for k in di.keys()][0]]
        # experimental computed normally outside of fwf and normalised
        for key in di2.keys():
            value = abs(di2[key] - ana_obj.normalised_exper_val_dict[key][0])
            dg_list.append(value)
            if value > 5:
                print(protein, eng, key, value)
    except:
        print("did not fwf at all")

    net_ana_method_dict["method"].append(["fen" for l in range(0, len(dg_list))])
    net_ana_method_dict["engine"].append(
        [eng_dict_name[eng] for l in range(0, len(dg_list))]
    )
    net_ana_method_dict["protein"].append([prot_dict_name[protein] for val in dg_list])
    net_ana_method_dict["value"].append([val for val in dg_list])

In [None]:
ana_obj.analyse_mbarnet(
    compute_missing=True,
    write_xml=True,
    run_xml_py=True,
    use_experimental=True,
    overwrite=True,
    engines=["SOMD"],
    normalise=True,
)
ana_obj._mbarnet_computed_DGs

In [None]:
# mbarnet

# compute all first
for eng in ana_obj.engines:
    try:
        ana_obj.analyse_mbarnet(
            compute_missing=False,
            write_xml=False,
            run_xml_py=False,
            use_experimental=True,
            overwrite=False,
            engines=[eng],
            normalise=True,
        )
    except Exception as e:
        print(e)
        print(f"failed for {eng}")

for eng in ana_obj.engines:
    dg_list = []
    print(eng)

    try:
        for key in ana_obj._mbarnet_computed_DGs[eng].keys():
            value = abs(
                ana_obj._mbarnet_computed_DGs[eng][key][0]
                - ana_obj.normalised_exper_val_dict[key][0]
            )
            dg_list.append(value)
            if value > 5:
                print(eng, key, value)
    except:
        pass

    net_ana_method_dict["method"].append(["MBARNet" for l in range(0, len(dg_list))])
    net_ana_method_dict["engine"].append(
        [eng_dict_name[eng] for l in range(0, len(dg_list))]
    )
    net_ana_method_dict["protein"].append([prot_dict_name[protein] for val in dg_list])
    net_ana_method_dict["value"].append([val for val in dg_list])

In [None]:
for stats_name, stats, func in zip(
    ["MAE (kcal/mol)", "Kendall's Rank", "R2"],
    ["MUE", "KTAU", "R2"],
    [
        ana_obj.calc_mae_engines,
        ana_obj.calc_kendalls_rank_engines,
        ana_obj.calc_r2_engines,
    ],
):
    df_dict = {}

    df_dict["cinnabar"] = {}
    df_dict["fen"] = {}
    df_dict["mbarnet"] = {}

    print("cinnabar")
    df_dict["cinnabar"][protein] = {}

    for eng in ana_obj.engines:
        try:
            df, df_err, df_ci = func(engines=[eng], pert_val="val", recalculate=False)
            df_dict["cinnabar"][protein][eng] = (
                df[eng]["experimental"],
                df_err[eng]["experimental"],
                df_ci[eng]["experimental"],
            )
            # print(stats_name, eng, df[eng]["experimental"], df_ci[eng]["experimental"])
        except:
            df_dict["cinnabar"][protein][eng] = (0, 0, (0, 0))

    df_dict["fen"][protein] = {}

    print("fen")
    for eng in ana_obj.engines:
        try:
            df, df_err, df_ci = ana_obj._get_stats_fwf(engines=[eng], statistic=stats)
            df_dict["fen"][protein][eng] = (
                df[eng]["experimental"],
                df_err[eng]["experimental"],
                df_ci[eng]["experimental"],
            )
            # print(stats_name, eng, df[eng]["experimental"], df_ci[eng]["experimental"])
        except:
            print("ooft")
            df_dict["fen"][protein][eng] = (0, 0, (0, 0))

    df_dict["mbarnet"][protein] = {}
    print("mbarnet")
    for eng in ana_obj.engines:
        try:
            print(
                protein,
                eng,
                len(ana_obj._perturbations_dict[eng]),
                len(ana_obj._mbarnet_computed_DGs[eng]),
            )
        except Exception as e:
            print(e)

    for eng in ana_obj.engines:
        # try:
        #     df, df_err, df_ci = ana_obj._get_stats_mbarnet(
        #         engines=[eng], statistic=stats
        #     )
        #     df_dict["mbarnet"][protein][eng] = (
        #         df[eng]["experimental"],
        #         df_err[eng]["experimental"],
        #         df_ci[eng]["experimental"]
        #     )
        #     print(stats_name, eng, df[eng]["experimental"], df_ci[eng]["experimental"])
        # except:
        #     print("oop")
        df_dict["mbarnet"][protein][eng] = (0, 0, (0, 0))

    print(stats_name)
    print("cinnabar")
    df = pd.DataFrame(df_dict["cinnabar"])
    df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
    print(df)
    print("fen")
    df = pd.DataFrame(df_dict["fen"])
    df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
    print(df)
    print("mbarnet")
    df = pd.DataFrame(df_dict["mbarnet"])
    df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
    print(df)

    # compare stats
    fig, ax = plt.subplots(figsize=(5, 5), sharex=True, sharey=True)
    plt.xlim = ()
    plt.ylim = ()

    df_cinnabar = (
        pd.DataFrame(df_dict["cinnabar"])
        .applymap(lambda x: x[0])
        .rename({protein: "cinnabar"}, axis=1)
    )
    df_fen = (
        pd.DataFrame(df_dict["fen"])
        .applymap(lambda x: x[0])
        .rename({protein: "fen"}, axis=1)
    )
    df_mbarnet = (
        pd.DataFrame(df_dict["mbarnet"])
        .applymap(lambda x: x[0])
        .rename({protein: "mbarnet"}, axis=1)
    )
    df = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_cinnabar, df_fen, df_mbarnet],
    )

    df_cinnabar = (
        pd.DataFrame(df_dict["cinnabar"])
        .applymap(lambda x: x[1])
        .rename({protein: "cinnabar"}, axis=1)
    )
    df_fen = (
        pd.DataFrame(df_dict["fen"])
        .applymap(lambda x: x[1])
        .rename({protein: "fen"}, axis=1)
    )
    df_mbarnet = (
        pd.DataFrame(df_dict["mbarnet"])
        .applymap(lambda x: x[1])
        .rename({protein: "mbarnet"}, axis=1)
    )
    df_err = reduce(
        lambda left, right: pd.merge(left, right, left_index=True, right_index=True),
        [df_cinnabar, df_fen, df_mbarnet],
    )

    # df_lower = df_err.applymap(lambda x: x[0])
    # df_upper = df_err.applymap(lambda x: x[1])
    # df_err = (df_upper - df_lower)/2

    df.plot(
        kind="bar",
        color=["purple", "orchid", "lavender"],
        yerr=df_err,
        title=prot_dict_name[protein],
        ax=ax,
        xlabel="MD Engine",
        ylabel=f"{stats_name}",
    )

In [None]:
plotting_dict = {
    "method": flatten_comprehension(net_ana_method_dict["method"]),
    "MD engine": flatten_comprehension(net_ana_method_dict["engine"]),
    "MAE dG (kcal/mol)": flatten_comprehension(net_ana_method_dict["value"]),
    "Protein": flatten_comprehension(net_ana_method_dict["protein"]),
}

fig, ax = plt.subplots(figsize=(3.25, 3.25), dpi=500)
df = pd.DataFrame(plotting_dict)
sns.boxplot(
    df,
    x="MD engine",
    y="MAE dG (kcal/mol)",
    hue="method",
    palette=["purple", "orchid", "lavender"],
    ax=ax,
)
# modify individual font size of elements
plt.legend(fontsize=10)
plt.xlabel("MD Engine", fontsize=10)
plt.ylabel("MAE ΔG (kcal/mol)", fontsize=10)
plt.tick_params(axis="both", which="major", labelsize=10)
# ax.set_ylim(top=12)

In [None]:
# comparing the different networks
df_dict = {}
for stats_name, stats, func in zip(
    ["MAE (kcal/mol)", "Kendall's Rank", "R2"], ["MUE", "KTAU", "R2"], ["x", "y", "z"]
):
    df_dict[protein] = {}

    for net_name in network_dict.keys():
        print(net_name)

        ana_obj = network_dict[net_name]["plain"]
        func_dict = {
            "MUE": ana_obj.calc_mae_engines,
            "KTAU": ana_obj.calc_kendalls_rank_engines,
            "R2": ana_obj.calc_r2_engines,
        }
        func = func_dict[stats]

        df_dict[protein][net_name] = {}

        df, df_err, df_ci = func(pert_val="val", recalculate=False)

        for eng in ana_obj.engines:
            try:
                df_dict[protein][net_name][eng] = (
                    df[eng]["experimental"],
                    df_err[eng]["experimental"],
                    df_ci[eng]["experimental"],
                )
                # print(stats_name, eng, df[eng]["experimental"], df_ci[eng]["experimental"])
            except:
                df_dict[protein][net_name][eng] = (0, 0, (0, 0))

        print(stats_name)
        print("cinnabar")
        df = pd.DataFrame(df_dict[protein])
        df = df.applymap(lambda x: f"{x[0]:.2f} ({x[2][0]:.2f},{x[2][1]:.2f})")
        print(df)

    # compare stats
    fig, ax = plt.subplots(figsize=(5, 5), sharex=True, sharey=True)
    plt.xlim = ()
    plt.ylim = ()

    df = pd.DataFrame(df_dict[protein]).applymap(lambda x: x[0])
    df_err = pd.DataFrame(df_dict[protein]).applymap(lambda x: x[1])
    # df_lower = df_err.applymap(lambda x: x[0])
    # df_upper = df_err.applymap(lambda x: x[1])
    # df_err = (df_upper - df_lower)/2

    col_dict = {
        "AMBER": plt.get_cmap("autumn"),
        "SOMD": plt.get_cmap("cool"),
        "GROMACS": plt.get_cmap("viridis"),
    }

    fig, ax = plt.subplots(figsize=(8, 6))

    df.plot(
        kind="bar",
        color=plt.cm.plasma(np.linspace(0, 1, len(df.columns))),
        yerr=df_err,
        title=prot_dict_name[protein],
        ax=ax,
        xlabel="Network generation method",
        ylabel=f"{stats_name}",
    )
    ax.legend(
        loc="center left",
        bbox_to_anchor=(1, 0.5),  # fancybox=True, shadow=True
    )