In [7]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import scipy.optimize as opt
import os
from lib.NA_sequence_utilities import *

cell_lines_main = ["HEK293T", "HeLa", "SKNSH", "MCF7", "HUH-7", "A549"]
cell_lines_other = ["HaCaT", "JEG-3", "Tera-1", "PC-3"]
cell_lines_measured = ["HEK293T", "HeLa", "SKNSH", "MCF7"]
cell_lines = cell_lines_main + cell_lines_other

plot_folder = "../plots/1_microRNA_data/"
# Create the folder if it does not exist
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

# create output directory if it does not exist
output_folder = "../output/1_output/"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

## 1.1 Compare quantile normalized and non-normalized data

In [2]:
def normalize_expr_df_to_rpm(df):
    # normalize
    df = df.div(df.sum(axis=0), axis=1) * 1000000
    # deduct the minimum and add one to the expression data to avoid division by 0
    df = df - df.min() + 1
    # normalize to rpm
    df = df.div(df.sum(axis=0), axis=1) * 1000000
    return df

In [5]:
df_alles_quantile = pd.read_excel('../input_data/mirna_expression_data/1_input/Alles2019_quantile_norm_processed.xlsx', index_col=0)
df_alles_quantile = normalize_expr_df_to_rpm(df_alles_quantile)

# make it log10
df_alles_quantile = np.log10(df_alles_quantile)

## 1.2 Look at the correlation between cell lines

In [None]:
if not os.path.exists(os.path.join(plot_folder, "1.2_cell_line_r2")):
    os.makedirs(os.path.join(plot_folder, "1.2_cell_line_r2"))
corr = df_alles_quantile.corr()**2
mask_matrix = np.triu(corr)
sns.heatmap(corr, annot=True, fmt='.2f', vmin=0.3, vmax=1, cmap='Blues', mask=mask_matrix, cbar_kws={'label': 'r2'})
plt.title('Correlation between cell lines (quantile normalized)')
plt.savefig(os.path.join(plot_folder, "1.2_cell_line_r2/cell_line_r2.png"), dpi=300)

## 1.3 Plot measured stability against the expression data

In [10]:
data_dir_input = "../input_data/measurements_lib1"

# get the name of all files in "reference" folder
reference_files = os.listdir(data_dir_input)

# read them into a dictionary
reference_dict = {}
for reference_file in reference_files:
    if reference_file.endswith(".csv"):
        reference_dict[reference_file.split('.')[0]] = pd.read_csv(os.path.join(data_dir_input, reference_file), index_col=0)

# ADD EXPRESSION DATA TO THE DATAFRAMES
# get all dfs that contain "single" in their key
single_dfs = {key: reference_dict[key].copy() for key in reference_dict.keys() if "single" in key}

for key, df in single_dfs.items():
    # add miRNA expression data
    # this is done by matching the column "miRNA1" in the df with the column "miRNA" in the mirna_expression_df
    # not all values are present in the mirna_expression_df, so we have to match them
    df.set_index("miRNA1", inplace=True)
    df = df.filter(regex='_3UTR_log10')
    df = df.drop(['K562_3UTR_log10', 'HepG2_3UTR_log10'], axis=1)
    df = df.loc[df.index.intersection(df_alles_quantile.index)]

    cell_lines_measured = list(df.columns.str.replace('_3UTR_log10', ''))
    for cell_line in cell_lines_measured:
        df.loc[:, f"{cell_line}_exp"] = df_alles_quantile.loc[df.index, cell_line]

    # drop NaN values
    df.dropna(inplace=True)

    single_dfs[key] = df

In [11]:
# get the single data
measured_single = single_dfs["1_full_single_context1"]

In [12]:
%%capture output
if not os.path.exists(os.path.join(plot_folder, "1.3_stability_vs_expression")):
    os.makedirs(os.path.join(plot_folder, "1.3_stability_vs_expression"))

df = measured_single.copy()
for cell_line in cell_lines_measured:
    # plot the data
    plt.clf()
    plt.rcParams.update({'font.size': 8})
    fig = plt.figure(figsize=(3, 2))
    plt.scatter(df[f"{cell_line}_exp"], df[f"{cell_line}_3UTR_log10"], s=10, color="black")

    # calculate the correlation coefficient
    r, p = stats.spearmanr(df[f"{cell_line}_exp"], df[f"{cell_line}_3UTR_log10"])

    plt.xlabel(r"log$_{10}$"+f"({cell_line} expression)")
    plt.ylabel(r"log$_{10}$(RNA/DNA)")

    plt.xlim(0, 5.5)
    plt.legend(loc="lower left", frameon=False)
    plt.title(f"{cell_line}, " + r"$\rho^2$ = " + str(round(r**2, 2)), fontsize=7)
    plt.savefig(os.path.join(plot_folder, f"1.3_stability_vs_expression/{cell_line}_stability_vs_expression.png"), dpi=300)

In [13]:
# filter to columns that contain "_3UTR_log10"
df_knockdown = measured_single.filter(regex='_3UTR_log10')
# rename columns to drop the "_3UTR_log10"
df_knockdown.columns = [col.split("_")[0] for col in df_knockdown.columns]

# filter to columns that contain "exp"
df_expression = measured_single.filter(regex='exp')
# rename columns to drop the "_exp"
df_expression.columns = [col.split("_")[0] for col in df_expression.columns]

## 1.4 Fitting to the data

In [14]:
def hill_func_log_scales(datasets, c1=3.5, c2=10, n=1, *scales):
    """The expression is assumed to be normalized to one.
    The microRNA data is assumed to be log10.
    The return value is also log10. C2 is generally ignored by setting it to be high."""
    c1 = (10**c1)**n
    c2 = 10**c2
    results = []

    for x, scale in zip(datasets, scales):
        x = x + scale
        x = (10**x)**n
        results.append( np.log10( (1 / (1 + x / c1)) * (1 + x / c2) ))
    return np.concatenate(results)

def hill_func_log_alt(x, c1=3, c2=4.5, n=1):
    """The expression is assumed to be normalized to one. The microRNA data is assumed to be log10.
    The return value is also log10. C2 is generally ignored by setting it to be high."""
    x = (10**x)**n
    c1 = (10**c1)**n
    c2 = 10**c2
    result = (1 / (1 + x / c1)) * (1 + x / c2)
    return np.log10(result)

In [None]:
# testing for a single cell line
cell_line = "SKNSH"
x_data_log = [df_expression[f"{cell_line}"].values for cell_line in cell_lines_measured]
y_data_log = [df_knockdown[f"{cell_line}"].values for cell_line in cell_lines_measured]
cell_line_index = cell_lines_measured.index(cell_line)
x_data_log = x_data_log[cell_line_index]
y_data_log = y_data_log[cell_line_index]

# set bounds and initial guesses for non-scale fitting parameters
# ignore the saturation as it leads to unwanted fitting behavior
p0 = [1, 10, 1]
bounds = ([0.1, 9.9, 0.9999999], [7, 10.1, 1.0000001])

popt_hill_log, pcov = opt.curve_fit(hill_func_log_alt, x_data_log, y_data_log, p0=p0, bounds=bounds, maxfev=5000)
print(popt_hill_log)

In [None]:
x_range_log = np.arange(0, 5.2, 0.01)  
plt.scatter(df_expression[f"{cell_line}"], df_knockdown[f"{cell_line}"], s=5, color="black")
r2 = stats.pearsonr(hill_func_log_alt(df_expression[f"{cell_line}"], *popt_hill_log), df_knockdown[f"{cell_line}"])[0]**2
plt.title(f"{cell_line} (R2={r2:.2f})")
plt.plot(x_range_log, hill_func_log_alt(x_range_log, *popt_hill_log), color="red", label="Hill function fit")

In [17]:
x_data = []
y_data = []
dataset_indices = []
for i, cell_line in enumerate(cell_lines_measured):
    ex_df = df_expression[cell_line].values
    knock_df = df_knockdown[cell_line].values
    x_data.append(ex_df)
    y_data.append(knock_df)
    dataset_indices.append([i] * len(ex_df))

#x_data = np.concatenate(x_data)
y_data = np.concatenate(y_data)
#dataset_indices = np.concatenate(dataset_indices)

In [18]:
# set bounds and initial guesses for non-scale fitting parameters
p0 = [3, 10, 1]
num_params = len(p0)
bounds = ([1, 9.9999, 0.9999999], [7, 10.00001, 1.00000001])

# Guess initial scale values for all datasets
scale_guesses = [0 for _ in range(len(x_data))]
scale_bounds_min = [-2 for _ in range(len(x_data))]
scale_bounds_max = [2 for _ in range(len(x_data))]

# set scale for HEK293T to 0
scale_bounds_min[0] = -0.001
scale_bounds_max[0] = 0.001

p0_scale = p0 + scale_guesses
bounds_scale = (bounds[0]+scale_bounds_min, bounds[1]+scale_bounds_max)

popt_scales, pcov = popt_scales_filter, pcov = opt.curve_fit(
    lambda x, *params: hill_func_log_scales(x, *params),
    x_data,
    y_data,
    p0=p0_scale,
    bounds=bounds_scale,
    maxfev=5000
)

scales = list(popt_scales[num_params:])

In [None]:
popt_scales[:num_params]

In [20]:
%%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.4_global_fit_with_scales")):
    os.makedirs(os.path.join(plot_folder, "1.4_global_fit_with_scales"))

r2_vals_scales = {}
rmsd_vals_scales = {}
df_deviation = pd.DataFrame(columns=df_knockdown.columns, index=df_knockdown.index)
x_range_log = np.arange(0, 5.2, 0.01)  

for cell_line in cell_lines_measured:
    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]
    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    
    plt.scatter(df_expression[f"{cell_line}"]+current_scale, 10**df_knockdown[f"{cell_line}"], s=5, color="black")
    plt.plot(x_range_log+current_scale, 10**hill_func_log_alt(x_range_log+current_scale,
                *popt_scales[:num_params]), color="forestgreen", linewidth=2, label="fit")

    # calculate the R2 value
    r2 = stats.pearsonr(df_knockdown[f"{cell_line}"],
                        hill_func_log_alt(df_expression[f"{cell_line}"]+current_scale,
                        *popt_scales[:num_params]))[0]**2
    r2_vals_scales[cell_line] = r2
    # calculate the RMSD value
    rmsd = np.sqrt(np.mean((df_knockdown[f"{cell_line}"]-
                            hill_func_log_alt(df_expression[f"{cell_line}"]+current_scale,
                            *popt_scales[:num_params]))**2))
    rmsd_vals_scales[cell_line] = rmsd

    # calculate the deviation
    df_deviation[cell_line] = (df_knockdown[f"{cell_line}"]-hill_func_log_alt(df_expression[f"{cell_line}"]+current_scale,
                            *popt_scales[:num_params]))

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}, r2 = {round(r2, 2)}, rmsd = {round(rmsd, 2)}", fontsize=8)

    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.4_global_fit_with_scales/{cell_line}_global_fit.png"), dpi=300)

In [None]:
for cell_line in cell_lines_measured:
    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    # get the deviation
    deviation = abs(df_deviation[cell_line])
    deviation.sort_values(inplace=True, ascending=False)
    deviation = deviation[:10]

    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    plt.scatter(df_expression[f"{cell_line}"]+current_scale, df_knockdown[f"{cell_line}"], s=5, color="skyblue")
    plt.plot(x_range_log+current_scale,
            hill_func_log_alt(x_range_log+current_scale, *popt_scales[:num_params]), color="forestgreen", linewidth=2, label="fit")

    # add text with the miRNA names
    for i in deviation.index.to_list():
        plt.text(df_expression.loc[i, cell_line]+current_scale, df_knockdown.loc[i, cell_line], "-".join(i.split("-")[2:]), fontsize=8)

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}", fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.4_global_fit_with_scales/outliers_{cell_line}.png"), dpi=300)

## 1.5 Plot outliers against raw count data

In [23]:
# load count data
raw_counts = pd.read_csv('../input_data/measurements_lib1/log2fc_combined.csv', index_col=0)

In [24]:
# filter to columns containing "count"
raw_counts = raw_counts.filter(regex="count")
raw_counts = raw_counts.loc[raw_counts.index.intersection(reference_dict["1_full_single_context1"].index), :]
raw_counts["miRNA"] = reference_dict["1_full_single_context1"]["miRNA1"]
raw_counts.set_index("miRNA", inplace=True)

In [25]:
%%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.5_counts_vs_error")):
    os.makedirs(os.path.join(plot_folder, "1.5_counts_vs_error"))

for cell_line in cell_lines_measured:
    for r in ["r1", "r2"]:
        fig = plt.figure(figsize=(3, 2.5))
        plt.scatter(np.log10(raw_counts.loc[df_deviation.index, f"count_3UTR_{cell_line}_{r}"]), df_deviation[cell_line], s=5)
        plt.xticks(np.arange(1, 5, 1))
        plt.xlabel("log10(RNA counts in sequencing data)")
        plt.ylabel("deviation")
        plt.tight_layout()
        plt.savefig(os.path.join(plot_folder, f"1.5_counts_vs_error/{cell_line}_{r}.png"), dpi=300)

## 1.6 Investigate the effect of a single context

In [26]:
measured_single_c2 = single_dfs["1_full_single_context2"]

# filter to columns that contain "_3UTR_log10"
df_knockdown_c2 = measured_single_c2.filter(regex='_3UTR_log10')
# rename columns to drop the "_3UTR_log10"
df_knockdown_c2.columns = [col.split("_")[0] for col in df_knockdown_c2.columns]

# filter to columns that contain "exp"
df_expression_c2 = measured_single_c2.filter(regex='exp')
# rename columns to drop the "_exp"
df_expression_c2.columns = [col.split("_")[0] for col in df_expression_c2.columns]

In [27]:
%%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.6_context1_context2")):
    os.makedirs(os.path.join(plot_folder, "1.6_context1_context2"))

r2_vals_scales = {}
rmsd_vals_scales = {}
df_deviation_c2 = pd.DataFrame(columns=df_knockdown.columns, index=df_knockdown.index)

for cell_line in cell_lines_measured:
    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]
    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})

    plt.scatter(df_expression_c2[f"{cell_line}"]+current_scale, df_knockdown_c2[f"{cell_line}"], s=5, color="black")
    plt.plot(x_range_log+current_scale,
            hill_func_log_alt(x_range_log+current_scale, *popt_scales[:num_params]),
            color="forestgreen", linewidth=2, label="fit")

    # calculate the R2 value
    r2 = stats.pearsonr(df_knockdown_c2[f"{cell_line}"],
                        hill_func_log_alt(df_expression_c2[f"{cell_line}"]+current_scale, *popt_scales[:num_params]))[0]**2
    r2_vals_scales[cell_line] = r2

    # calculate the RMSD value
    rmsd = np.sqrt(np.mean((df_knockdown_c2[f"{cell_line}"]-
                            hill_func_log_alt(df_expression_c2[f"{cell_line}"]+current_scale, *popt_scales[:num_params]))**2))
    rmsd_vals_scales[cell_line] = rmsd

    # calculate the deviation
    df_deviation_c2[cell_line] = (df_knockdown_c2[f"{cell_line}"]
                                -hill_func_log_alt(df_expression_c2[f"{cell_line}"]+current_scale,*popt_scales[:num_params]))

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}, r2 = {round(r2, 2)}, rmsd = {round(rmsd, 2)}", fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.6_context1_context2/{cell_line}_global_fit_with_scales.png"), dpi=300)

In [None]:
for cell_line in cell_lines_measured:
    fig = plt.figure(figsize=(3, 2.5))
    plt.scatter(df_deviation[cell_line], df_deviation_c2.loc[df_deviation.index, cell_line], s=5)
    r2 = stats.pearsonr(df_deviation[cell_line], df_deviation_c2[cell_line])[0]**2
    plt.xlabel("deviation context 1")
    plt.ylabel("deviation context 2")
    plt.title(f"{cell_line} (R2={r2:.2f})")
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.6_context1_context2/{cell_line}_deviation_comparison.png"), dpi=300)

In [29]:
deviation_diff = df_deviation - df_deviation_c2
deviation_add = df_deviation + df_deviation_c2

deviation_diff = abs(deviation_diff)
deviation_add = abs(deviation_add)

In [30]:
%%capture output

for cell_line in cell_lines_measured:
    most_different_between_ctxt= deviation_add.sort_values(by=cell_line, ascending=False).head(10).index

    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})

    plt.scatter(df_expression[f"{cell_line}"]+current_scale, df_knockdown[f"{cell_line}"], s=5, color="black")
    plt.scatter(df_expression.loc[most_different_between_ctxt,f"{cell_line}"]+current_scale,
                df_knockdown.loc[most_different_between_ctxt,f"{cell_line}"], s=5, color="red")
    plt.plot(x_range_log+current_scale, hill_func_log_alt(x_range_log+current_scale, *popt_scales[:num_params]),
            color="forestgreen", linewidth=2, label="fit")

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}", fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.6_context1_context2/{cell_line}_add_context1_outliers.png"), dpi=300)

In [31]:
%%capture output

for cell_line in cell_lines_measured:
    most_different_between_ctxt= deviation_diff.sort_values(by=cell_line, ascending=False).head(10).index

    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})

    plt.scatter(df_expression_c2[f"{cell_line}"]+current_scale, df_knockdown_c2[f"{cell_line}"], s=5, color="black")
    plt.scatter(df_expression_c2.loc[most_different_between_ctxt,f"{cell_line}"]+current_scale,
                df_knockdown_c2.loc[most_different_between_ctxt,f"{cell_line}"], s=5, color="red")
    plt.plot(x_range_log+current_scale, hill_func_log_alt(x_range_log+current_scale, *popt_scales[:num_params]), 
            color="forestgreen", linewidth=2, label="fit")

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}", fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.6_context1_context2/{cell_line}_context2_outliers.png"), dpi=300)

## 1.7 Heuristically filter false positives

In [32]:
def fill_in_sequence(sequence):
    """Fill in the sequence with Ts to make it 21 nts long.
    Avoid introduction of ATGs."""
    if len(sequence) > 20:
        sequence = sequence
    elif len(sequence) == 20 and sequence[-1] == "A":
        sequence = sequence + "A"
    else:
        sequence = sequence + "T"* (21-len(sequence))
    return sequence

In [37]:
# get mirbase
mirbase_df = pd.read_csv('../input_data/mirbase_with_families_and_targets.csv', index_col=0)
# get all high confidence miRNAs
high_confidence = mirbase_df[mirbase_df['confidence'] == 'high'].copy()
# filter high confidence to those microRNAs in the expression df
high_confidence = high_confidence[high_confidence.index.isin(df_alles_quantile.index)]

In [38]:
# get miRNAs with low knockdown
df_knockdown_low = df_knockdown[df_knockdown.min(axis=1) > -0.5]

# make sure to only get those with less knockdown than expected
df_deviation_low = df_deviation.loc[df_knockdown_low.index,:]
df_deviation_low = df_deviation_low[df_deviation_low > 0]

In [None]:
# calculate the rank order for each column in df_deviation
# and save the result in df_rank
df_rank = pd.DataFrame()
for col in df_deviation.columns:
    df_rank[col] = df_deviation_low[col].rank(ascending=True)

# get the index of the maximum value in df_rank
df_rank.idxmax(axis=0)

In [None]:
# count how often each microRNA appears in the top 30% of df_rank
thresholds = df_rank.quantile(0.7)

# Create a boolean mask where True indicates the rank is in the top 20%
mask = df_rank.ge(thresholds)

# Count the True values for each miRNA (i.e., count the number of top 20% ranks for each miRNA)
counts = mask.sum(axis=1)

# print those with counts larger than 2
counts_high = counts[counts >= 3]
counts_high

In [41]:
# wel heuristically exclude these from the design process in case they are false positives
counts_high.to_csv('../output/1_output/1.8_false_positives.csv')

In [42]:
# remove them from the dataframes
df_knockdown_filter = df_knockdown[~df_knockdown.index.isin(counts_high.index)]
df_expression_filter = df_expression[~df_expression.index.isin(counts_high.index)]
df_alles_quantile_filter = df_alles_quantile[~df_alles_quantile.index.isin(counts_high.index)]

In [43]:
x_data = [df_expression_filter[f"{cell_line}"].values for cell_line in cell_lines_measured]
y_data = [df_knockdown_filter[f"{cell_line}"].values for cell_line in cell_lines_measured]
y_data = np.concatenate(y_data)

popt_scales_filter, pcov = opt.curve_fit(hill_func_log_scales, x_data, y_data, p0=p0_scale, bounds=bounds_scale, maxfev=5000)
scales = list(popt_scales[num_params:])

In [44]:
%%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.7_false_positives")):
    os.makedirs(os.path.join(plot_folder, "1.7_false_positives"))

for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})

    plt.scatter(df_expression_filter[cell_line]+current_scale, df_knockdown_filter[cell_line], color="black", s=5)
    plt.scatter(df_expression[cell_line].loc[counts_high.index,]+current_scale,
                df_knockdown[cell_line].loc[counts_high.index,], color="red", s=5)
    plt.plot(x_range_log, hill_func_log_alt(x_range_log, *popt_scales_filter[:num_params]), color="forestgreen", linewidth=2, label="fit")

    r2 = stats.pearsonr(hill_func_log_alt(df_expression_filter[cell_line]+current_scale, *popt_scales_filter[:num_params]),
                        df_knockdown_filter[cell_line])[0]**2
                        
    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}, r2 = " + str(round(r2, 2)), fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.7_false_positives/{cell_line}_individual_fit_with_outliers.png"), dpi=300)

## 1.8 Look at familiy crosstalk
Here, we look at crosstalk from a family-based perspective. This is *not* what we use for the actual heuristic filtering, which is done in the next section.

In [45]:
# get all unique family names
families = high_confidence['family_extended'].unique()

def get_family_mirnas(family):
    return high_confidence[high_confidence['family_extended'] == family].index.to_list()

#### Filter to the highest expressed microRNA in each family that I actually measured in library 1

In [None]:
# for each microRNA, find the family member with the highest expression
# get a list of family members
df_expression_filter["family"] = high_confidence.loc[df_expression.index, "family_extended"].copy()

for index, row in df_expression_filter.iterrows():
    family = row['family']
    other_mirnas = get_family_mirnas(family)
    for cell_line in cell_lines_measured:
        highest_expression = df_alles_quantile.loc[other_mirnas, cell_line].max()
        df_expression_filter.loc[index, f"{cell_line}_max"] = highest_expression

In [47]:
%%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.8_familiy_crosstalk")):
    os.makedirs(os.path.join(plot_folder, "1.8_familiy_crosstalk"))

for cell_line in cell_lines_measured:
    plt.clf()
    plt.figure(figsize=(3, 2.5))
    plt.scatter(df_expression_filter[f"{cell_line}_max"], df_deviation.loc[df_expression_filter.index, f"{cell_line}"], s=5)
    plt.xlabel("Max expression in family")
    plt.ylabel("deviation")
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.8_familiy_crosstalk/{cell_line}_deviation_max.png"), dpi=300)

In [None]:
# how many microRNAs of each family are there in the dataset?
df_expression_filter["family"].value_counts().hist(bins=20)

In [49]:
# for each family, check if the highest expressed miRNA is the same across cell lines
family_max_df = pd.DataFrame(columns=cell_lines_measured, index=df_expression_filter.index)
for family in df_expression_filter["family"].unique():
    df_family = df_expression_filter[df_expression_filter["family"] == family]
    for cell_line in cell_lines_measured:
        max_id = df_family[cell_line].idxmax(axis=0)
        family_max_df.loc[max_id, cell_line] = df_family.loc[max_id, cell_line]

In [50]:
%%capture output

for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    ex_df = family_max_df[cell_line].dropna().astype(float)
    knock_df = df_knockdown_filter.loc[ex_df.index, cell_line]

    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    plt.scatter(ex_df+current_scale, knock_df, color="black", s=5)
    plt.plot(x_range_log, hill_func_log_alt(x_range_log, *popt_scales_filter[:2]), color="forestgreen", linewidth=2, label="fit")
    r2 = stats.pearsonr(hill_func_log_alt(ex_df+current_scale, *popt_scales_filter[:2]), knock_df)[0]**2
    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}, r2 = " + str(round(r2, 2)), fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.8_familiy_crosstalk/{cell_line}_family_max_fit.png"), dpi=300)

### Filter to the highest expression miRNA for each family in the entire microRNA expression dataset

In [51]:
# for each microRNA, find the family member with the highest expression
# get a list of family members
df_alles_quantile["family"] = high_confidence.loc[df_alles_quantile.index, "family_extended"].astype(str)

# for each family, find the highest expressed miRNA across cell lines
family_max_df = pd.DataFrame(columns=cell_lines_measured, index=df_alles_quantile.index)
for family in df_alles_quantile["family"].unique():
    df_family = df_alles_quantile[df_alles_quantile["family"] == family]
    for cell_line in cell_lines:
        max_id = df_family[cell_line].idxmax(axis=0)
        family_max_df.loc[max_id, cell_line] = df_family.loc[max_id, cell_line]

In [52]:
def hill_func_log_scales(x_data, dataset_indices, c1=3, c2=5, n=1, *scales):
    """The expression is assumed to be normalized to one.
    The microRNA data is assumed to be log10.
    The return value is also log10."""
    c1 = (10**c1)**n
    c2 = 10**c2
    results = []

    for i, scale in enumerate(scales):
        mask = (dataset_indices == i)
        x = x_data[mask] + scale
        x = (10**x)**n
        result = (1 / (1 + x / c1)) * (1 + x / c2)
        results.append( np.log10( result ))
    return np.concatenate(results)

In [53]:
x_data = []
y_data = []
dataset_indices = []
for i, cell_line in enumerate(cell_lines_measured):
    allowed_mirnas = family_max_df[cell_line].dropna().index
    allowed_mirnas = allowed_mirnas[allowed_mirnas.isin(df_expression_filter.index)]
    ex_df = df_expression_filter.loc[allowed_mirnas, cell_line].values
    knock_df = df_knockdown_filter.loc[allowed_mirnas, cell_line].values
    x_data.append(ex_df)
    y_data.append(knock_df)
    dataset_indices.append([i] * len(ex_df))

x_data = np.concatenate(x_data)
y_data = np.concatenate(y_data)
dataset_indices = np.concatenate(dataset_indices)

In [54]:
p0_scale = p0 + scale_guesses
bounds_scale = (bounds[0]+scale_bounds_min, bounds[1]+scale_bounds_max)

popt_scales_filter, pcov = opt.curve_fit(
    lambda x, *params: hill_func_log_scales(x, dataset_indices, *params),
    x_data,
    y_data,
    p0=p0_scale,
    bounds=bounds_scale,
    maxfev=5000
)
scales = list(popt_scales_filter[num_params:])

In [55]:
%%capture output

df_deviation = pd.DataFrame(columns=df_knockdown_filter.columns, index=df_expression_filter.index)

for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    allowed_mirnas = family_max_df[cell_line].dropna().index
    allowed_mirnas = allowed_mirnas[allowed_mirnas.isin(df_expression_filter.index)]
    ex_df = df_expression_filter.loc[allowed_mirnas, cell_line]+current_scale
    knock_df = df_knockdown_filter.loc[allowed_mirnas, cell_line]

    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    plt.scatter(ex_df, 10**knock_df, color="black", s=5)
    plt.plot(x_range_log, 10**hill_func_log_alt(x_range_log, *popt_scales_filter[:num_params]), color="forestgreen", linewidth=2, label="fit")
    r2 = stats.pearsonr(hill_func_log_alt(ex_df, *popt_scales_filter[:num_params]), knock_df)[0]**2
    df_deviation[cell_line]

    # calculate the deviation
    df_deviation[cell_line] = (10**knock_df-10**hill_func_log_alt(ex_df+current_scale, *popt_scales_filter[:num_params]))
    
    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    #plt.legend(loc="lower left", fontsize=8)
    plt.title(f"{cell_line}, r2 = " + str(round(r2, 2)), fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.8_familiy_crosstalk/{cell_line}_family_max_entire.png"), dpi=300)

In [None]:
allowed_mirna_all = []

for cell_line in cell_lines_measured:
    # get the index of cell_line in cell_lines
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    allowed_mirnas = family_max_df[cell_line].dropna().index
    allowed_mirnas = allowed_mirnas[allowed_mirnas.isin(df_expression_filter.index)]
    allowed_mirna_all += allowed_mirnas.to_list()
    ex_df = df_expression_filter.loc[allowed_mirnas, cell_line]+current_scale
    knock_df = df_knockdown_filter.loc[allowed_mirnas, cell_line]

    # get the deviation
    deviation = abs(df_deviation.loc[allowed_mirnas, cell_line])
    deviation.sort_values(inplace=True, ascending=False)
    deviation = deviation[:10]

    # plot the data
    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    plt.scatter(ex_df+current_scale, knock_df, s=5, color="skyblue")
    plt.plot(x_range_log+current_scale,
            hill_func_log_alt(x_range_log+current_scale, *popt_scales_filter[:num_params]), color="forestgreen", linewidth=2, label="fit")

    # add text with the miRNA names
    for i in deviation.index.to_list():
        plt.text(ex_df[i]+current_scale, knock_df[i], "-".join(i.split("-")[2:]), fontsize=8)

    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    plt.title(f"{cell_line}", fontsize=8)
    plt.tight_layout()
    plt.savefig(os.path.join(plot_folder, f"1.8_familiy_crosstalk/outliers_{cell_line}.png"), dpi=300)

## 1.9 Explore the crosstalk some more

In [None]:
first_filter = df_alles_quantile[df_alles_quantile.max(axis=1) > 2.5]
second_filter = first_filter[first_filter.min(axis=1) < 2]
print(len(second_filter))
print(len(second_filter["family"].unique()))

In [None]:
# look at the behavior of the let-7 family
let7_5p = get_family_mirnas("let-7-5p")

for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales[cell_line_index]

    ex_df = df_expression_filter.loc[let7_5p, cell_line]+current_scale
    knock_df = df_knockdown_filter.loc[let7_5p, cell_line]+current_scale

    for i, row in knock_df.items():
        if "7e" in i:
            plt.text(ex_df[i], 10**row, "-".join(i.split("-")[2:]), fontsize=8)

    plt.scatter(ex_df,
                10**knock_df, label=cell_line)
    plt.plot(x_range_log, 10**hill_func_log_alt(x_range_log, *popt_scales_filter[:num_params]), linewidth=2, color="black")

plt.legend()
plt.show()

In [59]:
# heuristically define regions that might be important in determining the crosstalk
regions = [(0, 1), (1, 8), (8, 11), (11, 14), (14, 17), (17, 20), (20, 21)]

def get_mismatches(target, query):
    """Returns a list of mismatches and wobbles between two sequences.
    Assumes that both target and query are miRNA sequences of the same length.
    DOES NOT WORK FOR AN ACTUAL TARGET SEQUENCE, WHICH IS THE REVERSE COMPLEMENT."""
    mismatch_positions = [0 for i in range(len(target))]
    wobble_positions = [0 for i in range(len(target))]
    for i in range(len(target)):
        if target[i] != query[i]:
            mismatch_positions[i] = 1
        if target[i] == 'A' and query[i] == 'G':
            wobble_positions[i] = 1
            mismatch_positions[i] = 0
        elif target[i] == 'C' and query[i] == 'T':
            wobble_positions[i] = 1
            mismatch_positions[i] = 0

    # the RISC likes an "A" in the target sequence at position 1
    # therefore, we never count it as a mismatch
    if target[0] == 'U':
        mismatch_positions[0] = 0
        wobble_positions[0] = 0

    return mismatch_positions, wobble_positions

def sum_mismatches_in_regions(mismatch, wobble, regions):
    """Returns a list of mismatches and wobbles in each region."""
    mismatch_counts = []
    wobble_counts = []
    for region in regions:
        mismatch_counts.append(sum(mismatch[region[0]:region[1]]))
        wobble_counts.append(sum(wobble[region[0]:region[1]]))

    try:
        assert sum(mismatch_counts) + sum(wobble_counts) == sum(mismatch) + sum(wobble)
    except AssertionError:
        print(mismatch_counts,wobble_counts,mismatch,wobble)

    return mismatch_counts, wobble_counts

def count_mismatches_in_region(target, query, regions):
    mismatch, wobble = get_mismatches(target, query)
    mismatch_counts, wobble_counts = sum_mismatches_in_regions(mismatch, wobble, regions)

    return mismatch_counts, wobble_counts

In [None]:
target = high_confidence.loc['hsa-let-7a-5p', "sequence"]
family_mirnas = get_family_mirnas(high_confidence.loc['hsa-let-7i-5p', "family_extended"]) 
for mirna in family_mirnas:
    query = high_confidence.loc[mirna, "sequence"]
    mismatch, wobble = count_mismatches_in_region(target, query, regions)
    print(f"{mirna} has {mismatch} mismatches and {wobble} wobbles")

In [61]:
mismatch_dict = {}

for mirna_target, row in high_confidence.iterrows():
    target = high_confidence.loc[mirna_target, "sequence"]
    query_df = pd.DataFrame(index = high_confidence.index, columns = ["mismatch", "wobble"])
    for mirna_query, row in high_confidence.iterrows():
        query = high_confidence.loc[mirna_query, "sequence"]
        mismatch, wobble = count_mismatches_in_region(target, query, regions)
        query_df.loc[mirna_query, "mismatch"] = mismatch
        query_df.loc[mirna_query, "wobble"] = wobble

    mismatch_dict[mirna_target] = query_df

We assume that crosstalk does not occur if 
a) the seed is different
b) the overall number of mismatches is large enough.

We ignore wobble base pairs as they might not count as proper mismatches.

In [62]:
for key in mismatch_dict.keys():
    df = mismatch_dict[key]
    # mismatches in the seed
    df["mismatch_seed"] = df["mismatch"].apply(lambda x: x[1])
    # count all mismatches except those at position 21
    df["mismatch_four"] = df["mismatch"].apply(lambda x: sum(x[0:5]))
    mismatch_dict[key] = df

mismatch_dict_filter = {}

for key in mismatch_dict.keys():
    df = mismatch_dict[key].copy()
    df = df[df["mismatch_four"] < 5]
    mismatch_dict_filter[key] = df[df["mismatch_seed"] < 1]

# for the most part, the family is sufficient to get potential cross-talk miRNAs
# for key in mismatch_dict_filter.keys():
#     df = mismatch_dict_filter[key]
#     family = high_confidence.loc[key, "family_extended"]
#     family_mirnas = get_family_mirnas(family)
#     if len(family_mirnas) < len(df):
#         print(key, len(df), len(family_mirnas))
#         print(mismatch_dict_filter[key])
#         print("------------------------------------------")

In [63]:
# some manual checking of whether it's doing the right thing
# mirs = high_confidence[high_confidence["family_extended"] == "mir-506-3p"].index
# mirs = mirs[mirs.isin(df_expression_filter.index)]
# df_alles_quantile_filter.loc[mirs, "HEK293T"]
# df = mismatch_dict["hsa-miR-512-3p"]
# df[df["mismatch_seed"] < 2]

In [64]:
# some mirnas/families seem suspicious. these can be removed manually. do not do this here.
# outlier_families = ["mir-506-3p", 'mir-515-3p', 'mir-146-3p', 'mir-142-5p', 'mir-194-5p']
# outlier_families_mirnas = []
# for outlier_family in outlier_families:
#     for mirna in high_confidence[high_confidence["family_extended"] == outlier_family].index:
#         outlier_families_mirnas.append(mirna)
        
# df_expression_filter2 = df_expression_filter[~df_expression_filter.index.isin(outlier_families_mirnas)]
# df_knockdown_filter2 = df_knockdown_filter[~df_knockdown_filter.index.isin(outlier_families_mirnas)]

We filter miRNAs that could potentially crosstalk (see above) based on two criteria:

1) There is at least one other miRNAs with a 1.35x higher expression in that same cell lines.
2) That other expression value is at least 400 tpm.

In [65]:
# for each microRNA, find if there are other microRNAs in the filtered mismatch dict with a higher expression
crosstalk_dict = {}
for cell_line in cell_lines:
    expr_df = 10**df_alles_quantile_filter[cell_line]
    crosstalk_dict[cell_line] = {}
    for i, value in expr_df.items():
        other_mirnas = mismatch_dict_filter[i].index
        # filter to those that have not been filtered out before
        other_mirnas = other_mirnas[other_mirnas.isin(expr_df.index)]
        other_expr_values = expr_df[other_mirnas]
        other_mirnas = other_expr_values[(other_expr_values > 1.35*value) & (other_expr_values > 400)].index
        crosstalk_dict[cell_line][i] = other_mirnas.to_list()

In [66]:
allowed_mirnas_all = {}
for cell_line in cell_lines:
    allowed_mirnas = []
    for key in crosstalk_dict[cell_line].keys():
        if len(crosstalk_dict[cell_line][key]) == 0:
            allowed_mirnas.append(key)
    allowed_mirnas_all[cell_line] = allowed_mirnas

In [67]:
x_data = []
y_data = []
dataset_indices = []
allowed_mirnas_all_meas = {}
for i, cell_line in enumerate(cell_lines_measured):
    allowed_mirnas_meas = [mirna for mirna in allowed_mirnas_all[cell_line] if mirna in df_expression_filter.index]
    allowed_mirnas_all_meas[cell_line] = allowed_mirnas_meas
    ex_df = df_expression_filter.loc[allowed_mirnas_meas, cell_line].values
    knock_df = df_knockdown_filter.loc[allowed_mirnas_meas, cell_line].values
    x_data.append(ex_df)
    y_data.append(knock_df)
    dataset_indices.append([i] * len(ex_df))

x_data = np.concatenate(x_data)
y_data = np.concatenate(y_data)
dataset_indices = np.concatenate(dataset_indices)

In [68]:
def hill_func_log_scales(x_data, dataset_indices, c1=3, c2=5, n=1, *scales):
    """The expression is assumed to be normalized to one.
    The microRNA data is assumed to be log10.
    The return value is also log10."""
    c1 = (10**c1)**n
    c2 = 10**c2
    results = []

    for i, scale in enumerate(scales):
        mask = (dataset_indices == i)
        x = x_data[mask] + scale
        x = (10**x)**n
        result = (1 / (1 + x / c1)) * (1 + x / c2)
        results.append( np.log10( result ))
    return np.concatenate(results)

In [69]:
# set bounds and initial guesses for non-scale fitting parameters
p0 = [3, 10, 1]
num_params = len(p0)
bounds = ([1, 9.9999999, 0.9999999], [7, 10.000001, 1.00000001])

# Guess initial scale values for all datasets
scale_guesses = [0 for _ in range(len(cell_lines_measured))]
scaled=True
if scaled:
    scale_bounds_min = [-2 for _ in range(len(cell_lines_measured))]
    scale_bounds_max = [2 for _ in range(len(cell_lines_measured))]
else:
    scale_bounds_min = [-0.001 for _ in range(len(cell_lines_measured))]
    scale_bounds_max = [0.001 for _ in range(len(cell_lines_measured))]

# set scale for HEK293T to 0
scale_bounds_min[0] = -0.001
scale_bounds_max[0] = 0.001

p0_scale = p0 + scale_guesses
bounds_scale = (bounds[0]+scale_bounds_min, bounds[1]+scale_bounds_max)

popt_scales_filter, pcov = opt.curve_fit(
    lambda x, *params: hill_func_log_scales(x, dataset_indices, *params),
    x_data,
    y_data,
    p0=p0_scale,
    bounds=bounds_scale,
    maxfev=5000
)
scales_filter = list(popt_scales_filter[num_params:])

In [None]:
popt_scales_filter

In [None]:
# %%capture output
# create the plot folder if it does not exist
if not os.path.exists(os.path.join(plot_folder, "1.9_crosstalk_filtering")):
    os.makedirs(os.path.join(plot_folder, "1.9_crosstalk_filtering"))
    
df_deviation = pd.DataFrame(columns=df_knockdown_filter.columns, index=df_expression_filter.index)

for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales_filter[cell_line_index]
    print(cell_line)
    print(current_scale)

    allowed_mirnas_meas = allowed_mirnas_all_meas[cell_line]

    ex_df = df_expression_filter.loc[allowed_mirnas_meas, cell_line] + current_scale
    knock_df = df_knockdown_filter.loc[allowed_mirnas_meas, cell_line]

    fig = plt.figure(figsize=(3, 2.5))
    plt.rcParams.update({'font.size': 8})
    plt.scatter(ex_df, knock_df, color="black", s=5)
    plt.plot(x_range_log, hill_func_log_alt(x_range_log, *popt_scales_filter[:num_params]), color="forestgreen", linewidth=2, label="fit")
    r2 = stats.pearsonr(hill_func_log_alt(ex_df, *popt_scales_filter[:num_params]), knock_df)[0]**2
    df_deviation[cell_line]

    # calculate the deviation
    df_deviation[cell_line] = (knock_df-hill_func_log_alt(ex_df+current_scale, *popt_scales_filter[:num_params]))
    # for i, row in knock_df.items():
    #     #if high_confidence.loc[i, "family_extended"] in outlier_families:
    #     plt.text(ex_df[i], 10**row, "-".join(i.split("-")[2:]), fontsize=5)
    
    plt.xlabel("log10(expression)")
    plt.ylabel("log10(stability)")
    #plt.legend(loc="lower left", fontsize=8)
    plt.title(f"{cell_line}, r2 = " + str(round(r2, 2)), fontsize=8)
    plt.tight_layout()
    if scaled:
        plt.savefig(os.path.join(plot_folder, f"1.9_crosstalk_filtering/{cell_line}__crosstalk_rem_with_scales.png"), dpi=300)
    else:
        plt.savefig(os.path.join(plot_folder, f"1.9_crosstalk_filtering/{cell_line}__crosstalk_rem_wo_scales.png"), dpi=300)

In [None]:
fig = plt.figure(figsize=(3, 2.5))
plt.rcParams.update({'font.size': 9})
cell_line_colors = {"HEK293T": "blue", "SKNSH": "green", "HeLa": "red", "MCF7": "orange"}
cell_line_symbols = {"HEK293T": "o", "SKNSH": "s", "HeLa": "D", "MCF7": "v"}
ex_dfs = []
knock_dfs = []

# plot all
for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales_filter[cell_line_index]
    print(cell_line)
    print(current_scale)

    allowed_mirnas_meas = allowed_mirnas_all_meas[cell_line]

    ex_df = df_expression_filter.loc[allowed_mirnas_meas, cell_line] + current_scale
    knock_df = df_knockdown_filter.loc[allowed_mirnas_meas, cell_line]

    ex_dfs.append(ex_df)
    knock_dfs.append(knock_df)

    plt.scatter(ex_df, knock_df, color=cell_line_colors[cell_line], s=5, marker=cell_line_symbols[cell_line], label=cell_line)

plt.plot(x_range_log, hill_func_log_alt(x_range_log, *popt_scales_filter[:num_params]), color="black", linestyle="--", label="hill function", linewidth=1.5)
r2 = stats.pearsonr(hill_func_log_alt(pd.concat(ex_dfs), *popt_scales_filter[:num_params]), pd.concat(knock_dfs))[0]**2
plt.xlabel("log10(expression)")
plt.ylabel("log10(stability)")
plt.legend(loc="lower left", fontsize=8)
plt.title(f"r2 = " + str(round(r2, 2)), fontsize=8)
plt.tight_layout()
if scaled:
    plt.savefig(os.path.join(plot_folder, f"1.9_crosstalk_filtering/all_crosstalk_rem_with_scales.png"), dpi=300)
else:
    plt.savefig(os.path.join(plot_folder, f"1.9_crosstalk_filtering/all_crosstalk_rem_wo_scales.png"), dpi=300)

## 1.10 Export the data for later use

In [73]:
def intersection_of_lists(dictionary):
    # Convert the first list in the dictionary to a set
    intersection_set = set(dictionary[next(iter(dictionary))])
    
    # Iterate over the other lists in the dictionary and update the intersection set
    for key in dictionary:
        intersection_set.intersection_update(dictionary[key])

    return list(intersection_set)

# get the intersection of all allowed mirnas
allowed_mirnas_all_intersection = intersection_of_lists(allowed_mirnas_all)

In [None]:
df_alles_quantile_scaled = df_alles_quantile_filter.copy()
for cell_line in cell_lines_measured:
    cell_line_index = cell_lines_measured.index(cell_line)
    current_scale = scales_filter[cell_line_index]
    print(cell_line)
    print(current_scale)

    df_alles_quantile_scaled.loc[:,cell_line] = df_alles_quantile_filter.loc[:,cell_line] + current_scale

In [75]:
df_alles_quantile_scaled2 = df_alles_quantile_scaled.loc[allowed_mirnas_all_intersection,:]
# df_alles_quantile_scaled2 = df_alles_quantile_scaled2[~df_alles_quantile_scaled2.index.isin(outlier_families_mirnas)]

In [78]:
df_alles_quantile_scaled.to_csv('../input_data/miRNA_expression_data/1_output/1.10_alles_quantile_no_crosstalk_filter.csv')
df_alles_quantile_scaled2.to_csv('../input_data/miRNA_expression_data/1_output/1.10_alles_quantile_crosstalk_filter.csv')

In [80]:
# save the parameters
if scaled:
    filename = '../output/1_output/1.10_fit_parameters_with_scales.txt'
else:
    filename = '../output/1_output/1.10_fit_parameters_without_scales.txt'

# create the output folder if it does not exist
if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))

with open(filename, "w") as f:
    f.write("c1\tc2\tn\tHEK293T\tSKNSH\tHeLa\tMCF7\n")
    f.write("\t".join([str(i) for i in popt_scales_filter]) + "\n")