# Deconvolution
Evaluate the performance of deconvolution for age prediction.

In [None]:
from utils.variation import meanexp, group_counts, sum_squares, var_comp
from utils.misc import extract_number, mae
import pandas as pd
import numpy as np
from scipy.stats import rankdata
import os
import matplotlib.pyplot as plt
from utils.viz import single_line_plot, single_scatter_plot, multi_scatter_plot

## Load Data

In [None]:
gene_expressions = pd.read_csv("data/train_data.csv", index_col=0)
gene_expressions_mat = gene_expressions.to_numpy()
genenames = np.array(gene_expressions.index.tolist())
samples = gene_expressions.columns.tolist()

# extract ages
ages = np.array([extract_number(timestring) for timestring in samples])
unique_ages=np.unique(ages)

# retain genes that are present in all samples
prevalence = np.mean(gene_expressions_mat > 0, axis=1)
subset_gene_id = np.where(prevalence == 1)[0]
subset_genenames = genenames[subset_gene_id]
gene_expressions = gene_expressions.loc[subset_genenames, :]
gene_expressions_mat = gene_expressions_mat[subset_gene_id, :]

# get log expressions
log_gene_expressions = np.log(gene_expressions)
log_gene_expressions_mat = np.log(gene_expressions_mat)

# transpose count tables to samples by genes
gene_expressions = gene_expressions.T
gene_expressions_mat = gene_expressions_mat.T
log_gene_expressions = log_gene_expressions.T
log_gene_expressions_mat = log_gene_expressions_mat.T

# get rankings of samples for each gene expression
gene_expressions_rank = log_gene_expressions.rank()


# Visualizing Genes
Plot the genes that have the most distinct expressions between samples of different ages.

In [None]:
def plot_top_genes(expressions_df, ages, age_min, age_max, type="raw", num_genes=50):

    if type not in ["raw", "log", "rank"]:
        raise ValueError("Argument must be one of 'raw', 'log', 'rank'")

    if type == "raw":
        input_df = expressions_df
    elif type == "log":
        input_df = np.log(expressions_df)
    else:
        input_df = expressions_df.rank()

    mask = np.logical_and(ages >= age_min, ages <= age_max)
    input_df = input_df.loc[mask, :]
    ages = ages[mask]
    unique_ages = np.unique(ages)
    age_rank = rankdata(ages, method='min').astype(int)
    unique_ranks = np.unique(age_rank)

    varcomp_sorted = var_comp(geneexp_df=input_df, groups=ages)
    varcomp_sorted = varcomp_sorted.sort_values(by='R2', ascending=False)
    sorted_genes = varcomp_sorted.index.tolist()

    for id in range(num_genes):

        genename = sorted_genes[id]
        values = input_df.loc[:, genename].to_numpy()
        values = values.reshape(1, -1)
        if type=="raw":
            fig = single_scatter_plot(ymat=values, xmat=age_rank.reshape(1, -1),
                                    xticks=unique_ranks, xticknames=unique_ages.astype(str),
                                    xname="Age (Month)", yname="Raw Expression", title=genename)
        elif type == "log":
            fig = single_scatter_plot(ymat=values, xmat=age_rank.reshape(1, -1),
                                    xticks=unique_ranks, xticknames=unique_ages.astype(str),
                                    xname="Age (Month)", yname="Log Expression", title=genename)
        else:
            fig = single_scatter_plot(ymat=values, xmat=age_rank.reshape(1, -1),
                                    xticks=unique_ranks, xticknames=unique_ages.astype(str),
                                    xname="Age (Month)", yname="Rank of Expression", title=genename)
        filename = f"{id}_{genename}_{age_min}_{age_max}_{type}.png"
        fig.savefig(os.path.join("gene_plots", type, f"age_{age_min}_{age_max}", filename), bbox_inches="tight")
        plt.close(fig)


In [None]:
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=2, age_max=42, type="raw")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=2, age_max=42, type="log")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=2, age_max=42, type="rank")

In [None]:
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=23, type="raw")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=23, type="log")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=23, type="rank")

In [None]:
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=42, type="raw")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=42, type="log")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=4, age_max=42, type="rank")

In [None]:
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=6, age_max=18, type="raw")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=6, age_max=18, type="log")
plot_top_genes(expressions_df=gene_expressions, ages=ages, age_min=6, age_max=18, type="rank")

## Deconvolution

I attempted deconvolution with different:
* Input: raw gene expression, log gene expression or ranks of gene expression
* Number of genes included
* Weight of the genes

In [None]:
from utils.deconvolution import loo_predict

In [None]:
def mae_tune(input_expression, labels, label_min=2, label_max=50, filter="raw", deconv="raw",
             numgenes=np.array([50, 100, 200, 400, 600, 800, 1000, 2000, 5000, 10000, 16000])):

    if filter not in ["raw", "log", "rank"]:
        raise ValueError("'filter' Argument must be one of 'raw', 'log', 'rank'")

    if deconv not in ["raw", "log", "rank"]:
        raise ValueError("'deconv' Argument must be one of 'raw', 'log', 'rank'")


    expressions = input_expression.copy()

    targets = labels.copy()
    mask = np.logical_and(targets >= label_min, targets <= label_max)
    expressions = expressions.loc[mask, :]
    log_expressions = np.log(expressions)
    expressions_rank = expressions.rank()
    targets = targets[mask]


    if filter == "raw":
        varcomp_sorted = var_comp(geneexp_df=expressions, groups=targets)
    elif filter == "log":
        varcomp_sorted = var_comp(geneexp_df=log_expressions, groups=targets)
    else:
        varcomp_sorted = var_comp(geneexp_df=expressions_rank, groups=targets)


    varcomp_sorted = varcomp_sorted.sort_values(by='R2', ascending=False)
    sorted_genes = varcomp_sorted.index.tolist()

    results_df = pd.DataFrame({"NumGene": numgenes,
                            "MAE": 0})

    if deconv == "raw":
        input = expressions
    elif deconv == "log":
        input = log_expressions
    else:
        input = expressions_rank

    best_mae = 20
    best_predictions = np.zeros_like(targets)

    for j in range(len(results_df)):
        num_genes = results_df["NumGene"][j]
        selected_genes = sorted_genes[0:num_genes]
        # weighed by R square
        output = loo_predict(expression_df=input.loc[:, selected_genes],
                                labels=targets, weighted="None", normalize=True)
        latest_mae = mae(output[0]["Truth"], output[0]["Predicted"], type="mean")
        if latest_mae < best_mae:
            best_mae = latest_mae
            best_predictions = output[0]["Predicted"]

        results_df["MAE"][j] = latest_mae

    prediction_df = pd.DataFrame({"Truth": targets,
                                  "Predicted": best_predictions})

    return results_df, prediction_df


In [None]:
def aggregate_performance_report(expressions, ages, min_age=2, max_age=42, ngenes_tune = None):

    if ngenes_tune is None:
        ngenes_tune = [50, 100, 200, 400, 600, 800, 1000, 2000, 5000, 10000, 16000]

    mae_summary = dict({})
    prediction_summary = dict({})
    optimal_filter_base=None
    optimal_deconv_base=None
    optimal_ngene = 0

    best_mae = 40

    for filter_base in ["raw", "log", "rank"]:
        for deconv_base in ["raw", "log", "rank"]:
            mae_df, prediction = mae_tune(input_expression=expressions, labels=ages, label_min=min_age, label_max=max_age,
                                         filter=filter_base, deconv=deconv_base, numgenes=ngenes_tune)
            if np.min(mae_df["MAE"]) < best_mae:
                best_mae = np.min(mae_df["MAE"])
                optimal_filter_base = filter_base
                optimal_deconv_base = deconv_base
                optimal_ngene = ngenes_tune[np.argmin(mae_df["MAE"])]
            mae_summary[f"{filter_base}_{deconv_base}"] = mae_df
            prediction_summary[f"{filter_base}_{deconv_base}"] = prediction

    maes_mat = np.zeros((9, len(ngenes_tune)))
    j=0
    for filter_base in ["raw", "log", "rank"]:
        for deconv_base in ["raw", "log", "rank"]:
            maes_mat[j, :] = mae_summary[f"{filter_base}_{deconv_base}"]["MAE"].to_numpy()
            j = j+1

    best_prediction = prediction_summary[f"{optimal_filter_base}_{optimal_deconv_base}"]
    truth = best_prediction["Truth"].to_numpy()
    predictions = best_prediction["Predicted"].to_numpy()

    performance_title = f"{min_age}-{max_age} Months old"

    colors_map = {"Raw-GeneSelect": "#344885", "Log-GeneSelect": "#db382c",
                  "Rank-GeneSelect": "#2f7028"}

    linetypes_map = {"Raw-Deconv": '-', "Log-Deconv": '-.', "Rank-Deconv": ':'}

    performance_plot = single_line_plot(ymat=maes_mat, xticks=np.arange(len(ngenes_tune)),
                                 xticknames=ngenes_tune,
                                 colors=np.repeat(["#344885", "#db382c", "#2f7028"], 3),
                                 linetypes=np.tile(['-o', '-.o', ':o'], 3),
                                 xname="Number of Genes", yname="Mean Absolute Error",
                                 colors_map=colors_map, linetypes_map=linetypes_map,
                                 title=performance_title, size=(6.5, 4))
    performance_plot.savefig(f"deconvolution/plots/MAEs/MAE_{min_age}_{max_age}_perm.png", bbox_inches="tight")

    prediction_plot = single_scatter_plot(ymat=predictions.reshape(1, -1), xticks=truth, xticknames=truth.astype(str),
                                         xname="True Age (Month)", yname="Predicted Age (Month)",
                                         size=(6, 4), diag_line=True)
    prediction_plot.savefig(f"deconvolution/plots/scatterplots/prediction_{min_age}_{max_age}_optimal.png", bbox_inches="tight")

    return mae_summary, prediction_summary, (optimal_filter_base, optimal_deconv_base, optimal_ngene)


In [None]:
mae_summary, prediction_summary, best_param_choice = aggregate_performance_report(expressions=gene_expressions, ages=ages,
                                                                                    min_age=2, max_age=42,
                                                                                    ngenes_tune = [50, 100, 200, 400, 600, 800, 1000, 2000])

In [None]:
mae_summary, prediction_summary, best_param_choice = aggregate_performance_report(expressions=gene_expressions, ages=ages, min_age=4, max_age=42, ngenes_tune = [50, 100, 200, 400, 600, 800, 1000, 2000])

In [None]:
mae_summary, prediction_summary, best_param_choice = aggregate_performance_report(expressions=gene_expressions, ages=ages,
                                                                                  min_age=4, max_age=23, ngenes_tune = [50, 100, 200, 400, 600, 800, 1000, 2000])

In [None]:
mae_summary, prediction_summary, best_param_choice = aggregate_performance_report(expressions=gene_expressions, ages=ages,
                                                                                  min_age=6, max_age=18, ngenes_tune = [50, 100, 200, 400, 600, 800, 1000, 2000])