This notebook loads and processes data associated with Hagai *et al.* (2018)

In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display import display
from scipy import linalg, spatial, stats
from sklearn.linear_model import LinearRegression

from rp2 import get_data_path, get_output_path, GeneSymbolMap

# 1. Loading and preparing data<a id="1" />

## 1.1. Settings<a id="1_1" />

In [None]:
study_species = "mouse"
study_replicates = ["1", "2", "3"]
study_control = "unst"
study_treatments = ["lps", "pic"]
study_time_points = ["0", "2", "4", "6"]

maximum_responsive_gene_padj = 0.01

## 1.2. List of "responsive" genes<a id="1_2" />

In [None]:
phagocyte_genes_df = pd.read_excel(
    get_data_path("hagai_2018", "41586_2018_657_MOESM4_ESM.xlsx"),
    sheet_name="phagocytes_FC_diveregnce"
)
display(phagocyte_genes_df.head())

In [None]:
responsive_phagocyte_genes = phagocyte_genes_df.loc[phagocyte_genes_df[f"{study_species}_padj"] < maximum_responsive_gene_padj].gene
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

Create a map between gene IDs and symbols

In [None]:
gene_symbols_df = pd.read_table(
    get_data_path("BioMart", f"{study_species}_genes.tsv"),
    names=["id", "symbol", "description"],
    index_col=0
)
symbol_map = GeneSymbolMap(gene_symbols_df)

## 1.3. UMI counts<a id="1_3" />

Load the UMI counts and reduce the dataset to values of interest (to optimise memory usage and access performance)

In [None]:
array_express_path = get_data_path("ArrayExpress")
umi_ad = anndata.read_h5ad(array_express_path.joinpath(f"E-MTAB-6754.processed.2.{study_species}.h5ad"))

umi_ad = umi_ad[:, responsive_phagocyte_genes]
umi_ad = umi_ad[umi_ad.obs.replicate.isin(study_replicates), :]
umi_ad = umi_ad[umi_ad.obs.treatment.isin(study_treatments + [study_control]), :]
umi_ad = umi_ad[umi_ad.obs.time_point.isin(study_time_points), :]
umi_ad = umi_ad.copy()

Create a StudyTreatmentSet for each treatment of interest so they may each be analysed

In [None]:
class StudyTreatmentSet:
    @staticmethod
    def add_method(method):
        setattr(StudyTreatmentSet, method.__name__, method)

    def __init__(self, umi_counts_ad, treatments, control_name):
        all_treatments = treatments + [control_name]
        self.umi_counts_ad = umi_counts_ad[umi_counts_ad.obs.treatment.isin(all_treatments), :]
        self.condition_stats_df = None
        self.lr_fit_with_intercept_df = None
        self.lr_fit_without_intercept_df = None

    @property
    def number_of_genes(self):
        return self.umi_counts_ad.n_vars

    @property
    def number_of_conditions(self):
        return len(self.umi_counts_ad.obs.loc[:, ["replicate", "treatment", "time_point"]].drop_duplicates())

    def check_integrity(self):
        gene_set = set(self.umi_counts_ad.var_names)
        if self.condition_stats_df is not None:
            assert(gene_set == set(self.condition_stats_df.gene))

    def drop_genes(self, genes):
        self.umi_counts_ad = self.umi_counts_ad[:, ~self.umi_counts_ad.var_names.isin(genes)]
        self.condition_stats_df = self.condition_stats_df.loc[~self.condition_stats_df.gene.isin(genes)]

        self.check_integrity()


treatment_sets = {}
for treatment in study_treatments:
    treatment_sets[treatment] = StudyTreatmentSet(umi_ad, [treatment], study_control)

treatment_sets["all"] = StudyTreatmentSet(umi_ad, study_treatments, study_control)

for set_name, treatment_set in treatment_sets.items():
    print(f'"{set_name}" treatment set:')
    print(f"  {treatment_set.number_of_genes} genes")
    print(f"  {treatment_set.number_of_conditions} conditions")

# 2. Quality control<a id="2" />

## 2.1. Settings<a id="2_1" />

In [None]:
minimum_samples_per_gene = 6

## 2.2 Remove undersampled genes<a id="2_2" />

Calculate per condition statistics (for each treatment set)

In [None]:
def calculate_condition_stats(self: StudyTreatmentSet):
    self.condition_stats_df = pd.DataFrame()

    for (replicate, treatment, time_point), group_df in self.umi_counts_ad.obs.groupby(["replicate", "treatment", "time_point"]):
        cell_view = self.umi_counts_ad[group_df.index, :]

        self.condition_stats_df = self.condition_stats_df.append(
            pd.DataFrame(
                data={
                    "gene": cell_view.var_names,
                    "replicate": replicate,
                    "treatment": treatment,
                    "time_point": time_point,
                    "n_barcodes": cell_view.n_obs,
                    "min": cell_view.X.A.min(axis=0),
                    "max": cell_view.X.A.max(axis=0),
                    "mean": cell_view.X.A.mean(axis=0),
                    "variance": cell_view.X.A.var(axis=0, ddof=1),
                    "std_dev": cell_view.X.A.std(axis=0, ddof=1),
                },
            ),
            ignore_index=True,
        )

    self.condition_stats_df = self.condition_stats_df.sort_values(["gene", "replicate", "time_point", "treatment"])


StudyTreatmentSet.add_method(calculate_condition_stats)

for treatment_set in treatment_sets.values():
    treatment_set.calculate_condition_stats()
    display(symbol_map.added_to(treatment_set.condition_stats_df).head())

Drop undersampled genes (those that have too few conditions with a non-zero mean UMI count)

In [None]:
for set_name, treatment_set in treatment_sets.items():
    n_non_zero_means_series = treatment_set.condition_stats_df.groupby("gene", as_index=True)["mean"].apply(np.count_nonzero)
    undersampled_genes = n_non_zero_means_series[n_non_zero_means_series < minimum_samples_per_gene].index
    treatment_set.drop_genes(undersampled_genes)
    print(f'"{set_name}" dropped {len(undersampled_genes)} genes ({treatment_set.number_of_genes} remaining)')

Enumerate remaining conditions with zero mean UMI count by time point (for each treatment set)

In [None]:
def drop_sample_points(self: StudyTreatmentSet, condition_ids):
    self.condition_stats_df = self.condition_stats_df.drop(index=condition_ids)
    self.check_integrity()


StudyTreatmentSet.add_method(drop_sample_points)

remove_all_zero_mean_samples = True

for set_name, treatment_set in treatment_sets.items():
    print(f'Zero mean conditions remaining for "{set_name}" treatment set:')
    zero_mean_series = treatment_set.condition_stats_df.loc[treatment_set.condition_stats_df["mean"] == 0, "time_point"]
    for time_point, n in zero_mean_series.value_counts().sort_index().iteritems():
        print(f"  {n} at time point {time_point}")

    if remove_all_zero_mean_samples:
        treatment_set.drop_sample_points(zero_mean_series.index)

## 2.3 Remove genes with low abundance<a id="2_3" />

In [None]:
minimum_max_mean = 2

for set_name, treatment_set in treatment_sets.items():
    print(f'"{set_name}" treatment set:')
    max_mean_series = treatment_set.condition_stats_df.groupby("gene")["mean"].apply(np.max).sort_values()

    rejected_genes = max_mean_series[max_mean_series < minimum_max_mean]
    print(f"  {len(rejected_genes):,} genes rejected ({len(max_mean_series) - len(rejected_genes):,} remaining)")

    max_mean_series.plot.hist(figsize=(16, 4), bins=40, log=True).set(xlabel="Max mean")
    plt.axvline(x=minimum_max_mean)
    plt.show()

    if len(rejected_genes) == 0:
        continue

    top_rejected_genes = rejected_genes[-4:]
    print(f"  Top {len(top_rejected_genes)} rejections:")
    _, axes = plt.subplots(1, len(top_rejected_genes), figsize=(4 * len(top_rejected_genes), 4))
    for gene_id, ax in zip(top_rejected_genes.index, axes.flatten()):
        treatment_set.condition_stats_df.loc[treatment_set.condition_stats_df.gene == gene_id].plot.scatter("mean", "variance", ax=ax)
        ax.set_title(symbol_map.lookup(gene_id))
        ax.set_xlim(0, minimum_max_mean)
        ax.set_ylim(0)
    plt.show()

    treatment_set.drop_genes(rejected_genes.index)

## 2.4 Detect outliers<a id="2_4" />

Determine outliers based on Mahalanobis distance

In [None]:
def calculate_mahalanobis_distance(df, column_name="distance"):
    centroid = df.mean(axis=0)
    cov_mtx = np.cov(df, rowvar=False)
    inv_cov_mtx = linalg.inv(cov_mtx)
    distances = [spatial.distance.mahalanobis(row, centroid, inv_cov_mtx)
                 for row in df.to_numpy()]

    return pd.DataFrame(
        index=df.index,
        data={column_name: distances},
    )

In [None]:
def calculate_outliers(self: StudyTreatmentSet, distance_threshold):
    mahalanobis_distances = self.condition_stats_df.groupby("gene")[["mean", "variance"]].apply(calculate_mahalanobis_distance)
    self.condition_stats_df["m_distance"] = mahalanobis_distances.loc[self.condition_stats_df.index]
    self.condition_stats_df["outlier"] = self.condition_stats_df.m_distance > distance_threshold


StudyTreatmentSet.add_method(calculate_outliers)

outlier_distance_threshold = np.sqrt(stats.chi2.ppf(0.95, 2))

for set_name, treatment_set in treatment_sets.items():
    treatment_set.calculate_outliers(outlier_distance_threshold)
    n_outliers = np.count_nonzero(treatment_set.condition_stats_df.outlier)
    print(f'{n_outliers:,} outliers in "{set_name}" treatment set')

In [None]:
def count_zero(a):
    return np.count_nonzero(a == 0)


for set_name, treatment_set in treatment_sets.items():
    n_non_outlier_series = treatment_set.condition_stats_df.groupby("gene", as_index=True)["outlier"].apply(count_zero).sort_values()
    undersampled_genes = n_non_outlier_series[n_non_outlier_series < minimum_samples_per_gene].index
    print(f'"{set_name}" has {len(undersampled_genes)} undersampled genes if outliers are discounted')

# 3. Mean-variance plots<a id="3" />

## 3.1 Fit regression models<a id="3_1" />

Fit linear regression model to mean-variance relationship of all genes for all treatment sets

In [None]:
def fit_linear_regression(df, fit_intercept=True):
    lr_values = df.to_numpy().reshape(1, -1, 2).T

    lr = LinearRegression(fit_intercept=fit_intercept)
    lr.fit(*lr_values)

    results = {}
    results["slope"] = np.squeeze(lr.coef_)
    if fit_intercept:
        results["intercept"] = np.squeeze(lr.intercept_)
    results["r2"] = lr.score(*lr_values)

    return pd.Series(results)

In [None]:
def calculate_linear_regression_fit(self: StudyTreatmentSet):
    group = self.condition_stats_df.loc[~self.condition_stats_df.outlier].groupby("gene")[["mean", "variance"]]
    self.lr_fit_with_intercept_df = group.apply(fit_linear_regression, fit_intercept=True)
    self.lr_fit_without_intercept_df = group.apply(fit_linear_regression, fit_intercept=False)


StudyTreatmentSet.add_method(calculate_linear_regression_fit)

r2_scores_df = []

for set_name, treatment_set in treatment_sets.items():
    treatment_set.calculate_linear_regression_fit()
    r2_scores_df.append(pd.DataFrame(data={
        "set": set_name,
        "intercept": True,
        "r2": treatment_set.lr_fit_with_intercept_df.r2
    }))
    r2_scores_df.append(pd.DataFrame(data={
        "set": set_name,
        "intercept": False,
        "r2": treatment_set.lr_fit_without_intercept_df.r2
    }))
    print(f'"{set_name}" mean R2 values:')
    print(f"  With intercept: {treatment_set.lr_fit_with_intercept_df.r2.mean():.2f}")
    print(f"  Without intercept: {treatment_set.lr_fit_without_intercept_df.r2.mean():.2f}")

r2_scores_df = pd.concat(r2_scores_df)
sns.boxplot(
    x="set",
    y="r2",
    hue="intercept",
    data=r2_scores_df,
)
plt.ylim(bottom=0.6)
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

## 3.2 Save results<a id="3_2" />

Save the descriptive statistics and results of fitting the regression models

In [None]:
def save_data(self: StudyTreatmentSet, output_path, prefix):
    self.condition_stats_df.to_csv(output_path.joinpath(prefix + "_stats_per_condition_per_gene.csv"), index=False)
    self.lr_fit_with_intercept_df.to_csv(output_path.joinpath(prefix + "_lr_fit_per_gene.csv"))


StudyTreatmentSet.add_method(save_data)

output_path = get_output_path()
output_path.mkdir(parents=True, exist_ok=True)

for treatment, treatment_set in treatment_sets.items():
    file_prefix = f"{study_species}_{treatment}"
    treatment_set.save_data(output_path, file_prefix)

## 3.3 Display plots<a id="3_3" />

Show mean-variance plots for 10 most variable genes

In [None]:
common_genes = set.intersection(*[set(treatment_set.umi_counts_ad.var_names)
                                  for treatment_set in treatment_sets.values()])
print(f"{len(common_genes):,} genes common to all treatment sets")

# Calculate a rough estimate of gene variability
gene_variability = pd.Series(index=umi_ad.var.index, data=umi_ad.X.A.var(axis=0))
gene_variability = gene_variability[common_genes].sort_values(ascending=False)

points_treatment_set = treatment_sets["all"]

for gene_id in gene_variability.index[:10]:
    gene_df = points_treatment_set.condition_stats_df.loc[points_treatment_set.condition_stats_df.gene == gene_id]

    ax = plt.subplot()

    sns.scatterplot(
        data=gene_df,
        x="mean",
        y="variance",
        hue="treatment",
        style="replicate",
        ax=ax,
    )

    line_df = []

    for set_name, treatment_set in treatment_sets.items():
        slope, intercept = treatment_set.lr_fit_with_intercept_df.loc[gene_id, ["slope", "intercept"]]
        slope2 = treatment_set.lr_fit_without_intercept_df.loc[gene_id, "slope"]

        lr_plot_x = np.asarray([0, gene_df["mean"].max()])

        line_df.append(pd.DataFrame(data={
            "set": set_name,
            "intercept": True,
            "mean": lr_plot_x,
            "variance": (lr_plot_x * slope) + intercept,
        }))
        line_df.append(pd.DataFrame(data={
            "set": set_name,
            "intercept": False,
            "mean": lr_plot_x,
            "variance": lr_plot_x * slope2,
        }))

    line_df = pd.concat(line_df)

    sns.lineplot(
        x="mean",
        y="variance",
        data=line_df,
        hue="set",
        style="intercept",
        ax=ax,
    )

    plt.title(f"{symbol_map.lookup(gene_id)} / {gene_id}")
    plt.xlabel("Mean")
    plt.ylabel("Variance")
    plt.xlim(left=0)
    plt.ylim(bottom=0)
    plt.legend(loc="upper left", bbox_to_anchor=(1, 1))
    plt.show()