This notebook loads and processes data associated with Hagai *et al.* (2018)

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from scipy import linalg, spatial, stats
from sklearn import linear_model

import rp2
from rp2 import create_gene_symbol_map, hagai_2018
from rp2.paths import get_output_path

rp2.check_environment()

Analysis will be performed on an scRNA dataset generated as part of the study by Hagai *et al.* (2018). UMI counts are available for phagocytes stimulation with lipopolysaccharide (LPS) and dsRNA polyinosinic:polycytidylic acid (poly(I:C) or PIC):

> Primary bone marrow-derived mononuclear phagocytes originating from females of four different species (black 6 mouse, brown Norway rat, rabbit and pig)... Cells were stimulated with: (1) ...LPS...or with (2) ...poly(I:C)... LPS stimulation time courses of 0, 2, 4, 6 h were performed for all species. Poly(I:C) stimulations were performed for rodents for 0, 2, 4, 6 h.

There are three biological replicates (i.e. three individuals) for each species (as alluded to in the paper and confirmed more explicitly in Supplementary Table 2). In the case of poly(I:C), mouse replicate 1 appears to be missing UMI counts for time 6 h whereas replicate 2 has two readings for this time point (6 and 6A). This does not appear to be discussed in the paper but may suggest that replicate 2 was inadvertently sampled a second time in place of replicate 1. The arrangement of samples is further illustrated in the [table included with the ArrayExpress dataset](https://www.ebi.ac.uk/arrayexpress/experiments/E-MTAB-6754/samples/). The UMI matrices have been QCed and clustered:

> Since bone marrow-derived phagocytes may include secondary cell populations, we focused our analysis on the major cell population. We identified clusters within each data set...and have taken the cells belonging to the largest cluster for further analysis, resulting in a less heterogeneous population of cells.

Subsequence analysis will focus upon a subset of the genes:

> To quantify transcriptional divergence in immune responses between species, we focused on genes that were differentially expressed during the stimulation (see Methods). For simplicity, we refer to these genes as ‘responsive genes’ (Fig. 1c). In this analysis, we study the subset of these genes with one-to-one orthologues across the studied species. There are 955 such responsive genes in dsRNA-stimulated human fibroblasts and 2,336 in LPS-stimulated mouse phagocytes. 

# 1. Loading and preparing data<a id="1" />

## 1.1. Settings<a id="1_1" />

In [None]:
study_species = "mouse"
study_replicates = ["1", "2", "3"]
study_control = "unst"
study_treatments = ["lps", "pic"]
study_time_points = ["0", "2", "4", "6"]

## 1.2. List of "responsive" genes<a id="1_2" />

In [None]:
responsive_phagocyte_genes = hagai_2018.load_lps_responsive_genes()
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

Create a map between gene IDs and symbols

In [None]:
symbol_map = create_gene_symbol_map(study_species)

## 1.3. UMI counts<a id="1_3" />

Load the UMI counts and reduce the dataset to values of interest (to optimise memory usage and access performance)

In [None]:
umi_ad = hagai_2018.load_umi_count(study_species)

umi_ad = umi_ad[:, responsive_phagocyte_genes]
umi_ad = umi_ad[umi_ad.obs.replicate.isin(study_replicates), :]
umi_ad = umi_ad[umi_ad.obs.treatment.isin(study_treatments + [study_control]), :]
umi_ad = umi_ad[umi_ad.obs.time_point.isin(study_time_points), :]
umi_ad = umi_ad.copy()

Create a StudyTreatmentSet for each treatment of interest so they may each be analysed

In [None]:
class StudyTreatmentSet:
    @staticmethod
    def add_method(method):
        setattr(StudyTreatmentSet, method.__name__, method)

    def __init__(self, umi_counts_ad, treatments, control_name):
        all_treatments = treatments + [control_name]
        self.umi_counts_ad = umi_counts_ad[umi_counts_ad.obs.treatment.isin(all_treatments), :]
        self.condition_stats_df = None
        self.lr_fit_df = None

    @property
    def number_of_genes(self):
        return self.umi_counts_ad.n_vars

    @property
    def number_of_conditions(self):
        return len(self.umi_counts_ad.obs.loc[:, ["replicate", "treatment", "time_point"]].drop_duplicates())

    def check_integrity(self):
        gene_set = set(self.umi_counts_ad.var_names)
        if self.condition_stats_df is not None:
            assert(gene_set == set(self.condition_stats_df.gene))

    def drop_genes(self, genes):
        self.umi_counts_ad = self.umi_counts_ad[:, ~self.umi_counts_ad.var_names.isin(genes)]
        self.condition_stats_df = self.condition_stats_df.loc[~self.condition_stats_df.gene.isin(genes)]

        self.check_integrity()


treatment_sets = {}
for treatment in study_treatments:
    treatment_sets[treatment] = StudyTreatmentSet(umi_ad, [treatment], study_control)

treatment_sets["all"] = StudyTreatmentSet(umi_ad, study_treatments, study_control)

for set_name, treatment_set in treatment_sets.items():
    print(f'"{set_name}" treatment set:')
    print(f"  {treatment_set.number_of_genes} genes")
    print(f"  {treatment_set.number_of_conditions} conditions")

# 2. Quality control<a id="2" />

## 2.1. Settings<a id="2_1" />

In [None]:
minimum_samples_per_gene = 6

## 2.2 Remove undersampled genes<a id="2_2" />

Calculate per condition statistics (for each treatment set)

In [None]:
def calculate_condition_stats(self: StudyTreatmentSet):
    self.condition_stats_df = hagai_2018.calculate_umi_condition_stats(self.umi_counts_ad)


StudyTreatmentSet.add_method(calculate_condition_stats)

for treatment_set in treatment_sets.values():
    treatment_set.calculate_condition_stats()
    display(symbol_map.added_to(treatment_set.condition_stats_df).head())

Drop undersampled genes (those that have too few conditions with a non-zero mean UMI count)

In [None]:
for set_name, treatment_set in treatment_sets.items():
    n_non_zero_means_series = treatment_set.condition_stats_df.groupby("gene", as_index=True)["mean"].apply(np.count_nonzero)
    undersampled_genes = n_non_zero_means_series[n_non_zero_means_series < minimum_samples_per_gene].index
    treatment_set.drop_genes(undersampled_genes)
    print(f'"{set_name}" dropped {len(undersampled_genes)} genes ({treatment_set.number_of_genes} remaining)')

Enumerate remaining conditions with zero mean UMI count by time point (for each treatment set)

In [None]:
def drop_sample_points(self: StudyTreatmentSet, condition_ids):
    self.condition_stats_df = self.condition_stats_df.drop(index=condition_ids)
    self.check_integrity()


StudyTreatmentSet.add_method(drop_sample_points)

remove_all_zero_mean_samples = True

for set_name, treatment_set in treatment_sets.items():
    print(f'Zero mean conditions remaining for "{set_name}" treatment set:')
    zero_mean_series = treatment_set.condition_stats_df.loc[treatment_set.condition_stats_df["mean"] == 0, "time_point"]
    for time_point, n in zero_mean_series.value_counts().sort_index().iteritems():
        print(f"  {n} at time point {time_point}")

    if remove_all_zero_mean_samples:
        treatment_set.drop_sample_points(zero_mean_series.index)

## 2.3 Remove genes with low abundance<a id="2_3" />

In [None]:
minimum_max_mean = 1

for set_name, treatment_set in treatment_sets.items():
    print(f'"{set_name}" treatment set:')
    max_mean_series = treatment_set.condition_stats_df.groupby("gene")["mean"].apply(np.max).sort_values()

    rejected_genes = max_mean_series[max_mean_series < minimum_max_mean]
    print(f"  {len(rejected_genes):,} genes rejected ({len(max_mean_series) - len(rejected_genes):,} remaining)")

    max_mean_series.plot.hist(figsize=(16, 4), bins=40, log=True).set(xlabel="Max mean")
    plt.axvline(x=minimum_max_mean)
    plt.show()

    if len(rejected_genes) == 0:
        continue

    top_rejected_genes = rejected_genes[-4:]
    print(f"  Top {len(top_rejected_genes)} rejections:")
    _, axes = plt.subplots(1, len(top_rejected_genes), figsize=(4 * len(top_rejected_genes), 4))
    for gene_id, ax in zip(top_rejected_genes.index, axes.flatten()):
        sns.scatterplot(
            x="mean",
            y="variance",
            hue="time_point",
            style="replicate",
            data=treatment_set.condition_stats_df.loc[treatment_set.condition_stats_df.gene == gene_id],
            ax=ax,
        )
        ax.set_title(symbol_map.lookup(gene_id))
        ax.set_xlim(0, minimum_max_mean)
        ax.set_ylim(0)
    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.show()

    treatment_set.drop_genes(rejected_genes.index)

## 2.4 Detect outliers<a id="2_4" />

Determine outliers based on Mahalanobis distance

In [None]:
def calculate_mahalanobis_distance(df, column_name="distance"):
    centroid = df.mean(axis=0)
    cov_mtx = np.cov(df, rowvar=False)
    inv_cov_mtx = linalg.inv(cov_mtx)
    distances = [spatial.distance.mahalanobis(row, centroid, inv_cov_mtx)
                 for row in df.to_numpy()]

    return pd.DataFrame(
        index=df.index,
        data={column_name: distances},
    )

In [None]:
def calculate_outliers(self: StudyTreatmentSet, distance_threshold):
    mahalanobis_distances = self.condition_stats_df.groupby("gene")[["mean", "variance"]].apply(calculate_mahalanobis_distance)
    self.condition_stats_df["m_distance"] = mahalanobis_distances.loc[self.condition_stats_df.index]
    self.condition_stats_df["outlier"] = self.condition_stats_df.m_distance > distance_threshold


StudyTreatmentSet.add_method(calculate_outliers)

outlier_distance_threshold = np.sqrt(stats.chi2.ppf(0.95, 2))

for set_name, treatment_set in treatment_sets.items():
    treatment_set.calculate_outliers(outlier_distance_threshold)
    n_outliers = np.count_nonzero(treatment_set.condition_stats_df.outlier)
    print(f'{n_outliers:,} outliers in "{set_name}" treatment set')

In [None]:
def count_zero(a):
    return np.count_nonzero(a == 0)


for set_name, treatment_set in treatment_sets.items():
    n_non_outlier_series = treatment_set.condition_stats_df.groupby("gene", as_index=True)["outlier"].apply(count_zero).sort_values()
    undersampled_genes = n_non_outlier_series[n_non_outlier_series < minimum_samples_per_gene].index
    print(f'"{set_name}" has {len(undersampled_genes)} undersampled genes if outliers are discounted')

# 3. Mean-variance plots<a id="3" />

## 3.1 Fit regression models<a id="3_1" />

Fit linear regression model to mean-variance relationship of all genes for all treatment sets

In [None]:
def fit_linear_regression(df, lr):
    lr_x, lr_y = df.to_numpy().reshape(1, -1, 2).T

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        lr.fit(lr_x, lr_y)

    return pd.Series(data={
        "slope": np.squeeze(lr.coef_),
        "intercept": np.squeeze(lr.intercept_),
        "r2": lr.score(lr_x, lr_y),
    })

In [None]:
def concat_dataframes(df_list, df_ids, id_column_name):
    concat_df = pd.concat(
        df_list,
        keys=df_ids,
        names=[id_column_name]
    )
    return concat_df.reset_index(level=[0]).reset_index(drop=True)


def calculate_linear_regression_fit(self: StudyTreatmentSet):
    if not hasattr(linear_model.RANSACRegressor, "coef_"):
        setattr(linear_model.RANSACRegressor, "coef_", property(lambda self: self.estimator_.coef_))
    if not hasattr(linear_model.RANSACRegressor, "intercept_"):
        setattr(linear_model.RANSACRegressor, "intercept_", property(lambda self: self.estimator_.intercept_))

    classic_methods = {
        "ols": linear_model.LinearRegression(fit_intercept=True),
        #"ols_no_intercept": linear_model.LinearRegression(fit_intercept=False),
    }
    robust_methods = {
        "huber": linear_model.HuberRegressor(),
        "ransac": linear_model.RANSACRegressor(),
        "theil_sen": linear_model.TheilSenRegressor(),
    }
    methods_and_views = (
        (classic_methods, self.condition_stats_df.loc[~self.condition_stats_df.outlier]),
        #(robust_methods, self.condition_stats_df),
    )

    concat_df_list = []

    for methods, method_view in (methods_and_views):
        group = method_view.groupby("gene")[["mean", "variance"]]

        concat_df_list.append(concat_dataframes(
            [group.apply(fit_linear_regression, method).reset_index()
             for method in methods.values()],
            methods.keys(),
            "method"
        ))

    self.lr_fit_df = pd.concat(concat_df_list, ignore_index=True)

    self.lr_fit_df = self.lr_fit_df.merge(
        self.condition_stats_df[~self.condition_stats_df.outlier].groupby("gene").agg({"min": "min", "max": max}).rename(
            columns={"min": "min_mean", "max": "max_mean"}
        ),
        left_on="gene",
        right_index=True,
        how="left",
    )

    self.lr_fit_df = self.lr_fit_df.loc[:, ["gene", "method", "slope", "intercept", "r2", "min_mean", "max_mean"]].sort_values(by=["gene", "method"])


StudyTreatmentSet.add_method(calculate_linear_regression_fit)

for treatment_set in treatment_sets.values():
    treatment_set.calculate_linear_regression_fit()

r2_plot_df = concat_dataframes(
    [ts.lr_fit_df for ts in treatment_sets.values()],
    treatment_sets.keys(),
    "set"
)

sns.boxplot(
    x="set",
    y="r2",
    hue="method",
    data=r2_plot_df,
)
plt.ylim(bottom=0.6)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

## 3.2 Save results<a id="3_2" />

Save the descriptive statistics and results of fitting the regression models

In [None]:
def save_data(self: StudyTreatmentSet, output_path, prefix, symbol_map):
    symbol_map.added_to(self.condition_stats_df).to_csv(output_path.joinpath(prefix + "_stats_per_condition_per_gene.csv"), index=False)
    symbol_map.added_to(self.lr_fit_df).to_csv(output_path.joinpath(prefix + "_lr_fit_per_gene.csv"), index=False)


StudyTreatmentSet.add_method(save_data)

output_path = get_output_path()
output_path.mkdir(parents=True, exist_ok=True)

for treatment, treatment_set in treatment_sets.items():
    file_prefix = f"{study_species}_{treatment}"
    treatment_set.save_data(output_path, file_prefix, symbol_map)

## 3.3 Display plots<a id="3_3" />

Show mean-variance plots for 10 most variable genes

In [None]:
common_genes = set.intersection(*[set(treatment_set.umi_counts_ad.var_names)
                                  for treatment_set in treatment_sets.values()])
print(f"{len(common_genes):,} genes common to all treatment sets")

get_output_path(f"{study_species}_common_treatment_set_genes.txt").write_text("\n".join(common_genes))

# Determine genes with top R2 scores
plot_lr_df = treatment_sets["lps"].lr_fit_df
plot_lr_df = plot_lr_df.loc[plot_lr_df.method == "ols"].set_index("gene").loc[common_genes]
top_r2_genes = plot_lr_df.r2.sort_values(ascending=False)

points_treatment_set = treatment_sets["all"]

plot_gene_ids = top_r2_genes.index[:25]

for gene_idx, gene_id in enumerate(plot_gene_ids):
    gene_df = points_treatment_set.condition_stats_df.loc[points_treatment_set.condition_stats_df.gene == gene_id]

    ax = plt.subplot()

    sns.scatterplot(
        data=gene_df,
        x="mean",
        y="variance",
        hue="treatment",
        style="replicate",
        ax=ax,
    )

    line_df = []
    lr_plot_x = np.asarray([0, gene_df["mean"].max()])

    for set_name, treatment_set in treatment_sets.items():
        lr_fit_df = treatment_set.lr_fit_df.loc[treatment_set.lr_fit_df.gene == gene_id]

        for lr_method, lr_method_df in lr_fit_df.groupby("method"):
            slope, intercept = lr_method_df[["slope", "intercept"]].squeeze()

            line_df.append(pd.DataFrame(data={
                "set": set_name,
                "method": lr_method,
                "mean": lr_plot_x,
                "variance": (lr_plot_x * slope) + intercept,
            }))

    line_df = pd.concat(line_df, ignore_index=True)
    line_style = "method" if line_df.method.nunique() > 1 else None

    sns.lineplot(
        x="mean",
        y="variance",
        data=line_df,
        hue="set",
        style=line_style,
        ax=ax,
    )

    plt.title(f"{gene_idx + 1}. {symbol_map.lookup(gene_id)} ({gene_id})")
    plt.xlabel("Mean")
    plt.ylabel("Variance")
    plt.xlim(left=0)
    plt.ylim(bottom=0)
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()