In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as stats
import scipy.optimize as opt
import itertools
import random
import re
import os
import pickle
from library2_utils.color_scheme import cell_line_colors, cell_line_symbols
from library2_utils.design_utilities import tsi

# set the font size
plt.rcParams.update({'font.size': 7})
# set Helvetica globally
plt.rcParams['font.family'] = 'Helvetica'

# only use cell lines for which I have a substantial amount of data from different sources
cell_lines_measured = ["HEK293T", "HeLa", "SKNSH", "MCF7", "A549", "PC3"]

plot_folder = "../plots/11_data_quality"
if not os.path.exists(plot_folder):
    os.makedirs(plot_folder)

### This notebook compares different datasets. The links to the individual datasets are given in the Supplementary Tables. The data needs to be downloaded and processed in order to run this Notebook. If that proves challenging, feel free to contact me.

In [2]:
def normalize_expr_df_to_rpm(df):
    # normalize
    df = df.div(df.sum(axis=0), axis=1) * 1000000
    # deduct the minimum and add one to the expression data to avoid division by 0
    df = df - df.min() + 1
    # normalize to rpm
    df = df.div(df.sum(axis=0), axis=1) * 1000000
    return df

## Load stability data

In [3]:
data_dir_input = "../measured_data/2_normalized_log10"

# get the name of all files in "reference" folder
reference_files = os.listdir(data_dir_input)

# read them into a dictionary
reference_dict = {}
for reference_file in reference_files:
    if reference_file.endswith(".csv"):
        reference_dict[reference_file.split('.')[0]] = pd.read_csv(os.path.join(data_dir_input, reference_file), index_col=0)

In [4]:
# get all dfs that contain "single" in their key
single_dfs = {key: reference_dict[key].copy() for key in reference_dict.keys() if "full_single_high_conf" in key or "full_repeat" in key}

for key, df in single_dfs.items():
    df.set_index("miRNA1", inplace=True)
    df = df.filter(regex='_3UTR')
    
    # drop the _3UTR_log10 suffix from the column names
    df.columns = df.columns.str.replace('_3UTR', '')

    single_dfs[key] = df

# Load the mirna data

In [6]:
# get only the high confidence microRNAs in mirbase
mirbase = pd.read_csv("../microrna_data/mirbase_extended.csv", index_col=0)
mirbase_high_confidence = mirbase[mirbase["confidence"] == "high"]

In [7]:
mirna_input_folder_microarray = "../microrna_data/11_input/microarray"
mirna_input_folder_sequencing = "../microrna_data/11_input/sequencing"
    
# get the name of all files in the microarray folder
microarray_files = os.listdir(mirna_input_folder_microarray)
# get the name of all files in the sequencing folder
sequencing_files = os.listdir(mirna_input_folder_sequencing)

# read them into dictionaries
microarray_dict = {}
sequencing_dict = {}
for microarray_file in microarray_files:
    if microarray_file.endswith(".csv"):
        microarray_dict[microarray_file.split('.')[0]] = pd.read_csv(os.path.join(mirna_input_folder_microarray, microarray_file), index_col=0)
for sequencing_file in sequencing_files:
    if sequencing_file.endswith(".csv"):
        sequencing_dict[sequencing_file.split('.')[0]] = pd.read_csv(os.path.join(mirna_input_folder_sequencing, sequencing_file), index_col=0)
        
# join these dictionaries
mirna_dict = {**microarray_dict, **sequencing_dict}       

In [9]:
dataset_metadata = pd.read_excel("../microrna_data/11_input/dataset_metadata.xlsx")

# reorganize the order in mirna dict to match the order in dataset_metadata["Name"]
mirna_dict = {key: mirna_dict[key] for key in dataset_metadata["Name"]}

In [None]:
# iterate over the rest of the DataFrames and update the set to keep only common indices
for key in mirna_dict:
    print(f"Length of {key}:", len(mirna_dict[key]))

In [None]:
# filter them to common mirnas
initial_key = next(iter(mirna_dict))  # Get the first key from the dictionary
common_indices = set(mirna_dict[initial_key].index)

# iterate over the rest of the DataFrames and update the set to keep only common indices
for key in mirna_dict:
    if key != initial_key:
        current_indices = set(mirna_dict[key].index)
        common_indices.intersection_update(current_indices)

# common_indices contains the indices present in all datasets
common_mirnas = list(common_indices)

# filter them to high confidence data in mirbase, then normalize and make them log10
for key in mirna_dict.keys():
    df = mirna_dict[key].loc[common_mirnas, :]
    
    # filter to high confidence
    df = df[df.index.isin(mirbase_high_confidence.index)]
    
    # normalize to a sum of 10E6
    df = normalize_expr_df_to_rpm(df)
    
    # set all values smaller than 100 to 100
    df[df < 100] = 100
    
    mirna_dict[key] = np.log10(df)
    print(key, len(df))

In [12]:
%%capture output
current_plot_folder = f"{plot_folder}/mirna_dataset_correlations"
if not os.path.exists(current_plot_folder):
    os.makedirs(current_plot_folder)

cell_lines = cell_lines_measured
keys = list(mirna_dict.keys())
no_datasets = len(keys)

dataset_correlations = {cell_line: pd.DataFrame(columns=keys, index=keys) for cell_line in cell_lines}
for cell_line in cell_lines:
    # Determine the number of valid plots for this cell line
    valid_keys = [key for key in keys if cell_line in mirna_dict[key]]
    no_datasets = len(valid_keys)

    fig, axes = plt.subplots(no_datasets, no_datasets, figsize=(no_datasets*0.6, no_datasets*0.6))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)

    print(cell_line, no_datasets)
    
    if no_datasets == 1:
        continue
    plot_count = 0
    for i in range(no_datasets):
        for j in range(no_datasets):
            row = i
            col = j
            ax = axes[col, row]
            
            # Get the data for each key
            mirna_x = mirna_dict[valid_keys[row]][cell_line]
            mirna_y = mirna_dict[valid_keys[col]][cell_line]

            # Plot data
            ax.scatter(mirna_x, mirna_y, color="tab:blue", s=1.5, edgecolor="none", rasterized=True)
            r2 = stats.pearsonr(mirna_x, mirna_y)[0] ** 2
            
            # save the correlation
            dataset_correlations[cell_line].loc[valid_keys[row], valid_keys[col]] = r2
            
            title_set1 = dataset_metadata[dataset_metadata["Name"] == valid_keys[row]]["Abbrev_name"].values[0]
            title_set2 = dataset_metadata[dataset_metadata["Name"] == valid_keys[col]]["Abbrev_name"].values[0]
            #ax.set_title(f"{title_set1} vs {title_set2}", fontsize=7)
            
            if j == no_datasets - 1:
                ax.set_xlabel(title_set1, fontsize=6.5)

            if i == 0:
                ax.set_ylabel(title_set2, fontsize=6.5)
                
            ax.set_xlim(1.5, 6)
            ax.set_ylim(1.5, 6)
            
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.text(1.7, 5, f"{r2:.2f}".lstrip('0'), fontsize=6, bbox=dict(facecolor='white', alpha=1, edgecolor='none', boxstyle='round,pad=0.1'))
            
            if i==0 and j==2:
                ax.text(0.5, 1.1, "Aff.: Affymetrix microarray\nAgi.: Agilent microarray\nSt. Seq: Standard Sequencing\nIm. Seq: Improved Sequencing", fontsize=7)
            
            plot_count += 1

    # Hide unused axes in the upper right triangle
    for x in range(no_datasets):
        for y in range(no_datasets):
            if x - 1 < y:
                axes[x, y].set_visible(False)
    
    
    for format in ["png", "svg"]:
        plt.savefig(f"{current_plot_folder}/{cell_line}_mirna_correlations.{format}", dpi=600)

In [13]:
dataset_correlations_average = pd.DataFrame(columns=keys, index=keys)
dataset_correlations_average = dataset_correlations_average.fillna(0)
dataset_correlations_count = pd.DataFrame(columns=keys, index=keys)
dataset_correlations_count = dataset_correlations_count.fillna(0)

for cell_line in cell_lines:
    # add the count to non-NaN values
    mask = dataset_correlations[cell_line].notnull()
    dataset_correlations_count[mask] += 1

    dataset_correlations_average += dataset_correlations[cell_line].fillna(0)

In [14]:
dataset_correlations_average[dataset_correlations_average == 0] = np.nan
dataset_correlations_average = dataset_correlations_average / dataset_correlations_count

In [15]:
# reorder the average based on the dataset_metadata["Name"] column
dataset_correlations_average = dataset_correlations_average.reindex(dataset_metadata["Name"], axis=0)
# do the same for the columns
dataset_correlations_average = dataset_correlations_average.reindex(dataset_metadata["Name"], axis=1)

In [None]:
# plot a heatmap
current_plot_folder = f"{plot_folder}/mirna_dataset_correlations"
if not os.path.exists(current_plot_folder):
    os.makedirs(current_plot_folder)

fig, ax = plt.subplots(figsize=(3, 2.86))

mask = np.triu(np.ones_like(dataset_correlations_average, dtype=bool))

sns.heatmap(dataset_correlations_average, mask=mask, cmap="viridis", ax=ax, vmin=0.2, vmax=0.8, cbar_kws={'label': r'$r^2$'}, square=True)
# set the xticks to dataset_metadata["Abbrev_name"]
plt.xticks(ticks=np.arange(len(dataset_metadata))+0.5, labels=dataset_metadata["Abbrev_name"], rotation=90)
# set the yticks to dataset_metadata["Abbrev_name"]
plt.yticks(ticks=np.arange(len(dataset_metadata))+0.5, labels=dataset_metadata["Abbrev_name"], rotation=0)
plt.xlabel("")
plt.ylabel("")
for format in ["png", "svg"]:
    plt.savefig(f"{current_plot_folder}/mirna_correlations_heatmap.{format}", dpi=600)

# Determine cross dataset correlation

# Fit the Hill function

In [17]:
def hill_func_log_scales(x_data, dataset_indices, c1=3, c2=10, *scales):
    """This is a hill function for a set of microRNA expression values that can be scaled individually.
    
    The expression is assumed to be normalized to one.
    The microRNA data is assumed to be log10.
    The return value is also log10."""
    c1 = 10**c1
    c2 = 10**c2
    results = []

    for i, scale in enumerate(scales):
        mask = (dataset_indices == i)
        x = x_data[mask] + scale
        x = 10**x
        result = (1 / (1 + x / c1)) * (1 + x / c2)
        results.append( np.log10( result ))
    return np.concatenate(results)

def hill_func_log_regular(x, c1=3, c2=10):
    """The expression is assumed to be normalized to one.
    The microRNA data is assumed to be log10.
    The return value is also log10."""
    x = 10**x
    c1 = 10**c1
    c2 = 10**c2
    
    result = (1 / (1 + x / c1)) * (1 + x / c2)
    return np.log10(result)

In [18]:
%%capture output
r2_vals = {}
rmsd_vals = {}
x_range_log = np.arange(0, 5.5, 0.01)
for knockdown_key in single_dfs.keys():
    r2_vals[knockdown_key] = pd.DataFrame(index=cell_lines_measured, columns=mirna_dict.keys())
    rmsd_vals[knockdown_key] = pd.DataFrame(index=cell_lines_measured, columns=mirna_dict.keys())
    # create a plot folder for this knockdown_key
    current_plot_folder = os.path.join(plot_folder, knockdown_key)
    # create it if it doesn't exist
    if not os.path.exists(current_plot_folder):
        os.makedirs(current_plot_folder)
    
    for mirna_key in mirna_dict.keys():
        df_expression = mirna_dict[mirna_key].copy()
        df_knockdown = single_dfs[knockdown_key].copy()
        
        # get the cell lines as the column intersections
        current_cell_lines = list(set(df_expression.columns).intersection(df_knockdown.columns))
        
        # get the common index as the index intersection
        common_mirnas = list(set(df_expression.index).intersection(df_knockdown.index))
        
        # constrain the dataframes
        df_knockdown = df_knockdown.loc[common_mirnas, current_cell_lines]
        df_expression = df_expression.loc[common_mirnas, current_cell_lines]
        
        # -----------------------------------------------------------------
        # PREPARE DATA
        x_data = []
        y_data = []
        dataset_indices = []
        for i, cell_line in enumerate(current_cell_lines):
            ex_df = df_expression[cell_line].values
            knock_df = df_knockdown[cell_line].values
            x_data.append(ex_df)
            y_data.append(knock_df)
            dataset_indices.append([i] * len(ex_df))

        x_data = np.concatenate(x_data)
        y_data = np.concatenate(y_data)
        dataset_indices = np.concatenate(dataset_indices)


        # -----------------------------------------------------------------
        # EXECUTE FIT
        # set bounds and initial guesses for non-scale fitting parameters
        p0 = [3, 10]
        num_params = len(p0)
        bounds = ([1, 9.99], [10, 10.01])

        # Guess initial scale values for all datasets
        scale_guesses = [0 for _ in range(len(current_cell_lines))]
        scale_bounds_min = [-2 for _ in range(len(current_cell_lines))]
        scale_bounds_max = [2 for _ in range(len(current_cell_lines))]

        # set scale for HEK293T to 0
        scale_bounds_min[0] = -0.001
        scale_bounds_max[0] = 0.001

        # set up parameters
        p0_scale = p0 + scale_guesses
        bounds_scale = (bounds[0]+scale_bounds_min, bounds[1]+scale_bounds_max)

        popt_scales, pcov = popt_scales_filter, pcov = opt.curve_fit(
            lambda x, *params: hill_func_log_scales(x, dataset_indices, *params),
            x_data,
            y_data,
            p0=p0_scale,
            bounds=bounds_scale,
            maxfev=5000
        )

        scales = list(popt_scales[num_params:])
        hill_params = popt_scales[:num_params]

        # -----------------------------------------------------------------
        # PLOT AND SAVE CORRELATION

        for i, cell_line in enumerate(current_cell_lines):
            current_scale = scales[i]
            
            plt.figure(figsize=(3,2))
            
            plt.scatter(df_expression[f"{cell_line}"]+current_scale, df_knockdown[f"{cell_line}"], s=5, color="black")
            plt.plot(x_range_log, hill_func_log_regular(x_range_log,
                        *hill_params), color="forestgreen", linewidth=2, label="fit")

            # calculate the R2 value
            r2 = stats.pearsonr(df_knockdown[f"{cell_line}"],
                                hill_func_log_regular(df_expression[f"{cell_line}"]+current_scale,
                                *hill_params))[0]**2
            r2_vals[knockdown_key].loc[cell_line, mirna_key] = r2
            
            # calculate the RMSD value
            rmsd = np.sqrt(np.mean((df_knockdown[f"{cell_line}"]-
                                    hill_func_log_regular(df_expression[f"{cell_line}"]+current_scale,
                                    *hill_params))**2))
            rmsd_vals[knockdown_key].loc[cell_line, mirna_key] = rmsd


            plt.xlabel("miRNA expression")
            plt.ylabel(r"log$_{10}$(RNA/DNA)")
            plt.title(f"{mirna_key},{knockdown_key},{cell_line}\nr2 = {round(r2, 2)}, rmsd = {round(rmsd, 2)}", fontsize=7)

            plt.xlim(0, 5.5)
            plt.ylim(-1.7, 0.25)
            
            plt.tight_layout()
            for format in ["png", "svg"]:
                plt.savefig(f"{current_plot_folder}/{knockdown_key}_{mirna_key}_{cell_line}.{format}", dpi=300)

In [19]:
%%capture output
r2_vals_copy = r2_vals.copy()
for knockdown_key in r2_vals_copy.keys():
    curr_r2_vals = r2_vals_copy[knockdown_key].copy().astype("float")
    
    # filter to datasets that exist in curr_r2_vals AND dataset_metadata
    curr_r2_vals = curr_r2_vals.loc[cell_lines_measured, dataset_metadata["Name"]]
    
    # add the Type of the dataset
    curr_r2_vals = curr_r2_vals.T
    curr_r2_vals = curr_r2_vals.join(dataset_metadata.set_index("Name")["Type"])
    r2_vals_copy[knockdown_key] = curr_r2_vals
    
    # create a heatmap
    plt.figure(figsize=(3,2.7))
    sns.heatmap(curr_r2_vals[cell_lines_measured], vmin=0.2, vmax=0.8, cmap="viridis", annot=True, fmt=".2f", cbar_kws={'label': r'r$^2$'})
    plt.xticks(rotation=45)
    plt.yticks(ticks=np.arange(0.5, len(dataset_metadata)+0.5, 1), labels=(dataset_metadata["Abbrev_name"]))
    
    plt.title(f"{knockdown_key}")
    plt.tight_layout()
    for i, entry in enumerate(dataset_metadata["Platform"].values):
        plt.text(20, i, entry, ha="center", va="center", fontsize=7)
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{knockdown_key}_R2.{format}"), dpi=300)

In [None]:
%%capture output
for knockdown_key in r2_vals_copy.keys():
    curr_r2_vals = r2_vals_copy[knockdown_key].copy()
    curr_r2_vals["Type"] = curr_r2_vals["Type"].replace({"Affymetrix microarray": "Affymetrix\nmicroarray",
                                                         "Agilent microarray": "Agilent",
                                                         "standard NGS": "basic\nNGS", 
                                                         "improved NGS": "improved"})
    
    # unroll along cell lines (except for the Type column)
    # I don't think this is wise - it'll lead to a single dataset dominating the boxplot
    # curr_r2_vals = pd.melt(curr_r2_vals, id_vars="Type", var_name="Cell line", value_name="R2")
    # average over cell lines
    curr_r2_vals["mean"] = curr_r2_vals[cell_lines_measured].mean(axis=1)
    print(knockdown_key)
    print(curr_r2_vals["mean"].groupby(curr_r2_vals["Type"]).mean())
    
    # create a boxplot
    # plt.figure(figsize=(2,1.6))
    plt.figure(figsize=(2,1.3))
    sns.boxplot(data=curr_r2_vals, x="Type", y="mean", width=0.6, palette="viridis", showfliers=True, flierprops=dict(marker='o', markersize=2))
    
    plt.ylabel(r"r$^2$")
    plt.xticks(fontsize=6.5)
    plt.xlabel("")
    plt.ylim(0, 1)
    plt.tight_layout()
    for format in ["png", "svg"]:
        plt.savefig(os.path.join(plot_folder, f"{knockdown_key}_R2_boxplot.{format}"), dpi=300)