In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import random
import os
from lib.transfer_functions import transfer_function
from lib.additive_model import add_mirna_combs
from lib.design_utilities import tsi, calculate_quality
import warnings
import ast

warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

cell_lines_subset = ["HEK293T", "HeLa", "SKNSH", "MCF7", "HUH-7", "A549"]
cell_lines_other = ["HaCaT", "JEG-3", "Tera-1", "PC-3"]
cell_lines_measured = ["HEK293T", "HeLa", "SKNSH", "MCF7"]
cell_lines = cell_lines_subset + cell_lines_other

plot_folder = "../plots/9_create_targeted_/designs/"
# Create folder if it does not exist
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

### Here, we use a genetic algorithm to create designs with targeted stability patterns. Designs are either made for a subset of six cell lines or for all six cell lines.

# Load the miRNA data

In [3]:
# load microRNA data
mirna_data_filter = pd.read_csv("../input_data/miRNA_expression_data/1_output/1.10_alles_quantile_crosstalk_filter.csv", index_col=0)
mirna_data_filter = mirna_data_filter.loc[:, cell_lines]
mirna_data_filter = 10**mirna_data_filter

In [None]:
# filter out mirnas that contain unwanted sequence motifs
# restriction sites for BsaI
restriction_sites = ["GAGACC", "GGTCTC"]
polyA_signals = ["AATAAA", "ATTAAA"]
filter_motifs = restriction_sites + polyA_signals

mirbase_df = pd.read_csv("../input_data/mirbase_with_families_and_targets.csv", index_col=0)
mirbase_df = mirbase_df[mirbase_df["confidence"] == "high"]
# only want the target sites here
mirbase_df = mirbase_df[["target"]]

# print miRNAs that contain unwanted motifs in their target sequence
forbidden = mirbase_df[mirbase_df["target"].str.contains("|".join(filter_motifs)) == True].index
print(forbidden)
mirna_data_filter = mirna_data_filter[mirna_data_filter.index.isin(forbidden) == False]

In [None]:
filename = "../output/1_output/1.10_fit_parameters_without_scales.txt"
with open(filename, "r") as file:
    file.readline()
    line = file.readline().split("\t")
    c1 = float(line[0])
    c2 = float(line[1])

popt_filter = 10**np.array([c1,  c2])
stability = mirna_data_filter.apply(lambda x: transfer_function(x, *popt_filter))

# Code for visualization

In [32]:
def create_heatmap(df, title, split_indices=[], filename="heatmap.png", annot=False, loss="lin", adaptive_cbar=False,
                    sublabel=False, cmap="rocket", designs_per_cell_line=5, label_colorbar="rel. expression"):
    plt.figure(figsize=(8, 5))
    plt.rcParams.update({'font.size': 10})
    plt.title(title, y=1.12, fontsize=12)

    # create the plot and set the colorbar range
    if adaptive_cbar:
        ax = sns.heatmap(df, cmap=cmap, vmin=df.min().min(), vmax=df.max().max(), annot=annot, cbar_kws={"shrink": .95})
    else:
        if loss == "lin":
            ax = sns.heatmap(df, cmap=cmap, vmin=0, vmax=1, annot=annot, cbar_kws={"shrink": .95})
        if loss == "log":
            ax = sns.heatmap(df, cmap=cmap, vmin=-1.3, vmax=0, annot=annot, cbar_kws={"shrink": .95})
    
    # retrieve the colorbar
    cbar = ax.collections[0].colorbar
    # add a label to the colorbar
    cbar.set_label(label_colorbar, rotation=270, labelpad=20, fontsize=8)
    cbar.ax.tick_params(length=2,width=1)
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_xlabel("")

    if sublabel is not False:
        ax.tick_params(axis='both', which='both', length=2, width=1)
        ax.set_ylabel("")
        # add the names of the sublabels to the x-axis every 10th tick starting at 5
        for i in range(len(sublabel)):
            plt.text(designs_per_cell_line*(i+0.5), -0.5, sublabel[i],
                     fontsize=10, ha='center', va='center', rotation=25)
        # add a vertical line between designs for different cell lines
        for i in range(len(sublabel)-1):
            plt.axvline(x=designs_per_cell_line*(i+1), color='white', linewidth=2)

    # save to file
    plt.tight_layout()
    plt.savefig(f"{plot_folder}{filename}", dpi=300, bbox_inches='tight')

In [33]:
def scatter_ought_vs_is(df, filename, title, mode="lin"):
    used_cell_lines = list(ast.literal_eval(df["target"].iloc[0]).keys())
    all_targets = []
    all_predictions = []
    for key, row in df.iterrows():
        target = ast.literal_eval(row["target"]).values()
        predicted = row[used_cell_lines].to_list()
        if mode == "log":
            target = [np.log10(x) for x in target]
            predicted = [np.log10(x) for x in predicted]
        all_targets += target
        all_predictions += predicted

    plt.figure(figsize=(5,5))
    plt.scatter(all_targets, all_predictions, s=10)
    if mode == "log":
        plt.plot([-1.3,0], [-1.3,0], color="black")
    else:
        plt.plot([0,1], [0,1], color="black")
    plt.xlabel("ought")
    plt.ylabel("is")
    plt.title(title)
    plt.savefig(f"{plot_folder}{filename}.png", dpi=300, bbox_inches="tight")

In [34]:
def plot_quality(df, filename=False):
    """Assumes a dataframe with a column named quality."""
    fig = plt.figure(figsize=(5, 4))
    plt.rcParams.update({'font.size': 9})
    if filename is False:
        plt.title("Quality of designs", y=1.05, fontsize=9)
    else:
        plt.title(filename.split(".")[0], y=1.05, fontsize=9)
    # count how often each target is used
    target_counts = df["target"].value_counts()[0]
    sublabel = ast.literal_eval(df["sublabel"].iloc[0])

    # get a palette of individual colors
    palette = sns.color_palette("tab10")
    # repeat the palette for every target_count-th color
    palette = list(itertools.chain.from_iterable(itertools.repeat(x, target_counts) for x in palette))
    
    ax = sns.barplot(x=df.index, y=df["quality"], palette=palette)
    plt.xticks([])

    # add the sublabels to every target_count-th tick
    if sublabel is not False:
        for i in range(len(sublabel)):
            plt.text(target_counts*(i+0.4), -1.8, sublabel[i],
                    fontsize=8, ha='center', va='center', rotation=90)
    if filename is False:
        plt.show()
    else:
        plt.savefig(f"{plot_folder}quality/{filename}", dpi=300, bbox_inches='tight')

def plot_quality_all(df):
    for key1, value1 in df.items():
        for key2, value2 in value1.items():
            plot_quality(value2, filename=f"quality_{key2}_{key1}.png")

In [35]:
def visualize_off_on_performance(df, cell_lines_used):
    plt.figure(figsize=(8, 5))
    plt.rcParams.update({'font.size': 10})
    # get the Set3 color palette from matplotlib
    colors = plt.get_cmap("Set3")

    for index, row in df.iterrows():
        target = ast.literal_eval(row["target"])
        x_vals = np.zeros(len(target))
        y_vals = np.zeros(len(target))
        for i, cell_line in enumerate(cell_lines_used):
            if target[cell_line] == 1:
                x_vals[i] == 1
            y_vals[i] = row[cell_line]
        plt.scatter(x_vals, y_vals, color=colors(index))
    
    plt.xticks([0, 1], ["off", "on"])
    plt.xlabel("target")
    plt.ylabel("relative expression")
    plt.show()

# Define the evolutionary algorithm

In [36]:
def calculate_mse(df, mse_target, loss_emphasis):
    """This function calculates the mean squared error of a design for a given target stability distribution.
    df has expression values for different cell lines in the columns and microRNAs in the rows.
    mse_target is a dataframe with the target distribution across cell lines as a single row."""
        
    mse = (df - mse_target)**2
    mse = mse.mul(loss_emphasis, axis=1)
    mse = mse.mean(axis=1)

    return mse

def calculate_fitness(pop, expression, cell_line, loss_emphasis={}, which="quality", mse_target=[]):
    """This function calculates the fitness of a population of designs based on the projected stabiltiy and target.
    
    which: should be either 'quality', 'mse', or 'mse-log' """
    # Calculate the stability levels for the designs in the population according to the additive model
    add_expr = add_mirna_combs(expression, pop).apply(transfer_function)

    # if loss emphasis is empty, generate it
    if len(loss_emphasis) == 0:
        loss_emphasis = {cell_line: 1 for cell_line in add_expr.columns}

    if which == "quality":
        add_expr["tsi"] = tsi(add_expr.to_numpy())
        quality = calculate_quality(add_expr, cell_line)
        fitness = quality

    if which == "mse":
        mse = calculate_mse(add_expr, mse_target, loss_emphasis)
        fitness = 1/mse

    if which == "mse-log":
        mse_target_log = {key: np.log10(value) for key, value in mse_target.items()}
        mse = calculate_mse(np.log10(add_expr), mse_target_log, loss_emphasis)
        fitness = 1/mse

    return fitness

def evaluate_fitness(pop, expression, cell_line, loss_emphasis={}, which="quality", mse_target=[]):
    add_expr = add_mirna_combs(expression, pop).apply(transfer_function)
    add_expr["quality"] = calculate_fitness(pop, expression, cell_line, loss_emphasis, which, mse_target)

    return add_expr

def drop_duplicates(df):
    """ Drop all duplicate designs. Assumes that the indices are tuples of microRNAs. """
    sorted_idx = df.index.map(sorted)
    df['sorted_index'] = [tuple(i) for i in sorted_idx]
    duplicates = df.duplicated(subset='sorted_index', keep=False)
    df = df.drop_duplicates(subset='sorted_index', keep='first').drop(columns='sorted_index')
    return df, duplicates 

def select_parents(fitnesses):
    """Selects two parents from the population using tournament selection."""
    # Tournament selection
    tournament_size = 3
    parents = []

    for _ in range(2):  # Select two parents
        tournament = fitnesses.sample(tournament_size)
        # select the best individual
        winner = tournament.sort_values(ascending=False).index[0]
        parents.append(winner)

    return tuple(parents)

def crossover(parent1, parent2, n):
    # Single point crossover
    idx = random.randint(0, n-1)
    child = parent1[:idx] + parent2[idx:]
    return child

def mutate(child, miRNAs, n):
    # Randomly replace one microRNA with another
    if random.random() < 0.2:  # 20% mutation rate
        idx = random.randint(0, n-1)
        new_mirna = random.choice(miRNAs)
        child = list(child)
        child[idx] = new_mirna
    return tuple(child)

# -----------------------------------------------------------------------

def determine_mirna_usage(df):
    usage_dict = {}
    used_mirnas = df.index.tolist()
    for design in used_mirnas:
        for mirna in design:
            if mirna in usage_dict:
                usage_dict[mirna] += 1
            else:
                usage_dict[mirna] = 1
    
    # sort dict by value
    usage_dict = {k: v for k, v in sorted(usage_dict.items(), key=lambda item: item[1], reverse=True)}
    return usage_dict

def count_mirnas_per_design(df):
    """Combinations is a list of dataframe of designs.
    Returns a dictionary with the number of times each miRNA is used across the designs.
    If a single miRNA is used multiple times in a single design, it is counted only once."""
    combinations = df.index.tolist()

    mirna_count = {}
    for design in combinations:
        design_count = {}
        for mirna in design:
            if mirna in design_count:
                continue
            else:
                design_count[mirna] = 1
            if mirna in mirna_count:
                mirna_count[mirna] += 1
            else:
                mirna_count[mirna] = 1

    mirna_count_df = pd.DataFrame.from_dict(mirna_count, orient='index', columns=['count'])
    mirna_count_df = mirna_count_df.sort_values(by=['count'], ascending=False)            
    return mirna_count_df

In [37]:
def generate_genetic_design(target, n_mirnas, mirnas, mirna_expression, loss_emphasis={},
                            no_designs=10, cell_line="none", loss="mse", generations=30, population_size=300):
    
    # Initial population
    population = [tuple(random.choice(mirnas) for _ in range(n_mirnas)) for _ in range(population_size)]

    # Run the GA for a set number of generations
    for generation in range(generations):
        fitnesses = calculate_fitness(population, mirna_expression,
                    cell_line, loss_emphasis=loss_emphasis, which=loss, mse_target=target)
        # if generation % 10 == 0:
        #     print(f'Processing generation {generation}')
        #     print(f"Mean fitness: {fitnesses.mean()}")
        new_population = []
        for _ in range(population_size):
            parent1, parent2 = select_parents(fitnesses)
            child = crossover(parent1, parent2, n_mirnas)
            child = mutate(child, mirnas, n_mirnas)
            new_population.append(child)
        population = new_population

    # Get the best designs
    designs = evaluate_fitness(population, mirna_expression, cell_line, loss_emphasis=loss_emphasis, which=loss, mse_target=target)
    designs, _ = drop_duplicates(designs)
    designs.sort_values(by=['quality'], ascending=False, inplace=True)
    designs = designs.head(no_designs)

    return designs

def add_numbered_index(df, base_name):
    """Df is assumed to have a multiindex of microRNAs. First, convert the multi-index to columns."""
    df = df.reset_index()
    """Then, add a column with the design number."""
    df.index = [f"{base_name}_{i+1}" for i in range(len(df))]
    return df

# Quality designs

These designs use a different type of objective function from mse-based designs, namely the 'quality' of the designs. This is defined only for the purpose of activity in a single target as expression(target) * tsi(expression) for the tissue-specificity index.

In [38]:
def generate_quality_designs(n_mirnas, mirnas, mirna_expression, base_name, used_cell_lines,
                             designs_per_cell_line=5, increase_diversity=1):
    
    all_designs = {}
    for cell_line in used_cell_lines:
        designs_cell_line = []
        target = {cell_line: 0 for i, cell_line in enumerate(used_cell_lines)}
        target[cell_line] = 1
        miRNAs_filter = mirnas.copy()

        for i in range(increase_diversity):
            designs = generate_genetic_design(
                target=target,
                cell_line=cell_line,
                n_mirnas=n_mirnas,
                mirnas=miRNAs_filter,
                mirna_expression=mirna_expression,
                no_designs=int(designs_per_cell_line/increase_diversity),
                loss="quality",
            )

            designs["target"] = str(target)
            designs["type"] = base_name

            used_mirnas = count_mirnas_per_design(designs)
            top_mirnas = used_mirnas.head(2).index.to_list()
            # print(cell_line, " ", top_mirnas)
            miRNAs_filter = [mirna for mirna in miRNAs_filter if mirna not in top_mirnas]
            designs_cell_line.append(designs)

        all_designs[cell_line] = pd.concat(designs_cell_line)

    all_designs_df = pd.concat([design for design in all_designs.values()])
    return all_designs_df

## miRNA_full_subset_quality

In [39]:
# List of microRNAs and their impacts
mirna_expression = mirna_data_filter
miRNAs = list(mirna_data_filter.index)
designs_per_cell_line = 9
increase_diversity = 3

In [41]:
%%capture output

# Generate designs
n = 4
quality_designs_4 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines_subset],
                    "quality",
                    used_cell_lines=cell_lines_subset,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_4[cell_lines_subset] = quality_designs_4[cell_lines_subset].astype("float")
quality_designs_4 = add_numbered_index(quality_designs_4, base_name="24_miRNA_full_subset_quality_AND4")
quality_designs_4.to_csv(f"../designs/24_miRNA_full_subset_quality_AND4.csv")
create_heatmap(quality_designs_4[cell_lines_subset].T, title="24_AND4_quality_heatmap",
              filename="24_miRNA_full_subset_quality_AND4.png", sublabel=cell_lines_subset, designs_per_cell_line=designs_per_cell_line)                    

In [None]:
%%capture output

# Generate designs
n = 5
quality_designs_5 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines_subset],
                    "quality",
                    used_cell_lines=cell_lines_subset,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_5[cell_lines_subset] = quality_designs_5[cell_lines_subset].astype("float")
quality_designs_5 = add_numbered_index(quality_designs_5, base_name="25_miRNA_full_subset_quality_AND5")
quality_designs_5.to_csv(f"../designs/25_miRNA_full_subset_quality_AND5.csv")
create_heatmap(quality_designs_5[cell_lines_subset].T, title="25_AND5_quality_heatmap",
               filename="25_miRNA_full_subset_quality_AND5.png", sublabel=cell_lines_subset, designs_per_cell_line=designs_per_cell_line)

In [None]:
%%capture output

# Generate designs
n = 6
quality_designs_6 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines_subset],
                    "quality",
                    used_cell_lines=cell_lines_subset,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_6[cell_lines_subset] = quality_designs_6[cell_lines_subset].astype("float")
quality_designs_6 = add_numbered_index(quality_designs_6, base_name="26_miRNA_full_subset_quality_AND6")
quality_designs_6.to_csv(f"../designs/26_miRNA_full_subset_quality_AND6.csv")
create_heatmap(quality_designs_6[cell_lines_subset].T, title="26_AND6_quality_heatmap",
               filename="26_miRNA_full_subset_quality_AND6.png", sublabel=cell_lines_subset, designs_per_cell_line=designs_per_cell_line)

## miRNA_full_quality

In [None]:
%%capture output

# Generate designs
n = 4
quality_designs_4 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines],
                    "quality",
                    used_cell_lines=cell_lines,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_4[cell_lines] = quality_designs_4[cell_lines].astype("float")
quality_designs_4 = add_numbered_index(quality_designs_4, base_name="27_miRNA_full_quality_AND4")
quality_designs_4.to_csv(f"../designs/27_miRNA_full_quality_AND4.csv")
create_heatmap(quality_designs_4[cell_lines].T, title="27_AND4_quality_heatmap",
              filename="27_miRNA_full_quality_AND4.png", sublabel=cell_lines, designs_per_cell_line=designs_per_cell_line)  

In [None]:
%%capture output

# Generate designs
n = 5
quality_designs_5 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines],
                    "quality",
                    used_cell_lines=cell_lines,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_5[cell_lines] = quality_designs_5[cell_lines].astype("float")
quality_designs_5 = add_numbered_index(quality_designs_5, base_name="28_miRNA_full_quality_AND5")
quality_designs_5.to_csv(f"../designs/28_miRNA_full_quality_AND5.csv")
create_heatmap(quality_designs_5[cell_lines].T, title="28_AND5_quality_heatmap",
              filename="28_miRNA_full_quality_AND5.png", sublabel=cell_lines, designs_per_cell_line=designs_per_cell_line)  

In [None]:
%%capture output

# Generate designs
n = 6
quality_designs_6 = generate_quality_designs(n, miRNAs,
                    mirna_expression[cell_lines],
                    "quality",
                    used_cell_lines=cell_lines,
                    designs_per_cell_line=designs_per_cell_line,
                    increase_diversity=increase_diversity)

quality_designs_6[cell_lines] = quality_designs_6[cell_lines].astype("float")
quality_designs_6 = add_numbered_index(quality_designs_6, base_name="29_miRNA_full_quality_AND6")
quality_designs_6.to_csv(f"../designs/29_miRNA_full_quality_AND6.csv")
create_heatmap(quality_designs_6[cell_lines].T, title="29_AND6_quality_heatmap",
              filename="29_miRNA_full_quality_AND6.png", sublabel=cell_lines, designs_per_cell_line=designs_per_cell_line)  

# mse targets

In [42]:
def generate_mse_designs(mse_targets, designs_per_target, base_name, used_cell_lines,
                         loss="mse", n_mirnas=4, loss_emphases={}, increase_diversity=1):
    # List of microRNAs and their impacts
    mirna_expression = mirna_data_filter[used_cell_lines]
    miRNAs = list(mirna_data_filter.index)

    all_designs = []

    if len(loss_emphases) == 0:
        loss_emphases = [{} for i in range(len(mse_targets))]
    for i, mse_target in enumerate(mse_targets):
        miRNAs_filter = miRNAs.copy()
        target_designs = []
        for _ in range(increase_diversity):
            designs = generate_genetic_design(
                target=mse_target,
                loss_emphasis=loss_emphases[i],
                n_mirnas=n_mirnas,
                mirnas=miRNAs_filter,
                mirna_expression=mirna_expression,
                no_designs=int(designs_per_target/increase_diversity),
                loss=loss,
            )

            designs["target"] = str(mse_target)
            designs["emphasis"] = str(loss_emphases[i])
            designs["type"] = base_name

            used_mirnas = count_mirnas_per_design(designs)
            top_mirnas = used_mirnas.head(2).index.to_list()
            # print(cell_line, " ", top_mirnas)
            miRNAs_filter = [mirna for mirna in miRNAs_filter if mirna not in top_mirnas]
            target_designs.append(designs)

        target_designs = pd.concat(target_designs)
        all_designs.append(target_designs)
        
    all_designs_df = pd.concat(all_designs)

    return all_designs_df

# MSE designs

In [43]:
all_mse_designs = {"AND4_subset": {}, "AND5_subset": {}, "AND6_subset": {},
                   "AND4_all": {}, "AND5_all": {}, "AND6_all": {}}

cell_line_design_targets = {
    "subset": cell_lines_subset,
    "all": cell_lines
}

In [44]:
def calc_diff(designs, used_cell_lines, log=False):
    """Designs is assumed to be a dataframe with expression in cells and a target column."""
    diff = pd.DataFrame(columns=used_cell_lines)
    for key, row in designs.iterrows():
        target = ast.literal_eval(row["target"])
        if log:
            diff.loc[key] = np.log10(pd.Series(target).astype('float')) - np.log10(row[used_cell_lines].astype('float'))
        else:
            diff.loc[key] = pd.Series(target) - row[used_cell_lines]
    return diff.astype("float")

In [45]:
def make_multiple_mse_designs(base_name, targets, emphases, design_targets, cell_lines_used,
                              loss, sublabel, designs_per_cell_line, plot_diff=False, trial=False):
    """This function generates designs for multiple targets and loss emphases."""
    if trial:
        diversity = 1
        n_list = [5]
    else:
        diversity = 2
        n_list = [4, 5, 6]
        
    for i, n in enumerate(n_list):
        designs = generate_mse_designs(targets,
                                used_cell_lines=cell_lines_used,
                                designs_per_target=designs_per_cell_line, 
                                base_name=base_name,
                                loss_emphases=emphases, 
                                loss=loss,
                                n_mirnas=n,
                                increase_diversity=diversity)
        designs[cell_lines_used] = designs[cell_lines_used].astype("float")
        designs["sublabel"] = str(sublabel)
        if design_targets == "subset":
            base_number = 30
        else:
            base_number = 33
            
        designs = add_numbered_index(designs, base_name=f"{base_number+i}_miRNA_full_{design_targets}_{base_name}_AND{n}")
        # make a folder name base_name if it doesn't exist yet
        if not os.path.exists(f"{plot_folder}{base_name}"):
            os.makedirs(f"{plot_folder}{base_name}")
        if loss == "mse":
            # plot all designs
            create_heatmap(designs[cell_lines_used].T, filename=f"{base_name}/mse_{base_name}_{design_targets}_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", cmap="magma", sublabel=sublabel, designs_per_cell_line=designs_per_cell_line)
            # plot the best design
            create_heatmap(designs.iloc[::designs_per_cell_line,:].loc[:,cell_lines_used].T, filename=f"{base_name}/mse_{base_name}_{design_targets}_opt_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", cmap="magma", annot=True, sublabel=sublabel, designs_per_cell_line=1)
            # make a scatter plot
            scatter_ought_vs_is(designs, f"{base_name}/{base_name}_{design_targets}_AND{n}", f"{base_name}_{design_targets}_AND{n}")
        else:
            create_heatmap(np.log10(designs[cell_lines_used].T), filename=f"{base_name}/mse_{base_name}_{design_targets}_AND{n}_heatmap", loss="log",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", cmap="magma", sublabel=sublabel, designs_per_cell_line=designs_per_cell_line) 
            create_heatmap(np.log10(designs.iloc[::designs_per_cell_line,:].loc[:,cell_lines_used].T), filename=f"{base_name}/mse_{base_name}_{design_targets}_opt_AND{n}_heatmap", loss="log",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", cmap="magma", annot=True, sublabel=sublabel, designs_per_cell_line=1)
            # make a scatter plot
            scatter_ought_vs_is(designs, f"{base_name}/{base_name}_{design_targets}_AND{n}", f"{base_name}_{design_targets}_AND{n}", "log")
            
        if plot_diff:
            if loss == "mse":
                diff = calc_diff(designs, cell_lines_used)
            else:
                diff = calc_diff(designs, cell_lines_used, log=True)
                
            # plot all designs
            create_heatmap(diff.T, filename=f"{base_name}/mse_diff_{base_name}_{design_targets}_AND{n}_heatmap", adaptive_cbar=True,
                        title=f"mse_diff_{base_name}_{design_targets}_AND{n}", cmap="magma", sublabel=sublabel, designs_per_cell_line=designs_per_cell_line)
            create_heatmap(diff.iloc[::designs_per_cell_line,:].T, filename=f"{base_name}/mse_diff_opt_{base_name}_{design_targets}_AND{n}_heatmap", adaptive_cbar=True,
                        annot=True, title=f"mse_diff_{base_name}_{design_targets}_AND{n}", cmap="magma", sublabel=sublabel, designs_per_cell_line=1)
        all_mse_designs[f"AND{n}_{design_targets}"][base_name] = designs

In [46]:
def plot_saved_design(loaded_df, base_name, design_targets, n, loss="mse"):
    df = loaded_df[base_name]
    cell_lines_used = cell_line_design_targets[design_targets]
    # count how often each target is used
    target_counts = df["target"].value_counts()[0]
    sublabel = ast.literal_eval(df["sublabel"].iloc[0])
    if not os.path.exists(f"{plot_folder}{base_name}"):
        os.makedirs(f"{plot_folder}{base_name}")
    if loss == "mse":
        create_heatmap(df[cell_lines_used].T, filename=f"{base_name}/mse_{base_name}_{design_targets}_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", cmap="magma", sublabel=sublabel, designs_per_cell_line=target_counts)
        create_heatmap(df.iloc[::target_counts,:].loc[:,cell_lines_used].T, filename=f"{base_name}/mse_{base_name}_{design_targets}_opt_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", annot=True, cmap="magma", sublabel=sublabel, designs_per_cell_line=1)
    else:
        create_heatmap(np.log10(df[cell_lines_used].T), filename=f"{base_name}/mse_{base_name}_{design_targets}_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", loss="log", cmap="magma", sublabel=sublabel, designs_per_cell_line=target_counts)
        create_heatmap(np.log10(df.iloc[::target_counts,:].loc[:,cell_lines_used].T), filename=f"{base_name}/mse_{base_name}_{design_targets}_opt_AND{n}_heatmap",
                        title=f"mse_{base_name}_{design_targets}_AND{n}", loss="log", annot=True, cmap="magma", sublabel=sublabel, designs_per_cell_line=1)
        
def plot_saved_designs(df, base_name, log=False):
    for key in all_mse_designs.keys():
        design_targets = key.split("_")[1]
        n = key.split("_")[0][-1]
        if log:
            plot_saved_design(df[key], base_name, design_targets, n, loss="mse-log")
        else:
            plot_saved_design(df[key], base_name, design_targets, n, loss="mse")

In [47]:
def save_mse_designs():
    i = 30
    for key, value in all_mse_designs.items():
        merged_df = pd.concat([df for df in value.values()])
        merged_df.to_csv(f"../designs/{i}_miRNA_{key}_mse_designs.csv")
        i += 1

In [48]:
def load_mse_designs():
    i = 30
    df = {}
    for key in all_mse_designs.keys():
        df[key] = {}
        for filename in os.listdir("designs"):
            if filename.startswith(f"{i}_"):
                loaded_df = pd.read_csv(f"../designs/{filename}", index_col=0)
                # separate the dataframe by the type column
                for design_type in loaded_df["type"].unique():
                    # print(key, design_type)
                    df[key][design_type] = loaded_df[loaded_df["type"] == design_type]
        i += 1
    return df

# Load and plot saved designs

In [None]:
all_mse_designs = load_mse_designs()
print(all_mse_designs.keys())

In [None]:
%%capture output
plot_saved_designs(df, 'single_active', log=False)

## Active in a single cell line

In [49]:
%%capture output
designs_per_cell_line = 4
base_name="single_active"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]

    single_actives = []
    for i in range(len(cell_lines_used)):
        single_active = {cell_line: 0 for i, cell_line in enumerate(cell_lines_used)}
        single_active[cell_lines[i]] = 1
        single_actives.append(single_active)

    single_emphases = []
    for i in range(len(single_actives)):
        single_emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        single_emphasis[cell_lines[i]] = len(cell_lines_used)/2.0
        single_emphases.append(single_emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=single_actives,
                        emphases=single_emphases,
                        design_targets=design_targets,
                        cell_lines_used=cell_lines_used,
                        loss="mse",
                        sublabel=cell_lines_used,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
save_mse_designs()

# Active in two cell lines

In [None]:
%%capture output
designs_per_cell_line = 4
base_name="double_active"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]

    # pick two random cell lines 10 times
    pick_two = set()
    while len(pick_two) < 10:
        pick_two.add(tuple(sorted(random.sample(cell_lines_used, 2))))
    pick_two = list(pick_two)
    # covert pick two to a list of strings
    pick_two_str = [item[0] + "\n" + item[1] for item in pick_two]

    # set the target values
    pick_two_targets = []
    for i in range(len(pick_two)):
        pick_two_target = {cell_line: 0 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_two[i])):
            pick_two_target[pick_two[i][j]] = 1
        pick_two_targets.append(pick_two_target)

    two_emphases = []
    for i in range(len(pick_two)):
        two_emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_two[i])):
            two_emphasis[pick_two[i][j]] = len(cell_lines_used)/(2.0)
        two_emphases.append(two_emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=pick_two_targets,
                        emphases=two_emphases,
                        design_targets=design_targets,
                        cell_lines_used=cell_lines_used,
                        loss="mse",
                        sublabel=pick_two_str,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
save_mse_designs()

# Inactive in a single cell line

In [None]:
%%capture output

designs_per_cell_line = 4
base_name = "single_target"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]

    single_targets = []
    for i in range(len(cell_lines_used)):
        single_target = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        single_target[cell_lines[i]] = 0
        single_targets.append(single_target)

    single_emphases = []
    for i in range(len(single_targets)):
        single_emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        single_emphasis[cell_lines[i]] = len(cell_lines_used)/2.0
        single_emphases.append(single_emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=single_targets,
                        emphases=single_emphases,
                        design_targets=design_targets,
                        cell_lines_used=cell_lines_used,
                        loss="mse",
                        sublabel=cell_lines_used,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
save_mse_designs()

##  Inactive in two cell lines

In [None]:
%%capture output
designs_per_cell_line = 4
base_name="double_target"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]

    # pick two random cell lines 10 times
    pick_two = set()
    while len(pick_two) < 10:
        pick_two.add(tuple(sorted(random.sample(cell_lines_used, 2))))
    pick_two = list(pick_two)
    # covert pick two to a list of strings
    pick_two_str = [item[0] + "\n" + item[1] for item in pick_two]

    # set the target values
    pick_two_targets = []
    for i in range(len(pick_two)):
        pick_two_target = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_two[i])):
            pick_two_target[pick_two[i][j]] = 0
        pick_two_targets.append(pick_two_target)

    two_emphases = []
    for i in range(len(pick_two)):
        two_emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_two[i])):
            two_emphasis[pick_two[i][j]] = len(cell_lines_used)/(2.0)
        two_emphases.append(two_emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=pick_two_targets,
                        emphases=two_emphases,
                        design_targets=design_targets,
                        cell_lines_used=cell_lines_used,
                        loss="mse",
                        sublabel=pick_two_str,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
save_mse_designs()

## Inactive in three cell lines

In [None]:
%%capture output
designs_per_cell_line = 4
base_name="three_target"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]   

    # pick three random cell lines 10 times
    pick_three = set()
    while len(pick_three) < 10:
        pick_three.add(tuple(sorted(random.sample(cell_lines_used, 3))))
    pick_three = list(pick_three)
    # covert pick two to a list of strings
    pick_three_str = [item[0] + "\n" + item[1] + "\n" + item[2] for item in pick_three]

    # set the target values
    pick_three_targets = []
    for i in range(len(pick_three)):
        pick_three_target = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_three[i])):
            pick_three_target[pick_three[i][j]] = 0
        pick_three_targets.append(pick_three_target)

    three_emphases = []
    for i in range(len(pick_three)):
        three_emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        for j in range(len(pick_three[i])):
            three_emphasis[pick_three[i][j]] = len(cell_lines_used)/(3.0)
        three_emphases.append(three_emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=pick_three_targets,
                        emphases=three_emphases,
                        design_targets=design_targets,
                        cell_lines_used=cell_lines_used,
                        loss="mse",
                        sublabel=pick_three_str,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
save_mse_designs()

## Randomly generated target values

In [None]:
%%capture output
designs_per_cell_line = 8
base_name="random_target"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]   

    # set the target values
    random_targets = []
    for i in range(70):
        random_target = {cell_line: random.random() for cell_line in cell_lines_used}
        random_targets.append(random_target)

    emphases = []
    for i in range(len(random_targets)):
        emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        emphases.append(emphasis)
    
    make_multiple_mse_designs(base_name=base_name,
                        targets=random_targets,
                        design_targets=design_targets,
                        emphases=emphases,
                        cell_lines_used=cell_lines_used,
                        plot_diff = True,
                        loss="mse",
                        sublabel=False,
                        trial=False,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
for key in all_mse_designs.keys():
    df = all_mse_designs[key][base_name]
    df = df.iloc[::int(designs_per_cell_line/2),:]
    base_index = "_".join(df.iloc[0:1,:].index.str.split("_")[0][:-1])
    df.index = [f"{base_index}_{i+1}" for i in range(len(df))]
    all_mse_designs[key][base_name] = df

In [None]:
save_mse_designs()

## Target values ranging between 0 and 1

In [None]:
#%%capture output
designs_per_cell_line = 8
base_name="range_target"
target_range = np.arange(0, 1.1, 0.1)

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]   

    # set the target values
    target_range = np.linspace(0, 1, num=len(cell_lines_used))
    # get n random unique permutations of the target_range
    n = 70
    target_range_permutations = random.sample(list(itertools.permutations(target_range)), n)
    curr_target_dicts = []
    for permutation in target_range_permutations:
        curr_target_dicts.append({cell_line: permutation[i] for i, cell_line in enumerate(cell_lines_used)})

    emphases = []
    for i in range(len(target_range_permutations)):
        emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        emphases.append(emphasis)

    make_multiple_mse_designs(base_name=base_name,
                        targets=curr_target_dicts,
                        design_targets=design_targets,
                        emphases=emphases,
                        cell_lines_used=cell_lines_used,
                        plot_diff = True,
                        loss="mse",
                        sublabel=False,
                        trial=False,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
for key in all_mse_designs.keys():
    df = all_mse_designs[key][base_name]
    df = df.iloc[::int(designs_per_cell_line/2),:]
    base_index = "_".join(df.iloc[0:1,:].index.str.split("_")[0][:-1])
    df.index = [f"{base_index}_{i+1}" for i in range(len(df))]
    all_mse_designs[key][base_name] = df

In [None]:
save_mse_designs()

## Randomly generated target values (log target)

In [None]:
%%capture output
designs_per_cell_line = 8
base_name="random_target_log"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]   

    # set the target values
    random_targets = []
    for i in range(70):
        random_target = {cell_line: 10**(1.3*(random.random()-1)) for cell_line in cell_lines_used}
        random_targets.append(random_target)

    emphases = []
    for i in range(len(random_targets)):
        emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        emphases.append(emphasis)

    make_multiple_mse_designs(base_name=base_name,
                        targets=random_targets,
                        design_targets=design_targets,
                        emphases=emphases,
                        cell_lines_used=cell_lines_used,
                        plot_diff = True,
                        loss="mse-log",
                        sublabel=False,
                        trial=False,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
for key in all_mse_designs.keys():
    df = all_mse_designs[key][base_name]
    df = df.iloc[::int(designs_per_cell_line/2),:]
    base_index = "_".join(df.iloc[0:1,:].index.str.split("_")[0][:-1])
    df.index = [f"{base_index}_{i+1}" for i in range(len(df))]
    all_mse_designs[key][base_name] = df

In [None]:
save_mse_designs()

## Target values ranging between 0 and 1 (log target)

In [None]:
%%capture output
designs_per_cell_line = 8
base_name="range_target_log"

for design_targets in ["subset", "all"]:
    cell_lines_used = cell_line_design_targets[design_targets]   

    # set the target values
    target_range = np.linspace(-1.3, 0, num=len(cell_lines_used))
    target_range = [10**i for i in target_range]
    # get n random unique permutations of the target_range
    n = 70
    target_range_permutations = random.sample(list(itertools.permutations(target_range)), n)
    curr_target_dicts = []
    for permutation in target_range_permutations:
        curr_target_dicts.append({cell_line: permutation[i] for i, cell_line in enumerate(cell_lines_used)})

    emphases = []
    for i in range(len(curr_target_dicts)):
        emphasis = {cell_line: 1 for i, cell_line in enumerate(cell_lines_used)}
        emphases.append(emphasis)

    make_multiple_mse_designs(base_name=base_name,
                        targets=curr_target_dicts,
                        design_targets=design_targets,
                        emphases=emphases,
                        cell_lines_used=cell_lines_used,
                        plot_diff = True,
                        loss="mse-log",
                        sublabel=False,
                        trial=False,
                        designs_per_cell_line=designs_per_cell_line)

In [None]:
for key in all_mse_designs.keys():
    df = all_mse_designs[key][base_name]
    df = df.iloc[::int(designs_per_cell_line/2),:]
    base_index = "_".join(df.iloc[0:1,:].index.str.split("_")[0][:-1])
    df.index = [f"{base_index}_{i+1}" for i in range(len(df))]
    all_mse_designs[key][base_name] = df

In [None]:
save_mse_designs()

## Cleanup of design outputs

In [None]:
clean_design_targets = ["range_target", "random_target", "random_target_log", "range_target_log"]

# put the best design for each target first
for key in all_mse_designs.keys():
    for key2 in clean_design_targets:
        df = all_mse_designs[key][key2]
        for i in range(int(len(df)/2)):
            # check if the quality of the second is higher than first
            # if so, swap their position in the dataframe
            if df["quality"].iloc[i*2] < df["quality"].iloc[i*2+1]:
                df.iloc[i*2], df.iloc[i*2+1] = df.iloc[i*2+1], df.iloc[i*2]
        all_mse_designs[key][key2] = df

# for the first third of the dataframe, choose all designs. Afterwards, choose every other design
for key in all_mse_designs.keys():
    for key2 in clean_design_targets:
        df = all_mse_designs[key][key2]
        df = pd.concat([df.iloc[0:int(len(df)/3)], df.iloc[int(len(df)/3)::2]])
        all_mse_designs[key][key2] = df

save_mse_designs()