In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib
import seaborn as sns
import scipy.stats as stats
import pickle
import itertools
import random
import os
import ast
from library2_utils.color_scheme import cell_line_colors, cell_line_symbols
from library2_utils.transfer_functions import transfer_function, inverse_transfer
from library2_utils.mirna_combinations import get_combinations
from library2_utils.additive_model import add_mirna_expression
from library2_utils.design_utilities import tsi
from library2_utils.plotting_utilities import density_scatter, HandlerSize

# set the font size
plt.rcParams.update({'font.size': 7})
# set Helvetica globally
plt.rcParams['font.family'] = 'Helvetica'

cell_lines_subset = ["HEK293T", "HeLa", "SKNSH", "MCF7", "HUH7", "A549"]
cell_lines_rest = ["HaCaT", "JEG3", "Tera1", "PC3"]
cell_lines_measured = cell_lines_subset + cell_lines_rest

cell_lines_measured_UTR = [cell_line + "_3UTR" for cell_line in cell_lines_measured]
cell_lines_subset_UTR = [cell_line + "_3UTR" for cell_line in cell_lines_subset]

cell_lines_measured_pred = ["predicted_" + cell_line for cell_line in cell_lines_measured]
cell_lines_subset_pred = ["predicted_" + cell_line  for cell_line in cell_lines_subset]

cell_lines_all_target = ["target_" + cell_line for cell_line in cell_lines_measured]
cell_lines_subset_target = ["target_" + cell_line for cell_line in cell_lines_subset]

label_rename = {
    "HUH-7": "HUH7",
    "JEG-3": "JEG3",
    "Tera-1": "Tera1",
    "SK-N-SH": "SKNSH",
    "PC-3": "PC3",
}

# get mirbase
mirbase = pd.read_csv("../microrna_data/mirbase_extended.csv", index_col=0)

base_plot_folder = "../plots/9_design_eval"
# create folder if it does not exist
if not os.path.exists(base_plot_folder):
    os.makedirs(base_plot_folder)
output_folder = "../outputs/9_design_eval"
# create folder if it does not exist
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

key_shorthand = {
    "24_miRNA_full_subset_quality_AND4": "quality_subset_AND4",
    "25_miRNA_full_subset_quality_AND5": "quality_subset_AND5",
    "26_miRNA_full_subset_quality_AND6": "quality_subset_AND6",
    "27_miRNA_full_quality_AND4": "quality_all_AND4",
    "28_miRNA_full_quality_AND5": "quality_all_AND5",
    "29_miRNA_full_quality_AND6": "quality_all_AND6",
    "30_miRNA_AND4_subset_mse_designs": "mse_subset_AND4",
    "31_miRNA_AND5_subset_mse_designs": "mse_subset_AND5",
    "32_miRNA_AND6_subset_mse_designs": "mse_subset_AND6",
    "33_miRNA_AND4_all_mse_designs": "mse_all_AND4",
    "34_miRNA_AND5_all_mse_designs": "mse_all_AND5",
    "35_miRNA_AND6_all_mse_designs": "mse_all_AND6",
    "subset_quality": "subset_quality",
    "full_quality": "full_quality",
    "subset_mse": "subset_mse",
    "all_mse": "all_mse"
}

titles = {
    "9.1": "baseline model",
    "9.1_merged": "baseline model",
    "9.2": "updated model",
    "9.2_merged": "updated model",
    "9.3": "inverted transfer function",
    "9.3_merged": "inverted transfer function"
}

main_colormap = "rocket"
box_color = "deepskyblue"

In [2]:
data_dir_input = "../measured_data/2_normalized_log10"

# get the name of all files in "reference" folder
reference_files = os.listdir(data_dir_input)

# read them into a dictionary
reference_dict = {}
for reference_file in reference_files:
    if reference_file.endswith(".csv"):
        reference_dict[reference_file.split('.')[0]] = pd.read_csv(os.path.join(data_dir_input, reference_file), index_col=0)

# apply 10**x to all values for the measured cell lines
for key in reference_dict.keys():
    reference_dict[key][cell_lines_measured_UTR] = 10**reference_dict[key][cell_lines_measured_UTR]

In [3]:
quality_designs = {}
for key in reference_dict.keys():
    if "quality" in key:
        quality_designs[key] = reference_dict[key].copy()
        quality_designs[key]["target"] = quality_designs[key]["target"].apply(lambda x: ast.literal_eval(x))
        
mse_designs = {}
for key in reference_dict.keys():
    if "mse_designs" in key:
        mse_designs[key] = reference_dict[key].copy()
        mse_designs[key]["target"] = mse_designs[key]["target"].apply(lambda x: ast.literal_eval(x))
        mse_designs[key]["emphasis"] = mse_designs[key]["emphasis"].apply(lambda x: ast.literal_eval(x))

In [4]:
# merge the two dictionaries
all_designs = {**quality_designs, **mse_designs}

In [5]:
# Get expression dfs
expression_dfs = {}
for key in all_designs.keys():
    df = all_designs[key]
    expression_dfs[key] = df[df.columns[df.columns.str.startswith("miRNA")]]

# 9.1 - Use original predictions

In [6]:
plot_folder_prefix = "9.1"

In [7]:
# false positive-aware merging
df_combined = pd.read_csv("../microrna_data/3_output/Alles_Keller_combined_expression_with_crosstalk.csv", index_col=0)
df_crosstalk_filter = pd.read_csv("../microrna_data/3_output/Alles_Keller_combined_expression_wo_crosstalk_scaled.csv", index_col=0).dropna()

mirna_expression_with_crosstalk = df_combined
mirna_expression_wo_crosstalk = df_crosstalk_filter
mirna_potential_crosstalk = df_combined.index.difference(df_crosstalk_filter.index)

mirna_expression = mirna_expression_wo_crosstalk
used_mirna_name = "combined_dataset"

In [8]:
with open(f"../outputs/3_fitting/{used_mirna_name}/{used_mirna_name}_popt_wo_crosstalk.pkl", "rb") as f:
    popt = pickle.load(f)

# 9.2 - Recalculate using microRNA expression data

In [89]:
plot_folder_prefix = "9.2"

In [None]:
# filter expression and mirna dfs that are not in the expression data
for key in all_designs.keys():
    print(key)
    df_expression = expression_dfs[key].copy()
    # print(len(df_expression))
    # print(len(df_knockdown))
    
    for index, row in df_expression.iterrows():
        mirnas = row
        if not all(mirnas.isin(mirna_expression.index)):
            df_expression = df_expression.drop(index)
        else:
            expression_vals = mirna_expression.loc[mirnas, cell_lines_measured]
            if expression_vals.isnull().values.any():
                df_expression = df_expression.drop(index)
                
    expression_dfs[key] = df_expression
    # print(len(df_expression))
    # print(len(df_knockdown))

In [91]:
mirna_expression_lin = 10**mirna_expression
added_dfs = {key: add_mirna_expression(mirna_expression_lin, expression_dfs[key]) for key in expression_dfs.keys()}

# apply the transfer function to the added dfs
knockdown_from_added = {}
for key in added_dfs.keys():
    knockdown_from_added[key] = transfer_function(added_dfs[key], *popt)

In [92]:
for key in knockdown_from_added.keys():
    knockdown_from_added[key].columns = "predicted_" + knockdown_from_added[key].columns
    
# replace predicted_{cell_line} in all_designs with the knockdown_from_added values
for key in all_designs.keys():
    if "subset" in key:
        all_designs[key][cell_lines_subset_pred] = knockdown_from_added[key][cell_lines_subset_pred]
    else:
        all_designs[key][cell_lines_measured_pred] = knockdown_from_added[key][cell_lines_measured_pred]
        
    # drop rows containing NaNs
    all_designs[key] = all_designs[key].dropna()

# 9.3 - Recalculate using the measured knockdown

In [93]:
plot_folder_prefix = "9.3"

In [94]:
used_mirna_name = "combined_dataset"

with open(f"../outputs/3_fitting/{used_mirna_name}/{used_mirna_name}_popt.pkl", "rb") as f:
    popt = pickle.load(f)

In [95]:
# get the actual expression values
single_knockdown = reference_dict["1_mirna_full_single_high_conf"]
# make miRNA the index
single_knockdown = single_knockdown.set_index("miRNA1")
# drop all columns that do not contain "3UTR"
single_knockdown = single_knockdown.filter(regex="_3UTR")
# drop the _log10 suffix
single_knockdown.columns = single_knockdown.columns.str.replace("_3UTR", "")
# make all values larger than 1 equal to 1
single_knockdown[single_knockdown > 1] = 1
# get the inverse of the expression
mirna_expr_fr_knockdown = inverse_transfer(single_knockdown, *popt)

In [None]:
# filter all mirnas that are not in mirna_expr_fr_knockdown from expression_df

meausured_mirnas = set(mirna_expr_fr_knockdown.index)
for key in expression_dfs.keys():
    df = expression_dfs[key]
    print(key, len(expression_dfs[key]))
    # check all columns containing "miRNA" for whether they are in mirna_expr_fr_knockdown
    miRNA_columns = [col for col in df.columns if col.startswith("miRNA")]
    # are there any miRNAs in the dataframe that are not in mirna_expr_fr_knockdown?
    mirna_list = [df[mirna_column].to_list() for mirna_column in miRNA_columns]
    mirna_list = set(list(itertools.chain.from_iterable(mirna_list)))
    missing_mirnas = mirna_list - meausured_mirnas
    # find the index of the designs with missing miRNAs
    missing_indices = df.index[df[miRNA_columns].isin(missing_mirnas).any(axis=1)]
    # drop the designs with missing miRNAs from both the knockdown and the expression dataframe
    df = df.drop(missing_indices)
    expression_dfs[key] = df
    print(key, len(expression_dfs[key]))

In [None]:
added_dfs = {key: add_mirna_expression(mirna_expr_fr_knockdown, expression_dfs[key]) for key in expression_dfs.keys()}

# apply the transfer function to the added dfs
knockdown_from_added = {}
for key in added_dfs.keys():
    knockdown_from_added[key] = transfer_function(added_dfs[key], *popt)
    
for key in knockdown_from_added.keys():
    knockdown_from_added[key].columns = "predicted_" + knockdown_from_added[key].columns
    
# replace predicted_{cell_line} in all_designs with the knockdown_from_added values
for key in all_designs.keys():
    if "subset" in key:
        all_designs[key][cell_lines_subset_pred] = knockdown_from_added[key][cell_lines_subset_pred]
    else:
        all_designs[key][cell_lines_measured_pred] = knockdown_from_added[key][cell_lines_measured_pred]
        
    # drop rows containing NaNs
    all_designs[key] = all_designs[key].dropna()

# 9.4 - Process design data for plotting

## 9.4.1 - Adjust for values larger than 1

In [None]:
dropped_designs = {}
for key in all_designs.keys():
    # get the linear df
    df = all_designs[key].copy()
    
    if "subset" in key:
        curr_cell_meas = cell_lines_subset_UTR
    else:
        curr_cell_meas = cell_lines_measured_UTR
        
    # get the maximum value across cell lines
    curr_max = df[curr_cell_meas].max(axis=1)
    mask = curr_max > 1.5
    # how many are there?
    print(key, mask.sum())
    dropped_designs[key] = df[mask].index
    # remove these
    df = df[~mask]

    # replace all other values larger than 1 with 1
    # this is important so that we later don't punish stable designs with values over 1
    df[curr_cell_meas] = np.where(df[curr_cell_meas] > 1, 1, df[curr_cell_meas])
    
    all_designs[key] = df

In [None]:
total_designs = [len(all_designs[key]) for key in all_designs.keys()]
print(sum(total_designs))

## 9.4.2 - Merge designs for different miRNA numbers

In [11]:
merge_dict = {
"subset_quality":["24_miRNA_full_subset_quality_AND4","25_miRNA_full_subset_quality_AND5","26_miRNA_full_subset_quality_AND6"],
"full_quality":["27_miRNA_full_quality_AND4","28_miRNA_full_quality_AND5","29_miRNA_full_quality_AND6"],
"subset_mse":["30_miRNA_AND4_subset_mse_designs","31_miRNA_AND5_subset_mse_designs","32_miRNA_AND6_subset_mse_designs"],
"all_mse":["33_miRNA_AND4_all_mse_designs","34_miRNA_AND5_all_mse_designs","35_miRNA_AND6_all_mse_designs"],
}

In [12]:
all_designs_merged = {}
for key in merge_dict:
    dfs = [all_designs[curr_key].copy() for curr_key in merge_dict[key]]
    curr_df = pd.concat(dfs)
    
    # also merge range_target, random_target_log, and range_target_log in to random_target
    if "type" in curr_df.columns:
        curr_df.loc[curr_df["type"] == "range_target", "type"] = "random_target"
        curr_df.loc[curr_df["type"] == "range_target_log", "type"] = "random_target"
        curr_df.loc[curr_df["type"] == "random_target_log", "type"] = "random_target"
        
    all_designs_merged[key] = curr_df
    
all_designs = all_designs_merged
plot_folder_prefix = plot_folder_prefix + "_merged"

### Option 1: Do not merge quality and mse designs

In [13]:
mse_designs = {}
quality_designs = {}
for key in all_designs.keys():
    if "quality" in key:
        quality_designs[key] = all_designs[key].copy()
    if "mse" in key:
        mse_designs[key] = all_designs[key].copy()

In [None]:
mse_designs.keys()

#### Option 2: Merge quality and mse designs

In [15]:
quality_mse_associations = {
    "24_miRNA_full_subset_quality_AND4": '30_miRNA_AND4_subset_mse_designs',
    "25_miRNA_full_subset_quality_AND5": '31_miRNA_AND5_subset_mse_designs',
    "26_miRNA_full_subset_quality_AND6": '32_miRNA_AND6_subset_mse_designs',
    "27_miRNA_full_quality_AND4": '33_miRNA_AND4_all_mse_designs',
    "28_miRNA_full_quality_AND5": '34_miRNA_AND5_all_mse_designs',
    "29_miRNA_full_quality_AND6": '35_miRNA_AND6_all_mse_designs',
    "subset_quality": "subset_mse",
    "full_quality": "all_mse"
}

def generate_emphasis_dict(target_dict, max_emphasis):
    emphasis = {}
    for key in target_dict.keys():
        if target_dict[key] == 1:
            emphasis[key] = max_emphasis
        else:
            emphasis[key] = 1
    return emphasis

In [16]:
def extract_target_cell_lines(cell_line_dict, target_value=0):
    result = [cell_line for cell_line, value in cell_line_dict.items() if value == target_value]
    # apply rename dictionary
    result = [label_rename[cell_line] if cell_line in label_rename else cell_line for cell_line in result]
    # make result a tuple
    # result = tuple(result)
    # make it a clean string
    result = ", ".join(result)
    return result

def get_cell_line_to_row_mapping(df):
    """The input should be a df that only contains the cell_lines or a list."""
    cell_line_to_row = {}
    i = 0
    if type(df) == list:
        for cell_line in df:
            cell_line_to_row[cell_line] = i
            i += 1
        return cell_line_to_row
    elif type(df) == pd.DataFrame:
        for cell_line in df.columns:
            cell_line_to_row[cell_line] = i
            i += 1
    return cell_line_to_row

def get_figsize(key, design_type, best=False):
    if best:
        if "subset" in key and "single" in design_type:
            figsize=(2.0, 1.1)
        elif "all" in key and "single" in design_type:
            figsize=(2.4, 1.6)
        elif "subset" in key:
            figsize=(2.4, 1.1)
        elif "all" in key:
            figsize=(2.4, 1.6)
        else:
            raise ValueError("The key and design type do not match.")
    else:
        if "subset" in key and "single" in design_type:
            figsize=(2.2, 1.2)
        elif "all" in key and "single" in design_type:
            figsize=(2.8, 1.6)
        elif "subset" in key:
            figsize=(2.8, 1.4)
        elif "all" in key:
            figsize=(2.8, 1.6)
        else:
            raise ValueError("The key and design type do not match.")
    return figsize

In [None]:
quality_merged = True
for key in quality_designs.keys():
    df_quality = quality_designs[key].copy()
    df_mse = mse_designs[quality_mse_associations[key]].copy()
    print(key, len(df_mse))

    df_quality["type"] = "single_active"
    df_single_active = df_mse[df_mse["type"] == "single_active"]
    emphasis_max = max(df_single_active["emphasis"].iloc[0].values())
    df_quality["emphasis"] = df_quality["target"].apply(lambda x: generate_emphasis_dict(x, emphasis_max))
    df_quality["sublabel"] = df_single_active["sublabel"].iloc[0]

    # append df_quality to df_mse
    df_mse = pd.concat([df_mse, df_quality])

    # -------------- reorder the target cell lines after adding the quality df --------------
    # get the target cell line
    df_single_active = df_mse[df_mse["type"] == "single_active"].copy()
    df_single_active.loc[:, "extracted_target"] = df_single_active.loc[:, "target"].apply(lambda x: extract_target_cell_lines(x, 1))

    if "subset" in key:
        curr_cell_line_set = cell_lines_subset
    else:
        curr_cell_line_set = cell_lines_measured
    
    # reorder the index such that the order of the target cell lines is the same as curr_cell_line_set
    # create a mapping of cell line names to their index in curr_cell_line_set for sorting
    cell_line_order = {cell_line: idx for idx, cell_line in enumerate(curr_cell_line_set)}
    # add a new column to df_mse for sorting
    df_single_active['sort_key'] = df_single_active['extracted_target'].map(cell_line_order)
    # sort df_mse by this new sort_key
    df_single_active = df_single_active.sort_values(by='sort_key')
    # remove the 'sort_key' and 'extracted_target' columns
    df_single_active = df_single_active.drop(columns=['sort_key', 'extracted_target'])
    
    # drop the unmodified df_single_active from df_mse
    df_mse = df_mse[df_mse["type"] != "single_active"]
    # then, add it back
    df_mse = pd.concat([df_mse, df_single_active])

    mse_designs[quality_mse_associations[key]] = df_mse
    print(key, len(df_mse))

## 9.4.3 - microRNA usage statistics (Fig. S24)

In [20]:
which = "all_cell_types"
full_mse_designs = mse_designs["all_mse"].copy()

In [21]:
AND4 = full_mse_designs[full_mse_designs.index.str.contains("AND4")]
AND5 = full_mse_designs[full_mse_designs.index.str.contains("AND5")]
AND6 = full_mse_designs[full_mse_designs.index.str.contains("AND6")]

AND_counts = [len(AND4), len(AND5), len(AND6)]

# make it a dataframe
AND_counts_df = pd.DataFrame(AND_counts, columns=["count"], index=["4", "5", "6"])

In [None]:
# create a vertical histogram of the design types
fig, ax = plt.subplots(figsize=(1.5, 1))
AND_counts_df.plot(kind="barh", color=box_color, ax=ax, legend=False)
ax.set_xlabel("count")
ax.set_ylabel("microRNA targets")
# disable the upper and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig(os.path.join(base_plot_folder, f"{which}_AND_counts.svg"), bbox_inches="tight")

In [23]:
design_type_counts = full_mse_designs["type"].value_counts()

In [24]:
# rename the design types
design_type_counts = design_type_counts.rename({"single_active": "include 1", "single_target": "exclude 1", "random_target": "graduated",
                                                "double_active": "include 2", "double_target": "exclude 2", "three_target": "exclude 3"})
# add the binary design counts
design_type_counts["binary"] = design_type_counts.drop("graduated").sum()

# reorder the design types
design_type_counts = design_type_counts[["exclude 1", "exclude 2", "exclude 3", "include 1", "include 2", "graduated"]]

design_type_counts = design_type_counts.iloc[::-1]

In [None]:
# create a vertical histogram of the design types
fig, ax = plt.subplots(figsize=(1.5, 1.5))
design_type_counts.plot(kind="barh", color=box_color, ax=ax)
ax.set_xlabel("count")
ax.set_ylabel("")
# disable the upper and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.savefig(os.path.join(base_plot_folder, f"{which}_design_type_counts.svg"), bbox_inches="tight")

In [26]:
# count each miRNA
mirna_columns = full_mse_designs.columns[full_mse_designs.columns.str.contains("miRNA")]
binary_designs = full_mse_designs[full_mse_designs["type"] != "random_target"]
graduated_designs = full_mse_designs[full_mse_designs["type"] == "random_target"]
mirna_counts = full_mse_designs[mirna_columns].melt().value.value_counts()
mirna_counts_binary = binary_designs[mirna_columns].melt().value.value_counts()
mirna_counts_graduated = graduated_designs[mirna_columns].melt().value.value_counts()

In [None]:
# create a vertical histogram of the design types
plt.figure(figsize=(2, 1.6))
plt.hist(np.log10(mirna_counts), bins=20, color=box_color, alpha=1)
plt.xlabel(r"log$_{10}$(uses per microRNA)")
plt.ylabel("number of microRNAs")
plt.savefig(os.path.join(base_plot_folder, f"{which}_mirna_counts.svg"), bbox_inches="tight")

In [28]:
mirna_expression_threshold = mirna_expression_with_crosstalk.copy()
mirna_expression_threshold[mirna_expression_threshold < 2] = 2

mirbase_high = mirbase[mirbase["confidence"] == "high"]
mirna_expression_threshold = mirna_expression_threshold[mirna_expression_threshold.index.isin(mirbase_high.index)]

In [None]:
mirna_expression_tsi = tsi(mirna_expression_threshold.to_numpy())
mirna_expression_tsi = pd.Series(mirna_expression_tsi, index=mirna_expression_threshold.index, name="tsi")
max_expression = mirna_expression_threshold.max(axis = 1)

common_index = mirna_expression_tsi.index.intersection(mirna_counts.index)
uncommon_index = mirna_expression_tsi.index.difference(mirna_counts.index)
# add the uncommon index to mirna_counts
mirna_counts = mirna_counts.astype(float)
mirna_counts = mirna_counts.append(pd.Series(0.5, index=uncommon_index))

total_index = mirna_expression_tsi.index.intersection(mirna_counts.index)

In [None]:
plt.figure(figsize=(3, 1.8))
plt.scatter(mirna_expression_tsi[common_index], np.log10(mirna_counts[common_index]), color="tab:blue", s=5, label="used in designs")
plt.scatter(mirna_expression_tsi[uncommon_index], np.log10(mirna_counts[uncommon_index]), color="tab:red", s=5, label="not used in designs")
r2 = stats.pearsonr(mirna_expression_tsi[common_index], np.log10(mirna_counts[common_index]))[0]**2
plt.title(f"r$^2$ = {r2:.2f}")
plt.xlabel("miRNA tsi")
plt.ylabel(r"log$_{10}$(uses per microRNA)")
plt.legend(loc=[1.01, 0.5])
plt.tight_layout()
plt.savefig(os.path.join(base_plot_folder, f"{which}_mirna_counts_vs_tsi.svg"), bbox_inches="tight", dpi=600)

In [None]:
plt.figure(figsize=(3, 1.8))
plt.scatter(max_expression[common_index], np.log10(mirna_counts[common_index]), color="tab:blue", s=5, label="used in designs")
plt.scatter(max_expression[uncommon_index], np.log10(mirna_counts[uncommon_index]), color="tab:red", s=5, label="not used in designs")
r2 = stats.pearsonr(max_expression[common_index], np.log10(mirna_counts[common_index]))[0]**2
plt.title(f"r$^2$ = {r2:.2f}")
plt.xlabel("maximum expression across cell lines")
plt.ylabel(r"log$_{10}$(uses per microRNA)")
plt.legend(loc=[1.01, 0.5])
plt.tight_layout()
plt.savefig(os.path.join(base_plot_folder, f"{which}_mirna_counts_vs_max.svg"), bbox_inches="tight", dpi=600)

In [None]:
plt.figure(figsize=(3, 1.6))
plt.scatter(max_expression[common_index], mirna_expression_tsi[common_index], color="tab:blue", s=5, label="used in designs", rasterized=True)
plt.scatter(max_expression[uncommon_index], mirna_expression_tsi[uncommon_index], color="tab:red", s=5, label="not used in designs", rasterized=True)
# r2 = stats.pearsonr(max_expression[common_index], np.log10(mirna_counts[common_index]))[0]**2
# plt.title(f"r$^2$ = {r2:.2f}")
plt.xlabel("maximum expression across cell lines")
plt.ylabel("miRNA tsi")
plt.legend(loc=[1.01, 0.5])
plt.tight_layout()
plt.savefig(os.path.join(base_plot_folder, f"{which}_max_vs_tsi.svg"), bbox_inches="tight", dpi=600)

# 9.5  - General performance analysis

In [None]:
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
print(plot_folder)

In [19]:
def calculate_mse(df_measured, mse_target, loss_emphasis=None):
    """This function calculates the mean squared error of a design for a given target distribution.
    df has expression values for different cell lines in the columns.
    mse_target is a dictionary with the target expression of each cell line.
    loss_emphasis is a dictionary with the loss emphasis for each cell line."""
    # if loss emphasis is not given, set it to 1 for all cell lines
    if loss_emphasis is None:
        loss_emphasis = {cell_line: 1 for cell_line in df_measured.columns}
    
    mse = (df_measured - mse_target)**2
    mse = mse.mul(loss_emphasis, axis=1)
    mse = mse.mean(axis=1)

    return mse

def calculate_mse_log(df_measured, mse_target, loss_emphasis=None):
    """This function calculates the mean squared error of a design for a given target distribution."""
    if loss_emphasis is None:
        loss_emphasis = {cell_line: 1 for cell_line in df_measured.columns}
    
    # set all values smaller than 0.05 to 0.05
    df_measured = df_measured.where(df_measured >= 0.05, 0.05)
    mse_target = np.where(mse_target < 0.05, 0.05, mse_target)
    
    df_measured = np.log10(df_measured)
    mse_target = np.log10(mse_target)
    
    log_mse = (df_measured - mse_target)**2
    log_mse = log_mse.mul(loss_emphasis, axis=1)
    log_mse = log_mse.mean(axis=1)
    
    return log_mse

def calculate_quality(weighted_mse):
    return (1/weighted_mse) + 1

## 9.5.1 - Calculate quality statistics

At this point, stability values are linear.

In [None]:
quality_x = {}
rmse_x = {}
rmse_log_x = {}
quality_y = {}
rmse_y = {}
rmse_log_y  = {}

for key in mse_designs.keys():
    print(key)
    df = mse_designs[key].copy()
    
    if "subset" in key:
        df_measured = df[cell_lines_subset_UTR]
        df_predicted = df[cell_lines_subset_pred]
    else:
        df_measured = df[cell_lines_measured_UTR]
        df_predicted = df[cell_lines_measured_pred]
        
    # drop the suffix
    df_measured.columns = df_measured.columns.str.split('_').str[0]
    df_predicted.columns = df_predicted.columns.str.split('_').str[1]

    # -----------------------------------------------------------------------
    # get the target and emphasis columns and calculate the mse to the target
    df_target = pd.DataFrame(df["target"].to_list())
    df_target.index = df.index

    # rename columns according to label_rename
    df_target.rename(columns=label_rename, inplace=True)

    df_emphasis = pd.DataFrame(df["emphasis"].to_list())
    df_emphasis.index = df.index

    # rename columns according to label_rename
    df_emphasis.rename(columns=label_rename, inplace=True)
    
    # the weighted mse between the measured and the target
    # this was used for the fitness function
    mse_weighted_measured = calculate_mse(df_measured, df_target, df_emphasis)
    # the non-weighted mse between the measured and the target
    # this is a better measure when comparing different designs types
    mse_measured = calculate_mse(df_measured, df_target)
    mse_meausured_log = calculate_mse_log(df_measured, df_target)
    
    mse_designs[key]["measured_quality"] = calculate_quality(mse_weighted_measured)
    mse_designs[key]["measured_mse"] = mse_measured
    mse_designs[key]["measured_mse_log"] = mse_meausured_log
    
    # -----------------------------------------------------------------------
    # get the mse between the prediction and the target
    mse_weighted_predicted = calculate_mse(df_predicted, df_target, df_emphasis)
    mse_predicted = calculate_mse(df_predicted, df_target)
    mse_predicted_log = calculate_mse_log(df_predicted, df_target)
    
    mse_designs[key]["predicted_quality"] = calculate_quality(mse_weighted_predicted)
    mse_designs[key]["predicted_mse"] = mse_predicted
    mse_designs[key]["predicted_mse_log"] = mse_predicted_log
    
    # -----------------------------------------------------------------------
    # get the mse between the prediction and the measured
    mse_accuracy = calculate_mse(df_predicted, df_measured, df_emphasis)
    quality_accuracy = calculate_quality(mse_accuracy)
    mse_designs[key]["accuracy_quality"] = quality_accuracy
    
    # -----------------------------------------------------------------------
    # store target values in the dataframe
    for column in df_target.columns:
        mse_designs[key]["target_" + column] = df_target[column]
        
    # make sure predicted and measured values are linear, not log
    assert (df_predicted.min().min() >= 0) and (df_measured.min().min() >= 0)
        
    # store the difference between the measured and predicted and target results
    for column in df_measured.columns:
        mse_designs[key]["diff_to_pred_" + column] = df_measured[column] - df_predicted[column]
        mse_designs[key]["diff_to_target_" + column] = df_measured[column] - df_target[column]
    
    # -----------------------------------------------------------------------
    # use non-weighted mse for the quality comparison between different designs
    quality_x[key] = np.log10(1/mse_predicted.values + 1) 
    quality_y[key] = np.log10(1/mse_measured.values + 1)
    
    rmse_x[key] = np.sqrt(mse_predicted.values)
    rmse_y[key] = np.sqrt(mse_measured.values)
    
    rmse_log_x[key] = np.sqrt(mse_predicted_log.values)
    rmse_log_y[key] = np.sqrt(mse_meausured_log.values)

#### Predicted vs measured quality

In [110]:
%%capture output
for key in quality_x.keys():
    if "9.3" in plot_folder_prefix:
        plt.figure(figsize=(1.6, 1.3))
    else:
        plt.figure(figsize=(0.7, 0.8))
    plt.scatter(quality_x[key], quality_y[key], s=2, alpha=1, edgecolor="none", rasterized=True, color="tab:blue")
    plt.plot([0, 3], [0, 3], color="black", linestyle="--", linewidth=1)
    
    r2 = np.corrcoef(quality_x[key], quality_y[key])[0,1]**2
    rmse = np.sqrt(np.mean((quality_x[key] - quality_y[key])**2))
    
    if "9.3" in plot_folder_prefix:
        plt.text(1.7, 0.15, f"rmsd = {rmse:.2f}", fontsize=6.5)
    else:
        plt.text(0.1, 2.3, f"{rmse:.2f}", fontsize=6.5)
    plt.title(titles[plot_folder_prefix], fontsize=7)
    plt.xlim([0, 3])
    plt.ylim([0, 3])
    if not "9.3" in plot_folder_prefix:
        plt.xticks([])
        plt.yticks([])
    if "9.3" in plot_folder_prefix:
        plt.xlabel("Predicted design quality")
        plt.ylabel("Measured design quality")
    plt.tight_layout()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.1_{key}_design_quality_scatter.{format}"), bbox_inches="tight", dpi=300)

In [None]:
for key in quality_x.keys():
    print(key)
    if "9.2" in plot_folder_prefix:
        plt.figure(figsize=(1.75, 1.3))
    else:
        plt.figure(figsize=(1.0, 1.0))
    plt.scatter(rmse_y[key], rmse_x[key], s=1, alpha=1, edgecolor="none", rasterized=True, color="tab:blue")
    #density_scatter(rmse_x[key], rmse_y[key], ax=ax, sort=True, s=1, alpha=1, rasterized=True)
    plt.plot([0, 0.8], [0, 0.8], color="black", linestyle="--", linewidth=1)
    
    r2 = np.corrcoef(rmse_x[key], rmse_y[key])[0,1]**2
    rmse = np.sqrt(np.mean((rmse_x[key] - rmse_y[key])**2))
    
    print(r2)
    
    if "9.2" in plot_folder_prefix:
        plt.text(0.05, 0.65, f"rmsd = {rmse:.2f}", fontsize=6.5)
    else:
        plt.text(0.05, 0.65, f"{rmse:.2f}", fontsize=6.5)
    plt.title(titles[plot_folder_prefix], fontsize=7)
    plt.xlim([0, 0.8])
    plt.ylim([0, 0.8])
    
    if not "9.2" in plot_folder_prefix:
        plt.xticks([0, 0.4, 0.8])
        plt.yticks([0, 0.4, 0.8])
    if "9.2" in plot_folder_prefix:
        plt.xticks([0, 0.4, 0.8])
        plt.yticks([0, 0.4, 0.8])
        plt.xlabel("Measured design rmsd")
        plt.ylabel("Predicted design rmsd")
    plt.tight_layout()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.1_{key}_mse_scatter.{format}"), bbox_inches="tight", dpi=600)

In [None]:
mse_designs_updated["all_mse"]["predicted_mse"].idxmax()

#### Does this change a lot with the number of targets?

In [83]:
# plot a histogram of design qualities
colors = ["tab:blue", "tab:orange", "tab:green"]
labels = ["4 targets", "5 targets", "6 targets"]
if not 'merged' in plot_folder_prefix:
    plt.figure(figsize=(2, 1.5))
    i = 0
    for key in quality_x.keys():
        if "subset" in key:
            continue
        plt.hist(quality_y[key], bins=np.arange(0,3,0.1), color=colors[i], alpha=0.5,
                 label=f"{labels[i]}, median = {np.median(quality_y[key]):.2f}")
        i += 1
    plt.xlabel("measured design quality")
    plt.ylabel("count")
    plt.tight_layout()
    plt.legend()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.1_quality_hist_measured.{format}"), bbox_inches="tight", dpi=300)

if not 'merged' in plot_folder_prefix:
    plt.figure(figsize=(2, 1.5))
    i = 0
    for key in quality_x.keys():
        if "subset" in key:
            continue
        plt.hist(quality_x[key], bins=np.arange(0,3,0.1), color=colors[i], alpha=0.5,
                 label=f"{labels[i]}, median = {np.median(quality_x[key]):.2f}")
        i += 1
    plt.xlabel("predicted design quality")
    plt.ylabel("count")
    plt.tight_layout()
    plt.legend()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.1_quality_hist_predicted.{format}"), bbox_inches="tight", dpi=300)

### Save the calculated quality statistics

In [None]:
if plot_folder_prefix == "9.1" or plot_folder_prefix == "9.1_merged":
    mse_designs_original = mse_designs.copy()
    print("9.1")
if plot_folder_prefix == "9.2" or plot_folder_prefix == "9.2_merged":
    mse_designs_updated = mse_designs.copy()
    print("9.2")
if plot_folder_prefix == "9.3" or plot_folder_prefix == "9.3_merged":
    mse_designs_inverted = mse_designs.copy()
    print("9.3")

In [21]:
# save the dfs to files
output_folder = "../outputs/9_designs"
# create it if it doesn't exist
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

if "mse_designs_original" in locals() and "mse_designs_updated" in locals() and "mse_designs_inverted" in locals():
    print("saving")
    with open(os.path.join(output_folder, "mse_designs_original.pkl"), "wb") as f:
        pickle.dump(mse_designs_original, f)
    with open(os.path.join(output_folder, "mse_designs_updated.pkl"), "wb") as f:
        pickle.dump(mse_designs_updated, f)
    with open(os.path.join(output_folder, "mse_designs_inverted.pkl"), "wb") as f:
        pickle.dump(mse_designs_inverted, f)
else:
    with open(os.path.join(output_folder, "mse_designs_original.pkl"), "rb") as f:
        mse_designs_original = pickle.load(f)
    with open(os.path.join(output_folder, "mse_designs_updated.pkl"), "rb") as f:
        mse_designs_updated = pickle.load(f)
    with open(os.path.join(output_folder, "mse_designs_inverted.pkl"), "rb") as f:
        mse_designs_inverted = pickle.load(f)

## 9.5.2 - Plot on a per design type basis

In [65]:
for prefix, curr_mse_designs in zip(["9.1_merged", "9.2_merged", "9.3_merged"], [mse_designs_original, mse_designs_updated, mse_designs_inverted]):
    plot_folder = os.path.join(base_plot_folder, f"cross_model_per_design")
    if not os.path.exists(plot_folder):
        os.makedirs(plot_folder)

    quality_x = {}
    quality_y = {}
    rmse_x = {}
    rmse_y = {}
    rmse_log_x = {}
    rmse_log_y = {}

    for key in curr_mse_designs.keys():
        if "subset" in key:
            continue
        
        df = curr_mse_designs[key].copy()
        design_types = df["type"].unique()
        for design_type in design_types:
            df = curr_mse_designs[key].copy()
            df = df[df["type"] == design_type]

            # Use non-weighted design qualities for easier comparison
            # quality_predicted = df["predicted_quality"]
            # quality_measured = df["measured_quality"]
            quality_predicted = np.log10(1/df["predicted_mse"]+1)
            quality_measured = np.log10(1/df["measured_mse"]+1)

            quality_x[key+"_"+design_type] = quality_predicted
            quality_y[key+"_"+design_type] = quality_measured
            
            rmse_x[key+"_"+design_type] = np.sqrt(df["predicted_mse"])
            rmse_y[key+"_"+design_type] = np.sqrt(df["measured_mse"])
            rmse_log_x[key+"_"+design_type] = np.sqrt(df["predicted_mse_log"])
            rmse_log_y[key+"_"+design_type] = np.sqrt(df["measured_mse_log"])

    # flatten them
    quality_x_flatten = list(itertools.chain.from_iterable(quality_x.values()))
    quality_y_flatten = list(itertools.chain.from_iterable(quality_y.values()))

    if not "quality_x_by_model" in locals() or not "rmse_x_by_model" in locals() or not "rmse_log_x_by_model" in locals():
        quality_x_by_model = {}
        quality_y_by_model = {}
        rmse_x_by_model = {}
        rmse_y_by_model = {}
        rmse_log_x_by_model = {}
        rmse_log_y_by_model = {}
    
    quality_x_by_model[prefix] = quality_x
    quality_y_by_model[prefix] = quality_y
    rmse_x_by_model[prefix] = rmse_x
    rmse_y_by_model[prefix] = rmse_y
    rmse_log_x_by_model[prefix] = rmse_log_x
    rmse_log_y_by_model[prefix] = rmse_log_y

In [57]:
%%capture output
prefixes = ["9.1_merged", "9.2_merged", "9.3_merged"]
colors = ["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown"]
markers = ["o", "s", "D", "v", "^", "P"]
binary_design_titles = {
    "all_mse_double_active": "two active",
    "all_mse_double_target": "two inactive",
    "all_mse_single_active": "one active",
    "all_mse_single_target": "one inactive",
    "all_mse_three_target": "three inactive",
}

for prefix in prefixes:
    sc_plots = []
    quality_x_current = rmse_log_x_by_model[prefix]
    quality_y_current = rmse_log_y_by_model[prefix]
    x_flatten = []
    y_flatten = []
    
    plt.figure(figsize=(2, 1.7))
    for i, key in enumerate(quality_x_current.keys()):
        if "random" in key:
            continue
        
        sc = plt.scatter(quality_y_current[key], quality_x_current[key], s=2, label=binary_design_titles[key],
                    color=colors[i], marker=markers[i], alpha=1, edgecolors="none")
        x_flatten += quality_x_current[key].tolist()
        y_flatten += quality_y_current[key].tolist()
        sc_plots.append(sc)
        
    x_flatten = np.array(x_flatten)
    y_flatten = np.array(y_flatten)
    r2 = np.corrcoef(x_flatten, y_flatten)[0,1]**2
    rmse = np.sqrt(np.mean((x_flatten - y_flatten)**2))
    
    plt.text(0.1, 1, r"rmsd = "+f"{rmse:.2f}")
    plt.text(0.1, 0.85, r"r$^2$ = "+f"{r2:.2f}")
    plt.xlim([0, 1.2])
    plt.ylim([0, 1.2])
    # plt.xticks([1, 2, 3])
    # plt.yticks([1, 2, 3])
    plt.plot([0, 1.2], [0, 1.2], color="black", linestyle="--", linewidth=1)
    # reorder the legend according to binary_design_titles
    plt.legend(handler_map={sc: HandlerSize(12) for sc in sc_plots}
    plt.title(titles[prefix])
    plt.xlabel(r"Predicted rmsd(log$_{10}$(stability))")
    plt.ylabel(r"Measured rmsd(log$_{10}$(stability))")
    plt.tight_layout()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{prefix}_binary_log_pred_vs_meas.{format}"), bbox_inches="tight", dpi=300)

In [54]:
%%capture output
prefixes = ["9.1_merged", "9.2_merged", "9.3_merged"]

for prefix in prefixes:
    quality_x_current = rmse_x_by_model[prefix]
    quality_y_current = rmse_y_by_model[prefix]
    x_flatten = []
    y_flatten = []
    
    plt.figure(figsize=(2, 1.7))
    for i, key in enumerate(quality_x_current.keys()):
        if "random" in key:
            continue
        
        plt.scatter(quality_y_current[key], quality_x_current[key], s=2, label=key, color=colors[i], marker=markers[i], alpha=1, edgecolors="none")
        x_flatten += quality_x_current[key].tolist()
        y_flatten += quality_y_current[key].tolist()
    
    x_flatten = np.array(x_flatten)
    y_flatten = np.array(y_flatten)
    r2 = np.corrcoef(x_flatten, y_flatten)[0,1]**2
    rmse = np.sqrt(np.mean((x_flatten - y_flatten)**2))
    plt.plot([0, 0.8], [0, 0.8], color="black", linestyle="--", linewidth=1)
    
    plt.text(0.1, 0.7, r"rmsd = "+f"{rmse:.2f}")
    plt.text(0.1, 0.6, r"r$^2$ = "+f"{r2:.2f}")
    # plt.xlim([0.5, 3])
    # plt.ylim([0.5, 3])
    # plt.xticks([1, 2, 3])
    # plt.yticks([1, 2, 3])
    #plt.legend()
    plt.title(titles[prefix])
    plt.xlabel(r"Measured rmsd(stability)")
    plt.ylabel(r"Predicted rmsd(stability)")
    plt.tight_layout()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{prefix}_binary_linear_pred_vs_meas.{format}"), bbox_inches="tight", dpi=300)

# ----------------------------------------------------------------------------------------------

In [77]:
%%capture output
prefixes = ["9.1_merged", "9.2_merged", "9.3_merged"]
for prefix in prefixes:
    quality_x_current = rmse_x_by_model[prefix]
    quality_y_current = rmse_y_by_model[prefix]
    for key in quality_x_current.keys():
        if not "random" in key:
            continue
        fig, ax = plt.subplots(figsize=(2, 1.5))
        
        ax,sc = density_scatter(quality_x_current[key], quality_y_current[key],
            ax=ax, s=3, cmap='viridis', zorder=2)
        #plt.plot([0.25, 3], [0.25, 3], color="black", linestyle="--", linewidth=1.5, zorder=1)
        plt.plot([-0.5, 1.5], [-0.5, 1.5], color="black", linestyle="--", linewidth=1.5, zorder=1)

        r2 = np.corrcoef(quality_y_current[key], quality_x_current[key])[0,1]**2
        rmse = np.sqrt(np.mean((quality_x_current[key] - quality_y_current[key])**2))
        plt.title(titles[prefix])
        
        plt.text(0.35, 0.15, r"rmsd = "+f"{rmse:.2f}", fontsize=8)
        plt.text(0.35, 0.05, r"r$^2$ = "+f"{r2:.2f}", fontsize=8)

        plt.xlim([-0.05, 0.65])
        plt.ylim([-0.05, 0.65])
        # plt.xticks([1, 2, 3])
        # plt.yticks([1, 2, 3])

        plt.ylabel(r"Predicted rmsd(stability)")
        plt.xlabel(r"Measured rmsd(stability)")
        
        # Add colorbar
        # cbar = plt.colorbar(sc)
        # cbar.set_label('Density')
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{prefix}_{key}_linear_rmsd_density.{format}"), bbox_inches="tight", dpi=300)
        plt.close()

In [84]:
%%capture output
prefixes = ["9.1_merged", "9.2_merged", "9.3_merged"]
for prefix in prefixes:
    quality_x_current = rmse_log_x_by_model[prefix]
    quality_y_current = rmse_log_y_by_model[prefix]
    for key in quality_x_current.keys():
        if not "random" in key:
            continue
        fig, ax = plt.subplots(figsize=(2.2, 1.9))
        
        ax,sc = density_scatter(quality_y_current[key], quality_x_current[key],
            ax=ax, s=3, cmap='viridis', zorder=2)
        #plt.plot([0.25, 3], [0.25, 3], color="black", linestyle="--", linewidth=1.5, zorder=1)
        plt.plot([-0.5, 1.5], [-0.5, 1.5], color="black", linestyle="--", linewidth=1.5, zorder=1)

        r2 = np.corrcoef(quality_x_current[key], quality_y_current[key])[0,1]**2
        rmse = np.sqrt(np.mean((quality_x_current[key] - quality_y_current[key])**2))
        plt.title(titles[prefix])
        
        plt.text(0.05, 0.85, r"rmsd = "+f"{rmse:.2f}", fontsize=8)
        plt.text(0.05, 0.73, r"r$^2$ = "+f"{r2:.2f}", fontsize=8)

        plt.xlim([0, 1])
        plt.ylim([0, 1])
        # plt.xticks([1, 2, 3])
        # plt.yticks([1, 2, 3])

        plt.ylabel(r"Predicted rmsd(log$_{10}$(stability))")
        plt.xlabel(r"Measured rmsd(log$_{10}$(stability))")
        
        # Add colorbar
        # cbar = plt.colorbar(sc)
        # cbar.set_label('Density')
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{prefix}_random_log10_rmsd_density.{format}"), bbox_inches="tight", dpi=300)
        plt.close()

# ----------------------------------------------------------------------------------------------

In [34]:
curr_mse_designs = mse_designs_updated.copy()
plot_folder_prefix = "9.2_merged"
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

In [None]:
df = curr_mse_designs["all_mse"].copy()
df["predicted_quality"] = np.log10(df["predicted_quality"]+1)
df["predicted_mse"] = np.log10(1/df["predicted_mse"]+1)

order = ["single_target", "double_target", "three_target", "single_active", "double_active", "random_target"]
labels = ["one\ninactive", "two\ninactive", "three\ninactive", "one\nactive", "two\nactive", "graduated"]

# create a violin plot of the median quality values for each design type
plt.figure(figsize=(2.8, 1.3))
sns.violinplot(data=df, x="type", y="predicted_mse", density_norm="width", inner="quartile", order=order, width=0.6, linewidth=1.5, color="skyblue",
               linecolor="black")
plt.ylabel("Predicted design quality")
plt.xlabel("Design type", rotation=0)
# rotate xticks by 45
plt.xticks(ticks=np.arange(len(order)), labels=labels, rotation=0)
for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.5.2_predicted_quality_violin.{format}"), bbox_inches="tight", dpi=300)

In [None]:
df = curr_mse_designs["all_mse"].copy()
df["measured_quality"] = np.log10(df["measured_quality"]+1)
df["measured_mse"] = np.log10(1/df["measured_mse"]+1)

order = ["single_target", "double_target", "three_target", "single_active", "double_active", "random_target"]
labels = ["one\ninactive", "two\ninactive", "three\ninactive", "one\nactive", "two\nactive", "graduated"]

# create a violin plot of the median quality values for each design type
plt.figure(figsize=(2.8, 1.3))
sns.violinplot(data=df, x="type", y="measured_mse", density_norm="width", inner="quartile", order=order,
               width=0.6, linewidth=1.5, color="skyblue",
               linecolor="black")
plt.ylabel("Measured design quality")
plt.xlabel("Design type")
plt.xticks(ticks=np.arange(len(order)), labels=labels, rotation=0)
for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.5.2_measured_quality_violin.{format}"), bbox_inches="tight", dpi=300)

In [37]:
df = curr_mse_designs["all_mse"].copy()
# df["measured_quality"] = np.log10(df["measured_quality"]+1)
# df["measured_mse"] = np.log10(1/df["measured_mse"]+1)
# df["predicted_mse"] = np.log10(1/df["predicted_mse"]+1)
df["measured_mse"] = np.sqrt(df["measured_mse"])
df["predicted_mse"] = np.sqrt(df["predicted_mse"])

df["target_no"] = 0
df.loc[df.index.str.contains("AND4"), "target_no"] = 4
df.loc[df.index.str.contains("AND5"), "target_no"] = 5
df.loc[df.index.str.contains("AND6"), "target_no"] = 6

design_types = ["single_target", "double_target", "three_target", "single_active", "double_active", "random_target"]
labels = ["one inactive", "two inactive", "three inactive", "one active", "two active", "graduated"]

In [None]:
def add_pvalue_annotations(ax, data, x, y):
    target_numbers = sorted(data[x].unique())
    for i, t1 in enumerate(target_numbers):
        for t2 in target_numbers[i+1:]:
            group1 = data[data[x] == t1][y]
            group2 = data[data[x] == t2][y]
            stat, p = stats.mannwhitneyu(group1, group2)
            x1, x2 = i, target_numbers.index(t2)
            y_max = max(data[y].max(), ax.get_ylim()[1])
            h = y_max * 1.05
            ax.plot([x1, x1, x2, x2], [h, h+0.05*y_max, h+0.05*y_max, h], lw=1.5, c='k')
            ax.text((x1+x2)*.5, h+0.05*y_max, f'p={p:.2e}', ha='center', va='bottom')
            
def add_significance_annotations(ax, data, x, y, y_max_1 = 0.9, y_max_2 = 0.8):
    def get_significance_symbol(p):
        if p < 0.001:
            return '***'
        elif p < 0.01:
            return '**'
        elif p < 0.05:
            return '*'
        else:
            return 'ns'

    target_numbers = sorted(data[x].unique())
    for i, t1 in enumerate(target_numbers):
        for t2 in target_numbers[i+1:]:
            group1 = data[data[x] == t1][y]
            group2 = data[data[x] == t2][y]
            stat, p = stats.mannwhitneyu(group1, group2)
            x1, x2 = i, target_numbers.index(t2)
            # y_max = max(data[y].max(), ax.get_ylim()[1])
            y_max = y_max_1 if t1 == 4 and t2 == 6 else y_max_2
            h = y_max * 1.08
            ax.plot([x1+0.2, x2-0.2], [h, h], lw=1.5, c='k')
            ax.text((x1+x2)*.5, h, get_significance_symbol(p), ha='center', va='bottom')            

fig, axes = plt.subplots(1, 6, figsize=(5, 1.9))
axes = axes.flatten()

for i, (design_type, ax) in enumerate(zip(design_types, axes)):
    df_design_type = df[df["type"] == design_type]
    sns.boxplot(data=df_design_type, x="target_no", y="measured_mse",
                   ax=ax, width=0.6, linewidth=1.5, color="skyblue",
                   linecolor="black",
                   flierprops=dict(marker='o', markersize=3, markerfacecolor='black', linestyle='none', markeredgecolor='none'))
    
    ax.set_title(labels[i], fontsize=8)
    ax.set_ylabel("Measured design RMSD")
    ax.set_xlabel("Target number")
    
    add_significance_annotations(ax, df_design_type, "target_no", "measured_mse")
    ax.set_ylim([0, 1.1])
    
    # remove axis labels on inner plots
    # if i not in [3, 4, 5]:
    #     ax.set_xlabel("")
    #     ax.set_xticks([])
    if i not in [0]:
        ax.set_ylabel("")
        ax.set_yticks([])

plt.tight_layout()

for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.2_targets_quality_all.{format}"), bbox_inches="tight", dpi=300)

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(5, 1.9))
axes = axes.flatten()

for i, (design_type, ax) in enumerate(zip(design_types, axes)):
    df_design_type = df[df["type"] == design_type]
    sns.boxplot(data=df_design_type, x="target_no", y="predicted_mse",
                   ax=ax, width=0.6, linewidth=1.5, color="skyblue",
                   linecolor="black",
                   flierprops=dict(marker='o', markersize=3, markerfacecolor='black', linestyle='none', markeredgecolor='none'))
    
    ax.set_title(labels[i], fontsize=8)
    ax.set_ylabel("Predicted design RMSD")
    ax.set_xlabel("Target number")
    
    add_significance_annotations(ax, df_design_type, "target_no", "predicted_mse")
    ax.set_ylim([0, 1.1])
    #ax.set_yticks([1.0, 2.0, 3.0])
    #ax.set_yticklabels([f"{x:.1f}" for x in [1.0, 2.0, 3.0]])
    
    # remove axis labels on inner plots
    # if i not in [3, 4, 5]:
    #     ax.set_xlabel("")
    #     ax.set_xticks([])
    if i not in [0]:
        ax.set_ylabel("")
        ax.set_yticks([])

plt.tight_layout()

for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.2_targets_quality_all_predicted.{format}"), bbox_inches="tight", dpi=300)

In [None]:
df =  curr_mse_designs["all_mse"].copy()

curr_cell_meas = cell_lines_measured_UTR
curr_cell_pred = cell_lines_measured_pred
curr_cell_set = cell_lines_measured

diff_df = pd.DataFrame(columns=curr_cell_set)
diff_df[curr_cell_set] = df[curr_cell_meas].values - df[curr_cell_pred].values
diff_df.index = df.index

plt.figure(figsize=(2.8, 1.65))

# violin plot
sns.violinplot(diff_df, palette=cell_line_colors.values(), inner="quart", linewidth=0.75, width=0.7, zorder=2)
# make it a boxplot instead
#sns.boxplot(data=diff_df, palette=cell_line_colors.values(), linewidth=1.5, width=0.6, flierprops=dict(marker='o', markersize=3, markerfacecolor='black', linestyle='none', markeredgecolor='none'))

# create a horizontal line at x = 0
plt.axhline(y=0, color='darkgrey', linestyle='-', linewidth=1.5, label="prediction goal", zorder=1)

plt.ylabel("measured - predicted")
plt.tick_params(axis="x", rotation=45)
plt.ylim(-1, 1)
plt.yticks([-1, -0.5, 0, 0.5, 1])

plt.tight_layout()
plt.legend()
for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}_predict_violin.{format}"), dpi=300)

### calculate fraction < 0.2

In [None]:
diff_df_flatten = diff_df.values.flatten()
100*len(diff_df_flatten[abs(diff_df_flatten) < 0.2])/len(diff_df_flatten)

# 9.6 - Analyze linear design stability for graduated designs

#### Get the data

In [31]:
curr_mse_designs = mse_designs_updated.copy()
plot_folder_prefix = "9.2_merged"
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")

In [None]:
# flatten diff_df
df =  curr_mse_designs["all_mse"].copy()
df = df[df["type"] == "random_target"]
print(len(df))

In [33]:
curr_cell_meas = cell_lines_measured_UTR
curr_cell_pred = cell_lines_measured_pred
curr_cell_set = cell_lines_measured

pred_df = df[curr_cell_pred]
pred_df.columns = pred_df.columns.str.split('_').str[1]
meas_df = df[curr_cell_meas]
meas_df.columns = meas_df.columns.str.split('_').str[0]
diff_df = pd.DataFrame(columns=curr_cell_set)
diff_df[curr_cell_set] = df[curr_cell_meas].values - df[curr_cell_pred].values
diff_df_mean_abs = abs(diff_df).mean(axis=1)

#### Mean prediction error for stability

In [None]:
plt.figure(figsize=(2.2, 1.7))
plt.hist(diff_df_mean_abs, bins=np.arange(0, 1.1, 0.025), color="tab:blue", alpha=1)
# plot a vertical line at 0.2
plt.axvline(x=0.2, color='black', linestyle='-', linewidth=1.5, label="design failure")
# what fraction are below 0.2?
plt.text(0.25, 140,
         "below 0.2: "+ \
         f"{100*len(diff_df_mean_abs[diff_df_mean_abs < 0.2])/len(diff_df_mean_abs):.0f}%", fontsize=8)
plt.text(0.25, 110,
         "median: "+ \
         f"{diff_df_mean_abs.median():.2f}", fontsize=8)
plt.title(titles[plot_folder_prefix], fontsize=8)
plt.xlim([0, 0.5])
plt.xlabel("abs(measured-predicted)")
plt.ylabel("count")
plt.tight_layout()
for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}_predict_hist.{format}"), dpi=300)

#### What fraction lies within a certain margin and how does this compare with random guessing?

In [None]:
pred_df_flatten = pred_df.values.flatten()
meas_df_flatten = meas_df.values.flatten()
diff_df_flatten = diff_df.values.flatten()
    
margin = 0.2
diff_df_flatten_small = [x for x in diff_df_flatten if abs(x) < margin]
fraction_approx_correct = round(100 * len(diff_df_flatten_small)/len(diff_df_flatten))
fraction_random = (1-(1-margin)**2)*100
print(fraction_approx_correct)
print(fraction_random)
print(fraction_random/fraction_approx_correct)

#### logarithmic scatter plot

In [None]:
fig, ax = plt.subplots(figsize=(2, 1.8))
plt.scatter(np.log10(pred_df_flatten), np.log10(meas_df_flatten), s=0.5, alpha=0.5, edgecolor="none", rasterized=True, color="darkblue")
#density_scatter(np.log10(pred_df_flatten), np.log10(meas_df_flatten), ax=ax, s=1, cmap='viridis', zorder=2)
r2 = np.corrcoef(np.log10(pred_df_flatten), np.log10(meas_df_flatten))[0,1]**2

plt.xlabel(r"log$_{10}$(predicted stability)")
plt.ylabel(r"log$_{10}$(measured stability)")
# plt.title(f"r$^2$ = {r2:.2f}", fontsize=8)
plt.text(-1.95, -0.25, f"r$^2$ = {r2:.2f}", fontsize=8)
plt.ylim(-2, 0)
plt.xlim(-2, 0)
plt.tight_layout()
plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}_predict_scatter_log.{format}"), dpi=600)

#### linear scatter plot

In [None]:
fig, ax = plt.subplots(figsize=(2, 2))
#density_scatter(pred_df_flatten, meas_df_flatten, ax=ax, sort=True, bins=30, s=2)
# plt.plot([0, 1], [0, 1], color="black", linestyle="--", linewidth=1)
plt.plot([0, 1], [margin, 1+margin], color="black", linestyle="--", linewidth=1)
plt.plot([0, 1], [-margin, 1-margin], color="black", linestyle="--", linewidth=1)
# shade the area where the difference is smaller than 0.25
plt.fill_between([0, 1], [-margin, 1-margin], [margin, 1+margin], color="lightgreen", alpha=0.25)
# shade the area where the difference is larger than 0.25
plt.fill_between([0, 1], [-margin, 0], [-margin, 1-margin], color="lightcoral", alpha=0.25)
plt.fill_between([0, 1], [margin, 1+margin], [1, 1], color="lightcoral", alpha=0.25)

plt.scatter(pred_df_flatten, meas_df_flatten, s=0.5, alpha=0.5, edgecolor="none", rasterized=True, color="darkblue")
r2 = np.corrcoef(pred_df_flatten, meas_df_flatten)[0,1]**2

plt.xlabel("predicted stability")
plt.ylabel("measured stability")
plt.title(f"r$^2$ = {r2:.2f}\nfraction abs(diff) < {margin} = {fraction_approx_correct}%", fontsize=8)
plt.ylim(0, 1)
plt.xlim(0, 1)
plt.tight_layout()
plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}_predict_scatter_linear.{format}"), dpi=600)

# 9.7 - Do constraints on the microRNA choice hinder designs?

In [39]:
binary_designs = ['double_active', 'single_target', 'double_target', 'three_target', 'single_active']
graduated_designs = ['random_target']

In [40]:
df = curr_mse_designs["all_mse"].copy()
df_binary = df[df["type"].isin(binary_designs)].copy()
df_graduated = df[df["type"].isin(graduated_designs)].copy()

In [41]:
# count the number of unique targets
df_binary["target_string"] = df_binary["target"].apply(lambda x: str(x))
df_binary["design_number"] = df_binary.index.str.split("_").str[-1]
df_binary["design_number_mod"] = df_binary["design_number"].apply(lambda x: int(x)%4)

# if it is 1 or 2, set "design_constrained" to False
df_binary["design_constrained"] = df_binary["design_number_mod"].apply(lambda x: not x in [1, 2])

In [None]:
plt.figure(figsize=(2.8, 1.65))
plt.hist(np.log10(df_binary[df_binary["design_constrained"]]["measured_quality"]), bins=20, alpha=0.5, color="tab:blue", label="constrained", density=True)
plt.hist(np.log10(df_binary[~df_binary["design_constrained"]]["measured_quality"]), bins=20, alpha=0.5, color="tab:orange", label="unconstrained", density=True)
plt.legend()

# are the two distributions significantly different?
from scipy.stats import ttest_ind
ttest_ind(np.log10(df_binary[df_binary["design_constrained"]]["measured_quality"]), np.log10(df_binary[~df_binary["design_constrained"]]["measured_quality"]))

# 9.8 - Heatmaps

### Get an overview over design types

In [None]:
key = list(curr_mse_designs.keys())[0]
df = curr_mse_designs[key].copy()
df["type"].unique()

In [44]:
active_designs = ["single_active", "double_active"]
target_designs = ["single_target", "double_target", "three_target"]
random_designs = ["random_target", "random_target_log"]
range_designs = ["range_target", "range_target_log"]

# 9.8.1 - Binary designs - All Designs

In [None]:
curr_mse_designs = mse_designs_original.copy()
plot_folder_prefix = "9.1_merged"
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)
print(plot_folder)

In [51]:
# Filter to a specific miRNA number
miRNA_number = "AND5"
for key in curr_mse_designs.keys():
    df = curr_mse_designs[key].copy()
    df = df[df.index.str.contains(miRNA_number)]
    curr_mse_designs[key] = df

In [None]:
# %%capture output
for key in curr_mse_designs.keys():
    for design_type in (active_designs+target_designs):
        df = curr_mse_designs[key].copy()
        df = df[df["type"] == design_type]
        
        # plot boxes around designs
        if "active" in design_type:
            df["target_cell_lines"] = df["target"].apply(lambda x: extract_target_cell_lines(x, 1))
        if "target" in design_type:
            df["target_cell_lines"] = df["target"].apply(lambda x: extract_target_cell_lines(x, 0))
        
        # REORDER THE DESIGNS
        cell_line_indices = df["target_cell_lines"].apply(lambda x: [curr_cell_set.index(cell) for cell in x.split(", ")])
        # reorder the designs based on the cell line indices (alphabetical order)
        cell_line_indices = cell_line_indices.sort_values().index
        df = df.loc[cell_line_indices]

        df['group'] = (df['target_cell_lines'] != df['target_cell_lines'].shift(1)).cumsum()
        counts = df.groupby('group').size().tolist()

        if "subset" in key:
            df_measured = df[cell_lines_subset_UTR]
            df_predicted = df[cell_lines_subset_pred]
        else:
            df_measured = df[cell_lines_measured_UTR]
            df_predicted = df[cell_lines_measured_pred]
            
        # drop the suffix/prefix
        df_measured.columns = df_measured.columns.str.split('_').str[0]
        df_predicted.columns = df_predicted.columns.str.split('_').str[1]

        # get a mapping to rows
        cell_line_to_row_map = get_cell_line_to_row_mapping(df_measured)
        
        # -----------------------------------------------------------------------
        plt.clf()
        plt.figure(figsize=get_figsize(key, design_type))
        
        ax = sns.heatmap(df_predicted.T, cmap=main_colormap, vmin=0, vmax=1, cbar_kws={'label': 'stability'})

        x_pos = 0
        for count in counts:
            target_cell_lines = df.iloc[x_pos:x_pos+count]["target_cell_lines"].iloc[0]
            if ", " in target_cell_lines:
                target_cell_lines = target_cell_lines.split(", ")
                target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]
            else:
                target_rows = [cell_line_to_row_map[target_cell_lines]]

            for row in target_rows:
                rect = patches.Rectangle((x_pos, row), count, 1, linewidth=1, edgecolor=box_color, facecolor='none')
                ax.add_patch(rect)

            x_pos += count 

        ax.set_xlim([-0.1, df_predicted.shape[0] + 0.1])
        ax.set_ylim([len(df_predicted.columns) + 0.1, -0.1])
        
        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Predicted cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.1_{design_type}_{key_shorthand[key]}_predicted.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # -----------------------------------------------------------------------
        plt.clf()
        plt.figure(figsize=get_figsize(key, design_type))
            
        ax = sns.heatmap(df_measured.T, cmap=main_colormap, vmin=0, vmax=1, cbar_kws={'label': 'stability'})

        x_pos = 0
        for count in counts:
            target_cell_lines = df.iloc[x_pos:x_pos+count]["target_cell_lines"].iloc[0]
            if ", " in target_cell_lines:
                target_cell_lines = target_cell_lines.split(", ")
                target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]
            else:
                target_rows = [cell_line_to_row_map[target_cell_lines]]

            for row in target_rows:
                rect = patches.Rectangle((x_pos, row), count, 1, linewidth=1, edgecolor=box_color, facecolor='none')
                ax.add_patch(rect)

            x_pos += count 

        ax.set_xlim([-0.1, df_measured.shape[0] + 0.1])
        ax.set_ylim([len(df_measured.columns) + 0.1, -0.1])

        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Measured cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.1_{design_type}_{mirna_number}_{key_shorthand[key]}_measured.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()

# 9.8.2 - Binary designs - Best-performing Designs

In [56]:
curr_mse_designs = mse_designs_updated.copy()
plot_folder_prefix = "9.2_merged"
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")

In [57]:
# This is for predictions for tissues. Not used.
tissue_designs_only = False
if tissue_designs_only:
    tissue_designs = pd.read_csv(os.path.join("../outputs/16_predict_on_tissues", "all_designs_tissue.csv"), index_col=0)

In [None]:
saved_dfs = {}
for key in curr_mse_designs.keys():
    for design_type in (active_designs+target_designs):
        df = curr_mse_designs[key].copy()
        df = df[df["type"] == design_type]
        if tissue_designs_only:
            df = df[df.index.isin(tissue_designs.index)]
        
        if "all" in key:
            curr_cell_meas = cell_lines_measured_UTR
            curr_cell_pred = cell_lines_measured_pred
            curr_cell_set = cell_lines_measured
        if "subset" in key:
            curr_cell_meas = cell_lines_subset_UTR
            curr_cell_pred = cell_lines_subset_pred
            curr_cell_set = cell_lines_subset
        
        if "active" in design_type:
            df["target_cell_lines"] = df["target"].apply(lambda x: extract_target_cell_lines(x, 1))
        if "target" in design_type:
            df["target_cell_lines"] = df["target"].apply(lambda x: extract_target_cell_lines(x, 0))
        
        # calculate the margin
        df_target = pd.DataFrame(df["target"].to_list())
        df_target.index = df.index

        # rename columns according to label_rename
        df_target.rename(columns=label_rename, inplace=True)

        df_measured = df[curr_cell_meas]
        df_measured.columns = [column.split("_")[0] for column in df_measured.columns]
        # df["margin"] = calculate_margin(df_measured, df_target)

        # for each cell line, find the design with the lowest prediction mse
        max_quality_id = df.groupby("target_cell_lines")["measured_quality"].idxmax()
        #max_quality_id = df.groupby("target_cell_lines")["margin"].idxmax()
        # reorder to have the same index order as the original dataframe
        ordered_indices = df.index.intersection(max_quality_id)
        # get a new dataframe with the designs with the highest quality
        df_max_quality = df.loc[ordered_indices]

        # REORDER THE DESIGNS
        cell_line_indices = df_max_quality["target_cell_lines"].apply(lambda x: [curr_cell_set.index(cell) for cell in x.split(", ")])
        # reorder the designs based on the cell line indices (alphabetical order)
        cell_line_indices = cell_line_indices.sort_values().index
        df_max_quality = df_max_quality.loc[cell_line_indices]

        if key == "all_mse":
            saved_dfs[design_type] = df_max_quality.copy()

        # get a mapping to rows
        cell_line_to_row_map = get_cell_line_to_row_mapping(curr_cell_set)
        
        # -----------------------------------------------------------------------
        plt.clf()
        plt.figure(figsize=get_figsize(key, design_type, True))
        
        ax = sns.heatmap(df_max_quality[curr_cell_pred].T, cmap=main_colormap, vmin=0, vmax=1, fmt=".2f", square=True, annot=False, cbar=False)

        x_pos = 0
        for index, row in df_max_quality.iterrows():
            target_cell_lines = row["target_cell_lines"]
            if ", " in target_cell_lines:
                target_cell_lines = target_cell_lines.split(", ")
                target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]
            else:
                target_rows = [cell_line_to_row_map[target_cell_lines]]

            for target_row in target_rows:
                rect = patches.Rectangle((x_pos, target_row), 1, 1, linewidth=1.5, edgecolor=box_color, facecolor='none')
                ax.add_patch(rect)

            x_pos += 1

        ax.set_xlim([-0.1, df_max_quality[curr_cell_pred].shape[0] + 0.1])
        ax.set_ylim([len(df_max_quality[curr_cell_pred].columns) + 0.1, -0.1])
        
        # set yticks to curr_cell_set
        plt.yticks(ticks=np.arange(0.5, len(curr_cell_set)+0.5), labels=curr_cell_set)
        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Predicted cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.2_{design_type}_{key_shorthand[key]}_best_predicted.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # -----------------------------------------------------------------------
        plt.clf()
        plt.figure(figsize=get_figsize(key, design_type, True))
        
        ax = sns.heatmap(df_max_quality[curr_cell_meas].T, cmap=main_colormap, vmin=0, vmax=1, fmt=".2f", annot=False, square=True, cbar=False)
        
        x_pos = 0
        for index, row in df_max_quality.iterrows():
            target_cell_lines = row["target_cell_lines"]
            if ", " in target_cell_lines:
                target_cell_lines = target_cell_lines.split(", ")
                target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]
            else:
                target_rows = [cell_line_to_row_map[target_cell_lines]]

            for target_row in target_rows:
                rect = patches.Rectangle((x_pos, target_row), 1, 1, linewidth=1.5, edgecolor=box_color, facecolor='none')
                ax.add_patch(rect)

            x_pos += 1

        ax.set_xlim([-0.1, df_max_quality[curr_cell_pred].shape[0] + 0.1])
        ax.set_ylim([len(df_max_quality[curr_cell_pred].columns) + 0.1, -0.1])

        plt.xticks([])
        plt.yticks(ticks=np.arange(0.5, len(curr_cell_set)+0.5), labels=curr_cell_set)
        plt.xlabel("Designs")
        plt.ylabel("Measured cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.2_{design_type}_{key_shorthand[key]}_best_measured.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()

In [40]:
# save the tissue designs dfs
if tissue_designs_only:
    with open(os.path.join(output_folder, "best_tissue_designs.pkl"), "wb") as f:
        pickle.dump(saved_dfs, f)

# 9.8.3 - Graduated designs

In [59]:
curr_mse_designs = mse_designs_original.copy()
plot_folder_prefix = "9.1_merged"
plot_folder = os.path.join(base_plot_folder, f"{plot_folder_prefix}_mse")

### plot all designs

In [None]:
# %%capture output
for key in curr_mse_designs.keys():
    for design_type in (random_designs+range_designs):
        df = curr_mse_designs[key].copy()
        df = df[df["type"] == design_type]
        
        # sort the df by the accuracy_quality column
        df = df.sort_values("accuracy_quality", ascending=False)
        
        if "subset" in key:
            df_measured = df[cell_lines_subset_UTR]
            df_predicted = df[cell_lines_subset_pred]
            df_target = df[cell_lines_subset_target]
            
            cell_lines_subset_diff = ["diff_to_pred_" + cell_line for cell_line in cell_lines_subset]
            df_difference = df[cell_lines_subset_diff]
            
            # drop the suffix/prefix
            df_measured.columns = df_measured.columns.str.split('_').str[0]
            df_predicted.columns = df_predicted.columns.str.split('_').str[1]
            df_target.columns = df_target.columns.str.split('_').str[1]
            df_difference.columns = df_difference.columns.str.split('_').str[-1]
            
        else:
            df_measured = df[cell_lines_measured_UTR]
            df_predicted = df[cell_lines_measured_pred]
            df_target = df[cell_lines_all_target]
            
            cell_lines_all_diff = ["diff_to_pred_" + cell_line for cell_line in cell_lines_measured]
            df_difference = df[cell_lines_all_diff]
            
            # drop the suffix/prefix
            df_measured.columns = df_measured.columns.str.split('_').str[0]
            df_predicted.columns = df_predicted.columns.str.split('_').str[1]
            df_target.columns = df_target.columns.str.split('_').str[1]
            df_difference.columns = df_difference.columns.str.split('_').str[-1]
        
        if len(df_measured) == 0:
            continue
        
        # -----------------------------------------------------------------------
        # PREDICTED VALUES
        plt.clf()
        if "subset" in key:
            plt.figure(figsize=get_figsize(key, design_type))
        else:
            plt.figure(figsize=(5,1.6))
        
        ax = sns.heatmap(df_predicted.T, cmap=main_colormap, vmin=0, vmax=1, cbar_kws={'label': 'stability'})
        
        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Predicted cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.3_{design_type}_{key_shorthand[key]}_predicted.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # -----------------------------------------------------------------------
        # MEASURED VALUES
        plt.clf()
        if "subset" in key:
            plt.figure(figsize=get_figsize(key, design_type))
        else:
            plt.figure(figsize=(5,1.6))
            
        ax = sns.heatmap(df_measured.T, cmap=main_colormap, vmin=0, vmax=1, cbar_kws={'label': 'expression\nmeasured-predicted'})

        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Measured cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.3_{design_type}_{key_shorthand[key]}_measured.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # -----------------------------------------------------------------------
        # TARGET VALUES
        plt.clf()
        if "subset" in key:
            plt.figure(figsize=get_figsize(key, design_type))
        else:
            plt.figure(figsize=(5,1.6))
        
        ax = sns.heatmap(df_target.T, cmap=main_colormap, vmin=0, vmax=1, cbar_kws={'label': 'stability'})
        
        plt.xticks([])
        plt.xlabel("Designs")
        plt.ylabel("Target for cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.3_{design_type}_{key_shorthand[key]}_target.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # -----------------------------------------------------------------------
        # DIFFERENCE VALUES
        plt.clf()
        if "subset" in key:
            plt.figure(figsize=get_figsize(key, design_type))
        else:
            plt.figure(figsize=(6, 1.6))
        
        ax = sns.heatmap(df_difference.T, cmap="icefire", vmin=-0.5, vmax=0.5, cbar_kws={'label': 'expression\nmeasured-predicted'})
        
        plt.xticks([])
        plt.xlabel("Designs")
        #plt.ylabel("measured - predicted\nfor cell line")
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.3_{design_type}_{key_shorthand[key]}_difference.{format}"),
                        bbox_inches="tight", dpi=300)
        plt.close()
        
        # print the mean of the difference value for each cell line
        print("Mean difference values between prediction and measurement")
        print(key, design_type)
        for cell_line in df_difference.columns:
            print(cell_line, df_difference[cell_line].mean())

### Predicted and measured difference to the target across deciles (not used)

In [None]:
# for key in curr_mse_designs.keys():
#     for design_type in ["random_target"]:
#         df = curr_mse_designs[key].copy()
#         df = df[df["type"] == design_type]
        
#         # sort the df by the accuracy_quality column
#         df = df.sort_values("accuracy_quality", ascending=False)
        
#         if "subset" in key:
#             df_measured = df[cell_lines_subset_UTR]
#             df_predicted = df[cell_lines_subset_pred]
#             df_target = df[cell_lines_subset_target]
            
#             cell_lines_subset_diff = ["diff_to_pred_" + cell_line for cell_line in cell_lines_subset]
#             df_difference = df[cell_lines_subset_diff]
            
#             # drop the suffix/prefix
#             df_measured.columns = df_measured.columns.str.split('_').str[0]
#             df_predicted.columns = df_predicted.columns.str.split('_').str[1]
#             df_target.columns = df_target.columns.str.split('_').str[1]
#             df_difference.columns = df_difference.columns.str.split('_').str[-1]
            
#         else:
#             df_measured = df[cell_lines_measured_UTR]
#             df_predicted = df[cell_lines_measured_pred]
#             df_target = df[cell_lines_all_target]
            
#             cell_lines_all_diff = ["diff_to_pred_" + cell_line for cell_line in cell_lines_measured]
#             df_difference = df[cell_lines_all_diff]
            
#             # drop the suffix/prefix
#             df_measured.columns = df_measured.columns.str.split('_').str[0]
#             df_predicted.columns = df_predicted.columns.str.split('_').str[1]
#             df_target.columns = df_target.columns.str.split('_').str[1]
#             df_difference.columns = df_difference.columns.str.split('_').str[-1]
        
#         if len(df_measured) == 0:
#             continue
        
#         df = df.sort_values("predicted_quality", ascending=False)
        
#         # divide into tenths
#         df["group"] = pd.qcut(df["predicted_quality"], 10, labels=False)
        
#         # for each group, create a boxplot of the mean difference
#         plt.clf()
#         plt.figure(figsize=(2, 1.4))
#         sns.violinplot(data=df, x="group", y="measured_mse", linewidth=1, color="tab:blue", scale="width", inner="quartile")

### plot target versus measured for one model only (not used)

In [None]:
# for key in curr_mse_designs.keys():
#     for design_type in (random_designs+range_designs):
#         df = curr_mse_designs[key].copy()
#         df = df[df["type"] == design_type]
#         if len(df) == 0:
#             continue
        
#         if "all" in key:
#             curr_cell_meas = cell_lines_measured_UTR
#             curr_cell_pred = cell_lines_measured_pred
#             curr_cell_set = cell_lines_measured
#         if "subset" in key:
#             curr_cell_meas = cell_lines_subset_UTR
#             curr_cell_pred = cell_lines_subset_pred
#             curr_cell_set = cell_lines_subset
        
#         df_measured = df[curr_cell_meas]
#         df_measured.columns = [column.split("_")[0] for column in df_measured.columns]
        
#         # df = df.sort_values(by="measured_quality", ascending=False).head(int(len(df)/2))
#         # print(key, len(df))
        
#         n_designs = len(df)
#         df_max = df.sort_values(by="accuracy_quality", ascending=False)
#         quantile = [0, 0.25, 0.5, 0.75]
#         # chose four designs from each quantile
#         max_quality_id = []
#         for q in quantile:
#             for i in range(4):
#                 max_quality_id.append(df_max.iloc[int(q*n_designs)+i].name)
#         df_max = df_max.loc[max_quality_id]
        
#         # get the targets
#         df_target = pd.DataFrame(df_max["target"].to_list())
#         df_target.index = df_max.index
#         df_target.rename(columns=label_rename, inplace=True)
        
#         # get a new dataframe with the designs with the highest quality
#         n_designs = len(df_max)
        
#         if "subset" in key:
#             fig, axes = plt.subplots(1, n_designs, figsize=(n_designs * 0.3, 1.2))
#         else:
#             fig, axes = plt.subplots(1, n_designs, figsize=(n_designs * 0.3, 1.6)) 
#             #fig, axes = plt.subplots(1, n_designs, figsize=(n_designs * 1, 1.6))
        
#         for i, (ax, design_index) in enumerate(zip(axes.flat, max_quality_id)):
#             design_df = pd.DataFrame(columns=curr_cell_set, index=["measured", "predicted", "target"])
#             design_df.loc["measured"] = df_max.loc[design_index][curr_cell_meas].values
#             design_df.loc["predicted"] = df_max.loc[design_index][curr_cell_pred].values
#             design_df.loc["target"] = df_target.loc[design_index].values
#             design_df = design_df.astype(float)
            
#             sns.heatmap(design_df.T, cmap=main_colormap, vmin=0, vmax=1, annot=False, fmt=".2f", cbar=False, ax=ax)
            
#             # Set the y-axis label only for the first subplot
#             if i == 0:
#                 ax.set_ylabel("")
#             else:
#                 ax.set_ylabel("")
#                 ax.set_yticks([])
                
#             ax.set_xticks([0.5, 1.5, 2.5], ["m", "b", "t"], rotation=0, fontsize=6)
            
#             # delete the actual tickmarks
#             ax.tick_params(axis='both', which='both', length=0)

#             ax.set_xlim([-0.15, design_df.shape[0] + 0.15])
#             ax.set_ylim([len(design_df.columns) + 0.15, -0.15])
                
#         fig.subplots_adjust(wspace=0.2)
#         #plt.tight_layout()
#         for format in ["png", "svg"]:
#             plt.savefig(os.path.join(plot_folder, f"{plot_folder_prefix}.8.3_{design_type}_{key_shorthand[key]}_quantiles.{format}"), bbox_inches="tight", dpi=300)

# 9.9 - Combined heatmaps for the different models

In [None]:
#%%capture output
if "merged" in plot_folder_prefix:
    plot_folder = os.path.join(base_plot_folder, f"9.9_mse_best_margin_merged")
    if not os.path.exists(plot_folder):
        os.makedirs(plot_folder)
else:
    plot_folder = os.path.join(base_plot_folder, f"9.9_mse_best_margin")
    if not os.path.exists(plot_folder):
        os.makedirs(plot_folder)
print(plot_folder)

## 9.9.1 - Binary designs

## Best designs

In [None]:
%%capture output
# check that mse_designs_original, mse_designs_updated, and mse_designs_inverted exist as variables
# do the variables exist?
if "mse_designs_original" in locals() and "mse_designs_updated" in locals() and "mse_designs_inverted" in locals():
    for key in mse_designs_inverted.keys():
        for design_type in (active_designs+target_designs):
            df_orig = mse_designs_original[key].copy()
            df_orig = df_orig[df_orig["type"] == design_type]
            df_inverted = mse_designs_inverted[key].copy()
            df_inverted = df_inverted[df_inverted["type"] == design_type]
            df_updated = mse_designs_updated[key].copy()
            df_updated = df_updated[df_updated["type"] == design_type]
            
            # filter the three to common indices
            common_indices = df_orig.index.intersection(df_inverted.index).intersection(df_updated.index)
            df_orig = df_orig.loc[common_indices]
            df_inverted = df_inverted.loc[common_indices]
            df_updated = df_updated.loc[common_indices]
            
            if "all" in key:
                curr_cell_meas = cell_lines_measured_UTR
                curr_cell_pred = cell_lines_measured_pred
                curr_cell_set = cell_lines_measured
            if "subset" in key:
                curr_cell_meas = cell_lines_subset_UTR
                curr_cell_pred = cell_lines_subset_pred
                curr_cell_set = cell_lines_subset
            
            if "active" in design_type:
                df_orig["target_cell_lines"] = df_orig["target"].apply(lambda x: extract_target_cell_lines(x, 1))
            if "target" in design_type:
                df_orig["target_cell_lines"] = df_orig["target"].apply(lambda x: extract_target_cell_lines(x, 0))

            # get the targets
            df_target = pd.DataFrame(df_orig["target"].to_list())
            df_target.index = df_orig.index
            df_target.rename(columns=label_rename, inplace=True)
            
            # for each cell line, find the design with the lowest prediction mse
            max_quality_id = df_orig.groupby("target_cell_lines")["measured_quality"].idxmax()
            # reorder to have the same index order as the original dataframe
            ordered_indices = df_orig.index.intersection(max_quality_id)
            # get a new dataframe with the designs with the highest quality
            df_max_quality = df_orig.loc[ordered_indices]
            # REORDER THE DESIGNS
            cell_line_indices = df_max_quality["target_cell_lines"].apply(lambda x: [curr_cell_set.index(cell) for cell in x.split(", ")])
            # reorder the designs based on the cell line indices (alphabetical order)
            cell_line_indices = cell_line_indices.sort_values().index
            df_max_quality = df_max_quality.loc[cell_line_indices]

            # get a new dataframe with the designs with the highest quality
            df_max_quality_original = df_orig.loc[df_max_quality.index]
            df_max_quality_updated = df_updated.loc[df_max_quality.index]
            df_max_quality_inverted = df_inverted.loc[df_max_quality.index]
            df_target = df_target.loc[df_max_quality.index]

            # get a mapping to rows
            cell_line_to_row_map = get_cell_line_to_row_mapping(curr_cell_set)
            n_designs = len(df_max_quality.index)
            if n_designs == 1:
                axes = [axes]
            
            if "subset" in key:
                fig, axes = plt.subplots(1, n_designs, figsize=(n_designs * 0.5, 1.2))
            else:
                fig, axes = plt.subplots(1, n_designs, figsize=(n_designs * 0.5, 1.6)) 
            
            for i, (ax, design_index) in enumerate(zip(axes.flat, df_max_quality.index)):
                design_df = pd.DataFrame(columns=curr_cell_set, index=["target", "measured", "empty", "original", "updated", "inverted"])
                design_df.loc["target"] = df_target.loc[design_index].values
                design_df.loc["measured"] = df_max_quality_original.loc[design_index][curr_cell_meas].values
                design_df.loc["empty"] = np.nan
                design_df.loc["original"] = df_max_quality_original.loc[design_index][curr_cell_pred].values
                design_df.loc["updated"] = df_max_quality_updated.loc[design_index][curr_cell_pred].values 
                design_df.loc["inverted"] = df_max_quality_inverted.loc[design_index][curr_cell_pred].values
                design_df = design_df.astype(float)

                # # Show colorbar only for the last subplot
                # cbar = i == n_designs - 1  
                # sns.heatmap(design_df.T, cmap=main_colormap, vmin=0, vmax=1, annot=False, fmt=".2f", 
                #             cbar_kws={'label': 'stability'} if cbar else None, cbar=cbar, ax=ax)
                # don't show the colorbar
                sns.heatmap(design_df.T, cmap=main_colormap, vmin=0, vmax=1, annot=False, fmt=".2f", cbar=False, ax=ax)
                
                # Set the y-axis label only for the first subplot
                if i == 0:
                    ax.set_ylabel("")
                else:
                    ax.set_ylabel("")
                    ax.set_yticks([])
                    
                ax.set_xticks([0.5, 1.5, 2.5, 3.5, 4.5, 5.5], ["t", "m", "", "b", "u", "i"], rotation=0, fontsize=6)
                
                # delete the actual tickmarks
                ax.tick_params(axis='both', which='both', length=0)

                ax.set_xlim([-0.15, design_df.shape[0] + 0.15])
                ax.set_ylim([len(design_df.columns) + 0.15, -0.15])
                
                # Get the target cell lines and plot the rectangles
                target_cell_lines = df_orig.loc[design_index]["target_cell_lines"]
                target_rows = []
                if ", " in target_cell_lines:
                    target_cell_lines = target_cell_lines.split(", ")
                else:
                    target_cell_lines = [target_cell_lines]
                
                target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]

                for target_row in target_rows:
                    rect = patches.Rectangle((0, target_row), 6, 1, linewidth=1, edgecolor=box_color, facecolor='none')
                    ax.add_patch(rect)
                    
            fig.subplots_adjust(wspace=0.2)
            #plt.tight_layout()
            for format in ["png", "svg"]:
                plt.savefig(os.path.join(plot_folder, f"9.9.1_{design_type}_{key_shorthand[key]}_best_measured.{format}"), bbox_inches="tight", dpi=300)

## 9.9.3 - Plot heatmaps for the graduated designs for all three models

In [23]:
# keep track of prediction errors
saved_dfs = {}

In [None]:
# check that mse_designs_original, mse_designs_updated, and mse_designs_inverted exist as variables
for quartile_no in [1, 2, 3, 4]:
    if "mse_designs_original" in locals() and "mse_designs_updated" in locals() and "mse_designs_inverted" in locals():
        for key in mse_designs_original.keys():
            for design_type in (["random_target"]):
            # for design_type in (active_designs+target_designs):
                df_orig = mse_designs_original[key].copy()
                df_orig = df_orig[df_orig["type"] == design_type]
            
                if len(df_orig) == 0:
                    continue

                # get the indices for all three
                df_inverted = mse_designs_inverted[key].copy()
                df_inverted = df_inverted[df_inverted["type"] == design_type]
                df_updated = mse_designs_updated[key].copy()
                df_updated = df_updated[df_updated["type"] == design_type]
                
                # filter the three to common indices
                common_indices = df_orig.index.intersection(df_inverted.index).intersection(df_updated.index)
                df_orig = df_orig.loc[common_indices]
                df_inverted = df_inverted.loc[common_indices]
                df_updated = df_updated.loc[common_indices]
                
                if "active" in design_type:
                    df_orig["target_cell_lines"] = df_orig["target"].apply(lambda x: extract_target_cell_lines(x, 1))
                if "target" in design_type:
                    df_orig["target_cell_lines"] = df_orig["target"].apply(lambda x: extract_target_cell_lines(x, 0))
                
                if len(df_orig) == 0:
                    continue

                if "all" in key:
                    curr_cell_meas = cell_lines_measured_UTR
                    curr_cell_pred = cell_lines_measured_pred
                    curr_cell_set = cell_lines_measured
                if "subset" in key:
                    curr_cell_meas = cell_lines_subset_UTR
                    curr_cell_pred = cell_lines_subset_pred
                    curr_cell_set = cell_lines_subset
                
                df_measured = df_orig[curr_cell_meas]
                df_measured.columns = [column.split("_")[0] for column in df_measured.columns]
                
                # start organizing designs by quality
                df_max = df_orig.sort_values(by="measured_quality", ascending=False)
                
                # get the targets
                df_target = pd.DataFrame(df_max["target"].to_list())
                df_target.index = df_max.index
                df_target.rename(columns=label_rename, inplace=True)
                
                # Some of the best designs are very boring because they essentially just turn everything off
                # Use the mean of the target values to filter out the boring designs
                df_target_mean = df_target.mean(axis=1)
                # filter to those with a mean above 0.35
                # if "random" in design_type:
                #     df_target_mean = df_target_mean[df_target_mean > 0.35]
                df_max = df_max.loc[df_target_mean.index, :]
                n_designs = len(df_max)

                # get 1 designs from each quantile
                max_quality_id = []
                if not "random" in design_type:
                    for i in range(10):
                        max_quality_id.append(df_max.iloc[int(n_designs/10)*i].name)
                else:
                    # THIS IS IMPORTANT FOR RANDOM DESIGNS AND NEEDS TO BE RUN 4 TIMES (1, 2, 3, 4)
                    # quartile_no = 4
                    quartiles = [0, 0.25, 0.5, 0.75]
                    # get 10 designs from the chosen quartile
                    indices_range = n_designs//4
                    indices_start = indices_range * (quartile_no-1)
                    if "all" in key:
                        indices = list(range(indices_start, indices_start+indices_range-1))
                        index_ids = df_max.iloc[indices].index
                        saved_dfs[quartile_no] = [df_orig.loc[index_ids][curr_cell_meas],
                                            df_orig.loc[index_ids][curr_cell_pred],
                                            df_updated.loc[index_ids][curr_cell_pred],
                                            df_inverted.loc[index_ids][curr_cell_pred],
                                            df_target.loc[index_ids]]
                    
                    indices_step = indices_range//10
                    indices = list(range(indices_start, indices_start+indices_range-1, indices_step))
                    for index in indices:
                        max_quality_id.append(df_max.iloc[index].name)
                        #print(index)
                    
                df_max = df_max.loc[max_quality_id]
                df_max_updated = df_updated.loc[max_quality_id]
                df_max_inverted = df_inverted.loc[max_quality_id]
                
                # get a new dataframe with the designs with the highest quality
                n_designs_chosen = len(df_max)
                
                if n_designs_chosen == 1:
                    axes = [axes]
                
                if "subset" in key:
                    fig, axes = plt.subplots(1, n_designs_chosen, figsize=(n_designs_chosen * 0.5, 1.2))
                else:
                    fig, axes = plt.subplots(1, n_designs_chosen, figsize=(n_designs_chosen * 0.5, 1.6)) 
                
                for i, (ax, design_index) in enumerate(zip(axes.flat, max_quality_id)):
                    # get predictions
                    design_df = pd.DataFrame(columns=curr_cell_set, index=["target", "measured", "empty", "original", "updated", "inverted"])
                    design_df.loc["target"] = df_target.loc[design_index].values
                    design_df.loc["measured"] = df_max.loc[design_index][curr_cell_meas].values
                    design_df.loc["empty"] = np.nan
                    design_df.loc["original"] = df_max.loc[design_index][curr_cell_pred].values
                    design_df.loc["updated"] = df_max_updated.loc[design_index][curr_cell_pred].values 
                    design_df.loc["inverted"] = df_max_inverted.loc[design_index][curr_cell_pred].values
                    
                    design_df = design_df.astype(float)
                    
                    # replace HEK293T in the columns with 293T
                    design_df.columns = design_df.columns.str.replace("HEK293T", "293T")
                    
                    #target_df = df_target.loc[design_index].values
                    sns.heatmap(design_df.T, cmap="rocket", vmin=0, vmax=1, annot=False, fmt=".2f", cbar=False, ax=ax,
                                cbar_kws={'label': 'stability', 'ticks': [0, 0.25, 0.5, 0.75, 1]})
                    
                    # Set the y-axis label only for the first subplot
                    if i == 0:
                        ax.set_ylabel("")
                    else:
                        ax.set_ylabel("")
                        ax.set_yticks([])
                        
                    ax.set_xticks([0.5, 1.5, 2.5, 3.5, 4.5, 5.5])
                    ax.set_xticklabels(["t", "m", "", "b", "u", "i"], rotation=0, fontsize=6)
                    
                    # delete the actual tickmarks
                    ax.tick_params(axis='both', which='both', length=0)
                    
                    if not "random" in design_type:
                        # Get the target cell lines and plot the rectangles
                        target_cell_lines = df_orig.loc[design_index]["target_cell_lines"]
                        target_rows = []
                        if ", " in target_cell_lines:
                            target_cell_lines = target_cell_lines.split(", ")
                        else:
                            target_cell_lines = [target_cell_lines]
                        
                        target_rows = [cell_line_to_row_map[cell_line] for cell_line in target_cell_lines]

                        for target_row in target_rows:
                            rect = patches.Rectangle((0, target_row), 6, 1, linewidth=1, edgecolor=box_color, facecolor='none')
                            ax.add_patch(rect)
                        
                fig.subplots_adjust(wspace=0.4)
                #plt.tight_layout()
                if not "random" in design_type:
                    for format in ["png", "svg"]:
                        plt.savefig(os.path.join(plot_folder, f"9.9.2_MANY_{design_type}_{key_shorthand[key]}_predictions.{format}"), bbox_inches="tight", dpi=300)
            else:
                for format in ["png", "svg"]:
                    plt.savefig(os.path.join(plot_folder, f"9.9.2_quartile{quartile_no}_{design_type}_{key_shorthand[key]}_predictions.{format}"), bbox_inches="tight", dpi=300)

## 9.9.3 - Use the graduated model prediction errors to make a boxplot

In [None]:
# ORDER
# measured, original, updated, inverted, target
# I should calculate the mse to the target for all of these

mse_by_quartile = {}
mse_by_quartile["base model"] = {}
mse_by_quartile["updated model"] = {}
mse_by_quartile["inverted transfer function model"] = {}
mse_by_quartile["measured"] = {}
mean_values = {}
for quartile in [1, 2, 3, 4]:
    mse_by_quartile["base model"][quartile] = np.sqrt((calculate_mse(saved_dfs[quartile][1], saved_dfs[quartile][4].values)).to_numpy())
    mse_by_quartile["updated model"][quartile] = np.sqrt((calculate_mse(saved_dfs[quartile][2], saved_dfs[quartile][4].values)).to_numpy())
    mse_by_quartile["inverted transfer function model"][quartile] = np.sqrt((calculate_mse(saved_dfs[quartile][3], saved_dfs[quartile][4].values)).to_numpy())
    mse_by_quartile["measured"][quartile] = np.sqrt((calculate_mse(saved_dfs[quartile][0], saved_dfs[quartile][4].values)).to_numpy())

data = []

for model_type, quartiles in mse_by_quartile.items():
    for quartile, mse_values in quartiles.items():
        for mse in mse_values:
            data.append([model_type, quartile, mse])

df = pd.DataFrame(data, columns=["Model Type", "Quartile", "MSE"])
model_rename = {"base model": "baseline",
                "updated model": "updated",
                "inverted transfer function model": "inverted",
                "measured": "measured"}
# rename the model types
df["Model Type"] = df["Model Type"].apply(lambda x: model_rename[x])

# Plotting the data using seaborn
plt.figure(figsize=(2.4, 1.6))
# set fliers to be black
sns.boxplot(x="Quartile", y="MSE", hue="Model Type", data=df,
            flierprops=dict(marker='o', markersize=3, markeredgecolor='none',
                            linestyle='none', markerfacecolor='black'),
                            linewidth=1) #linecolor="black")

# print the mean of the MSE values
print("Mean RMSE values")
for model_type in model_rename.values():
    for quartile in range(1, 5):
        print(model_type, quartile, df[(df["Model Type"] == model_type) & (df["Quartile"] == quartile)]["MSE"].mean())
        mean_values[model_type + " " + str(quartile)] = df[(df["Model Type"] == model_type) & (df["Quartile"] == quartile)]["MSE"].mean()

plt.xlabel("design success quartile")
plt.ylabel("rmsd to target")
plt.ylim([0, 0.7])
# legend with two columns
plt.legend(loc="upper left", ncol=2, fontsize=6)
plt.tight_layout()
for format in ["png", "svg"]:
    plt.savefig(os.path.join(plot_folder, f"9.9.3_mse_boxplot.{format}"), bbox_inches="tight", dpi=300)

In [None]:
for i in range(4):
    print(mean_values["baseline " + str(i+1)]/mean_values["target " + str(i+1)])
    print(mean_values["updated " + str(i+1)]/mean_values["target " + str(i+1)])
    print(mean_values["inverted " + str(i+1)]/mean_values["target " + str(i+1)])

#### significance values ... makes the plot messy

In [None]:
def return_significance_notation(data, x_condition, x_fixed_col, x_fixed_val, y):
    """x_condition is the column that contains the different groups to compare.
       x_fixed_col is the column that is fixed and x_fixed_val is the value of that column that is fixed."""
    def get_significance_symbol(p):
        if p < 1E-10:
            return '***'
        elif p < 1E-5:
            return '**'
        elif p < 1E-2:
            return '*'
        else:
            return 'ns'

    data = data[data[x_fixed_col] == x_fixed_val]
    groups = data[x_condition].unique()
    for i in range(len(groups)):
        for j in range(i+1, len(groups)):
            group1 = data.groupby(x_condition)[y].get_group(groups[i])
            group2 = data.groupby(x_condition)[y].get_group(groups[j])

            _, p = stats.mannwhitneyu(group1, group2)
            group1_name = f"{x_fixed_val} {groups[i]}"
            group2_name = f"{x_fixed_val} {groups[j]}"
            print(f"Comparison between {group1_name} and {group2_name}: p = {p} {get_significance_symbol(p)}")

        # _, p = stats.mannwhitneyu(group1, group2)
        # return get_significance_symbol(p)

# ['baseline', 'updated', 'inverted', 'target']
x_fixed_col = "Model Type"
x_fixed_val = "updated"
y = "MSE"
x_condition = "Quartile"
return_significance_notation(df, x_condition, x_fixed_col, x_fixed_val, y)

In [None]:
mse_by_quartile = {}
mse_by_quartile["base model"] = {}
mse_by_quartile["updated model"] = {}
mse_by_quartile["inverted transfer function model"] = {}
mse_by_quartile["measured"] = {}
mean_values = {}

for quartile in [1, 2, 3, 4]:
    mse_by_quartile["base model"][quartile] = np.sqrt(
        calculate_mse(saved_dfs[quartile][1], saved_dfs[quartile][4].values)
    ).to_numpy()
    mse_by_quartile["updated model"][quartile] = np.sqrt(
        calculate_mse(saved_dfs[quartile][2], saved_dfs[quartile][4].values)
    ).to_numpy()
    mse_by_quartile["inverted transfer function model"][quartile] = np.sqrt(
        calculate_mse(saved_dfs[quartile][3], saved_dfs[quartile][4].values)
    ).to_numpy()
    mse_by_quartile["measured"][quartile] = np.sqrt(
        calculate_mse(saved_dfs[quartile][0], saved_dfs[quartile][4].values)
    ).to_numpy()

data = []

for model_type, quartiles in mse_by_quartile.items():
    for quartile, mse_values in quartiles.items():
        for mse in mse_values:
            data.append([model_type, quartile, mse])

df = pd.DataFrame(data, columns=["Model Type", "Quartile", "MSE"])
model_rename = {
    "base model": "baseline",
    "updated model": "updated",
    "inverted transfer function model": "inverted",
    "measured": "measured",
}
# Rename the model types
df["Model Type"] = df["Model Type"].apply(lambda x: model_rename[x])

# Plotting the data using seaborn
plt.figure(figsize=(3.2, 1.6))
# plt.figure(figsize=(8, 6))
ax = sns.boxplot(
    x="Quartile",
    y="MSE",
    hue="Model Type",
    data=df,
    flierprops=dict(
        marker="o",
        markersize=3,
        markeredgecolor="none",
        linestyle="none",
        markerfacecolor="black",
    ),
    linewidth=1,
    width=0.8,
)

plt.xlabel("design success quartile")
plt.ylabel("rmsd to target")
ymax = df["MSE"].max()
plt.ylim([0, ymax + 0.2])
plt.legend(loc=[1.05, 0], ncol=1, fontsize=7)
plt.tight_layout()

# Compute box positions manually
xticks = ax.get_xticks()  # Positions of the quartiles on the x-axis

# Get unique quartiles and model types
quartiles = sorted(df["Quartile"].unique())
model_types = list(df["Model Type"].unique())
n_hue_levels = len(model_types)

# Calculate width per box
group_width = 0.8 
width_per_box = group_width / n_hue_levels

# Calculate the offsets for each hue level
offsets = np.linspace(
    -group_width / 2 + width_per_box / 2,
    group_width / 2 - width_per_box / 2,
    n_hue_levels
)

# Map the positions to (model_type, quartile)
box_map = {}
for quartile_index, quartile in enumerate(quartiles):
    x_position = xticks[quartile_index]
    for hue_index, model_type in enumerate(model_types):
        pos = x_position + offsets[hue_index]
        box_map[(model_type, quartile)] = pos

# Compute p-values between adjacent quartiles for each model type
annotations = []

for quartile in [1, 2, 3]:
    for model_type in model_types[:-1]:
        group1 = df[
            (df["Model Type"] == model_type) & (df["Quartile"] == quartile)
        ]["MSE"]
        group2 = df[
            (df["Model Type"] == model_type) & (df["Quartile"] == quartile + 1)
        ]["MSE"]
        stat, p = stats.mannwhitneyu(group1, group2, alternative='two-sided')
        annotations.append(
            {
                "model_type": model_type,
                "quartile1": quartile,
                "quartile2": quartile + 1,
                "pval": p,
            }
        )

# Get maximum y-values for each model type to set y positions
max_y = {}
for model_type in model_types:
    max_y[model_type] = df[df["Model Type"] == model_type]["MSE"].max()

# Add the significance annotations
y_by_model = {
    "baseline": 0.45,
    "updated": 0.52,
    "inverted": 0.59,
}

for ann in annotations:
    model_type = ann["model_type"]
    quartile1 = ann["quartile1"]
    quartile2 = ann["quartile2"]
    pval = ann["pval"]

    x1 = box_map[(model_type, quartile1)]+0.05
    x2 = box_map[(model_type, quartile2)]-0.05
    y = y_by_model[model_type]
    h = 0.01  # Height of the vertical lines


    # Draw the lines
    ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1, c="k")

    # Get significance symbol
    if pval < 1e-10:
        sig_symbol = "***"
    elif pval < 1e-5:
        sig_symbol = "**"
    elif pval < 1e-2:
        sig_symbol = "*"
    else:
        sig_symbol = "ns"
    # Add text
    ax.text(
        (x1 + x2) / 2,
        y + h,
        sig_symbol,
        ha="center",
        va="bottom",
        color="k",
        fontsize=7,
    )

for format in ["png", "svg"]:
    plt.savefig(
        os.path.join(plot_folder, f"9.9.3_mse_boxplot_significance.{format}"),
        bbox_inches="tight",
        dpi=300,
    )


## 9.9.3 - Plot violin plots for the difference between predicted and measured values

In [None]:
#%%capture output
plot_folder = os.path.join(base_plot_folder, f"9.9.3_accuracy")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

labels = ["original", "updated", "inverted"]
for i, mse_df in enumerate([mse_designs_original, mse_designs_updated, mse_designs_inverted]):
    for key in mse_df.keys():
        df = mse_df[key].copy()
        df = df[df["type"] != "random_target"]
        
        if "all" in key:
            curr_cell_meas = cell_lines_measured_UTR
            curr_cell_pred = cell_lines_measured_pred
            curr_cell_set = cell_lines_measured
        if "subset" in key:
            curr_cell_meas = cell_lines_subset_UTR
            curr_cell_pred = cell_lines_subset_pred
            curr_cell_set = cell_lines_subset
            
        diff_df = pd.DataFrame(columns=curr_cell_set)
        diff_df[curr_cell_set] = df[curr_cell_meas].values - df[curr_cell_pred].values
        
        plt.figure(figsize=(2.4, 1.65))
        # create a horizontal line at x = 0
        plt.axhline(y=0, color='darkgrey', linestyle='-', linewidth=1.5, label="prediction goal", zorder=1)
        sns.violinplot(diff_df, palette=cell_line_colors.values(), inner="quart", linewidth=0.75, width=0.7, zorder=2)
        
        plt.ylabel("measured - predicted")
        plt.tick_params(axis="x", rotation=90)
        plt.ylim(-1, 1)
        plt.yticks([-1, -0.5, 0, 0.5, 1])
        
        plt.tight_layout()
        plt.legend()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"{labels[i]}_model_{key}_binary.{format}"), dpi=300)

In [None]:
#%%capture output
plot_folder = os.path.join(base_plot_folder, f"9.4_accuracy")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

# do the variables exist?
if "mse_designs_original" in locals() and "mse_designs_updated" in locals() and "mse_designs_inverted" in locals():
    for key in mse_designs_inverted.keys():
        if "all" in key:
            curr_cell_meas = cell_lines_measured_UTR
            curr_cell_pred = cell_lines_measured_pred
            curr_cell_set = cell_lines_measured
        if "subset" in key:
            curr_cell_meas = cell_lines_subset_UTR
            curr_cell_pred = cell_lines_subset_pred
            curr_cell_set = cell_lines_subset
        
        df_orig = mse_designs_original[key].copy()
        df_update = mse_designs_updated[key].copy()
        df_inverted = mse_designs_inverted[key].copy()
        
        df_orig = df_orig[df_orig["type"] != "random_target"]
        df_update = df_update[df_update["type"] != "random_target"]
        df_inverted = df_inverted[df_inverted["type"] != "random_target"]
        
        diff_df_orig = pd.DataFrame(columns=curr_cell_set)
        diff_df_update = pd.DataFrame(columns=curr_cell_set)
        diff_df_inverted = pd.DataFrame(columns=curr_cell_set)
        
        diff_df_orig[curr_cell_set] = df_orig[curr_cell_meas].values - df_orig[curr_cell_pred].values
        diff_df_update[curr_cell_set] = df_update[curr_cell_meas].values - df_update[curr_cell_pred].values
        diff_df_inverted[curr_cell_set] = df_inverted[curr_cell_meas].values - df_inverted[curr_cell_pred].values
        
        # print("Original")
        # print(diff_df_orig.std())
        # print("Update")
        # print(diff_df_update.std())
        
        series_orig = diff_df_orig.stack().reset_index(drop=True)
        series_update = diff_df_update.stack().reset_index(drop=True)
        series_inverted = diff_df_inverted.stack().reset_index(drop=True)
        
        mean_orig = series_orig.mean()
        std_orig = series_orig.std()

        mean_update = series_update.mean()
        std_update = series_update.std()

        mean_inverted = series_inverted.mean()
        std_inverted = series_inverted.std()

        # Printing the mean and standard deviation for each series
        if "all" in key:
            print("All cell lines")
            print("Original Series - Mean: {:.2f}, Std Dev: {:.2f}".format(mean_orig, std_orig))
            print("Updated Series - Mean: {:.2f}, Std Dev: {:.2f}".format(mean_update, std_update))
            print("Inverted Series - Mean: {:.2f}, Std Dev: {:.2f}".format(mean_inverted, std_inverted))
        
        # what percentage of values are contained within 1 standard deviation for each of the above?
        # original
        within_std_orig = len(series_orig[(series_orig > mean_orig - std_orig) & (series_orig < mean_orig + std_orig)]) / len(series_orig)
        # updated
        within_std_update = len(series_update[(series_update > mean_update - std_update) & (series_update < mean_update + std_update)]) / len(series_update)
        # inverted
        within_std_inverted = len(series_inverted[(series_inverted > mean_inverted - std_inverted) & (series_inverted < mean_inverted + std_inverted)]) / len(series_inverted)
        
        if "all" in key:
            print("Percentage of values within 1 standard deviation")
            print("Original Series: {:.2f}".format(within_std_orig))
            print("Updated Series: {:.2f}".format(within_std_update))
            print("Inverted Series: {:.2f}".format(within_std_inverted))
        
        diff_df_all = pd.DataFrame({
            "baseline": series_orig,
            "updated": series_update,
            "inverted": series_inverted
        })
        
        plt.figure(figsize=(1.6, 1.65))
        
        sns.violinplot(diff_df_all, inner="quart", linewidth=0.75)
        
        plt.ylabel("measured - predicted")
        plt.tick_params(axis="x", rotation=90)
        plt.ylim(-1, 1)
        plt.yticks([-1, -0.5, 0, 0.5, 1])
        
        # plot a grey line at 0
        plt.axhline(y=0, color='darkgrey', linestyle='-', linewidth=1.5, label="prediction goal", zorder=0)
        
        plt.tight_layout()
        for format in ["png", "svg"]:
            plt.savefig(os.path.join(plot_folder, f"all_model_{key}_binary.{format}"), dpi=300)

### Calculate correlation values for each cell line

In [None]:
#%%capture output
plot_folder = os.path.join(base_plot_folder, f"9.9.3_accuracy")
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

labels = ["original", "updated", "inverted"]
corrs = {}
for i, mse_df_curr in enumerate([mse_designs_original, mse_designs_updated, mse_designs_inverted]):
    for key in mse_df_curr.keys():
        df = mse_df_curr[key].copy()
        
        if "all" in key:
            curr_cell_meas = cell_lines_measured_UTR
            curr_cell_pred = cell_lines_measured_pred
            curr_cell_set = cell_lines_measured
        if "subset" in key:
            curr_cell_meas = cell_lines_subset_UTR
            curr_cell_pred = cell_lines_subset_pred
            curr_cell_set = cell_lines_subset
        
        df_meas = df[curr_cell_meas]
        df_meas.columns = [column.split("_")[0] for column in df_meas.columns]
        df_pred = df[curr_cell_pred]
        df_pred.columns = [column.split("_")[1] for column in df_pred.columns]
        
        # create a scatter plot of the measured vs predicted values for each cell line
        for cell_line in curr_cell_set:
            plt.figure(figsize=(1.6, 1.6))
            plt.scatter(df_meas[cell_line], df_pred[cell_line], color=cell_line_colors[cell_line], s=1)
            plt.plot([0, 1], [0, 1], color="black", linestyle="--")
            plt.xlabel("measured")
            plt.ylabel("predicted")
            plt.xlim(0, 1)
            plt.ylim(0, 1)
            plt.tight_layout()
            for format in ["png", "svg"]:
                plt.savefig(os.path.join(plot_folder, f"{labels[i]}_model_{key}_{cell_line}.{format}"), dpi=300)
            plt.close()
        
        # calculate the correlation between the measured and predicted values
        corr = df_meas.corrwith(df_pred, axis=0)**2
        if "all" in key:
            corrs[labels[i]] = corr
        
# create a dataframe with the correlations
corrs_df = pd.DataFrame(corrs)
corrs_df