In [None]:
import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.core.display import display
from scipy import linalg, spatial, stats
from sklearn.linear_model import LinearRegression

from rp2 import get_data_path, get_output_path, GeneSymbolMap

Load the results for phagocytes from Hagai *et al.* (2018) and determine their list of "responsive" genes

In [None]:
phagocyte_genes_df = pd.read_excel(
    get_data_path("hagai_2018", "41586_2018_657_MOESM4_ESM.xlsx"),
    sheet_name="phagocytes_FC_diveregnce"
)
display(phagocyte_genes_df)

In [None]:
study_species = "mouse"

responsive_phagocyte_genes = phagocyte_genes_df.loc[phagocyte_genes_df[f"{study_species}_padj"] < 0.01].gene
print(f"{len(responsive_phagocyte_genes):,} responsive phagocyte genes")

Load Hagai *et al.* (2018) UMI counts and reduce the dataset to values of interest to optimise memory usage and access performance

In [None]:
array_express_path = get_data_path("ArrayExpress")
umi_ad = anndata.read_h5ad(array_express_path.joinpath(f"E-MTAB-6754.processed.2.{study_species}.h5ad"))

study_genes = responsive_phagocyte_genes
study_replicates = ["1", "2", "3"]
study_control = "unst"
study_treatments = ["lps", "pic"]
study_time_points = ["0", "2", "4", "6"]

umi_ad = umi_ad[:, study_genes]
umi_ad = umi_ad[umi_ad.obs.replicate.isin(study_replicates), :]
umi_ad = umi_ad[umi_ad.obs.treatment.isin(study_treatments + [study_control]), :]
umi_ad = umi_ad[umi_ad.obs.time_point.isin(study_time_points), :]
umi_ad = umi_ad.copy()

Create a map between gene IDs and symbols

In [None]:
gene_symbols_df = pd.read_table(
    get_data_path("BioMart", f"{study_species}_genes.tsv"),
    names=["id", "symbol", "description"],
    index_col=0
)
symbol_map = GeneSymbolMap(gene_symbols_df)

Create a rough estimation of gene variability

In [None]:
gene_variability = pd.Series(index=umi_ad.var.index, data=umi_ad.X.A.var(axis=0)).sort_values(ascending=False)
print(gene_variability)

Create a StudyTreatmentSet for each treatment of interest so they may each be analysed

In [None]:
class StudyTreatmentSet:
    @staticmethod
    def add_method(method):
        setattr(StudyTreatmentSet, method.__name__, method)

    def __init__(self, umi_counts_ad, treatments, control_name):
        all_treatments = treatments + [control_name]
        self.umi_counts_ad = umi_counts_ad[umi_counts_ad.obs.treatment.isin(all_treatments), :]
        self.condition_stats_df = None
        self.lr_fit_with_intercept_df = None
        self.lr_fit_without_intercept_df = None


treatment_sets = {}
for treatment in study_treatments:
    treatment_sets[treatment] = StudyTreatmentSet(umi_ad, [treatment], study_control)

if len(study_treatments) > 1:
    treatment_sets["all"] = StudyTreatmentSet(umi_ad, study_treatments, study_control)
else:
    treatment_sets["all"] = treatment_sets[study_treatments[0]]

# Restrict focus to the LPS treatment for now
focus_treatment_set = treatment_sets["lps"]
all_treatment_set = treatment_sets["all"]

Calculate and display per condition statistics

In [None]:
def calculate_condition_stats(self: StudyTreatmentSet):
    self.condition_stats_df = pd.DataFrame()

    for (replicate, treatment, time_point), group_df in self.umi_counts_ad.obs.groupby(["replicate", "treatment", "time_point"]):
        cell_view = self.umi_counts_ad[group_df.index, :]

        self.condition_stats_df = self.condition_stats_df.append(
            pd.DataFrame(
                data={
                    "gene": study_genes,
                    "replicate": replicate,
                    "treatment": treatment,
                    "time_point": time_point,
                    "n_barcodes": cell_view.n_obs,
                    "min": cell_view.X.A.min(axis=0),
                    "max": cell_view.X.A.max(axis=0),
                    "mean": cell_view.X.A.mean(axis=0),
                    "variance": cell_view.X.A.var(axis=0, ddof=1),
                    "std_dev": cell_view.X.A.std(axis=0, ddof=1),
                },
            ),
            ignore_index=True,
        )

    self.condition_stats_df = self.condition_stats_df.sort_values(["gene", "replicate", "time_point", "treatment"])


StudyTreatmentSet.add_method(calculate_condition_stats)

for treatment_set in treatment_sets.values():
    treatment_set.calculate_condition_stats()

condition_stats_df = focus_treatment_set.condition_stats_df
display(symbol_map.added_to(condition_stats_df))

Display conditions that have a zero mean UMI count

In [None]:
zero_mean_df = all_treatment_set.condition_stats_df.loc[all_treatment_set.condition_stats_df["mean"] == 0]
display(symbol_map.added_to(zero_mean_df))
display(zero_mean_df.time_point.value_counts().sort_index())

Determine outliers based on (squared) Mahalanobis distance

In [None]:
def calculate_mahalanobis_distance(df, column_name="distance"):
    centroid = df.mean(axis=0)
    cov_mtx = np.cov(df, rowvar=False)

    try:
        inv_cov_mtx = linalg.inv(cov_mtx)
        distances = [spatial.distance.mahalanobis(row, centroid, inv_cov_mtx)
                     for row in df.to_numpy()]
    except linalg.LinAlgError:
        distances = np.nan

    return pd.DataFrame(
        index=df.index,
        data={column_name: distances},
    )

In [None]:
def calculate_outliers(self: StudyTreatmentSet, distance_threshold):
    mahalanobis_distances = self.condition_stats_df.groupby("gene")[["mean", "variance"]].apply(calculate_mahalanobis_distance)
    self.condition_stats_df["m_distance"] = mahalanobis_distances.loc[self.condition_stats_df.index]
    self.condition_stats_df["outlier"] = self.condition_stats_df.m_distance > distance_threshold


StudyTreatmentSet.add_method(calculate_outliers)

outlier_distance_threshold = np.sqrt(stats.chi2.ppf(0.95, 2))
for treatment_set in treatment_sets.values():
    treatment_set.calculate_outliers(outlier_distance_threshold)

print(f"Outliers: {condition_stats_df.outlier.sum():,}")

Plot histograms of mean and variance for all non-outliers

In [None]:
for stat_name in ["Mean", "Variance"]:
    condition_stats_df.loc[~condition_stats_df.outlier, stat_name.lower()].plot.hist(log=True).set(xlabel=stat_name)
    plt.show()

condition_stats_df.loc[~condition_stats_df.outlier].groupby("gene")["mean"].max().plot.hist(log=True).set(xlabel="Max mean")
plt.show()

Fit linear regression model to mean-variance relationship of all genes

In [None]:
def fit_linear_regression(df, fit_intercept=True):
    lr_values = df.to_numpy().reshape(1, -1, 2).T

    lr = LinearRegression(fit_intercept=fit_intercept)
    lr.fit(*lr_values)

    results = {}
    results["slope"] = np.squeeze(lr.coef_)
    if fit_intercept:
        results["intercept"] = np.squeeze(lr.intercept_)
    results["r2"] = lr.score(*lr_values)

    return pd.Series(results)

In [None]:
def calculate_linear_regression_fit(self: StudyTreatmentSet):
    group = self.condition_stats_df.loc[~self.condition_stats_df.outlier].groupby("gene")[["mean", "variance"]]
    self.lr_fit_with_intercept_df = group.apply(fit_linear_regression, fit_intercept=True)
    self.lr_fit_without_intercept_df = group.apply(fit_linear_regression, fit_intercept=False)


StudyTreatmentSet.add_method(calculate_linear_regression_fit)

for treatment_set in treatment_sets.values():
    treatment_set.calculate_linear_regression_fit()

gene_regression_df = focus_treatment_set.lr_fit_with_intercept_df
display(symbol_map.added_to(gene_regression_df))

Save the descriptive statistics and results of fitting the regression models

In [None]:
def save_data(self: StudyTreatmentSet, output_path, prefix):
    self.condition_stats_df.to_csv(output_path.joinpath(prefix + "_stats_per_condition_per_gene.csv"), index=False)
    self.lr_fit_with_intercept_df.to_csv(output_path.joinpath(prefix + "_lr_fit_per_gene.csv"))


StudyTreatmentSet.add_method(save_data)

output_path = get_output_path()
output_path.mkdir(parents=True, exist_ok=True)

for treatment, treatment_set in treatment_sets.items():
    file_prefix = f"{study_species}_{treatment}"
    treatment_set.save_data(output_path, file_prefix)

Show mean-variance plots for 10 most variable genes

In [None]:
plot_outliers = True

replicate_colour_map = {
    "1": "black",
    "2": "red",
    "3": "blue",
}

for gene in gene_variability.index[:10]:
    gene_df = condition_stats_df.loc[condition_stats_df.gene == gene]

    slope, intercept = gene_regression_df.loc[gene, ["slope", "intercept"]]

    lr_plot_x = gene_df["mean"]
    if not plot_outliers:
        lr_plot_x = lr_plot_x[~gene_df.outlier]
    lr_plot_x = np.sort(lr_plot_x)
    lr_plot_y = (lr_plot_x * slope) + intercept

    ax = plt.subplot()
    if plot_outliers:
        gene_df.loc[gene_df.outlier].plot.scatter("mean", "variance", ax=ax, marker="x")

    for replicate, df in gene_df.loc[~gene_df.outlier].groupby("replicate"):
        df.plot.scatter("mean", "variance", ax=ax, marker="o", label=replicate, color=replicate_colour_map[replicate])
    plt.plot(lr_plot_x, lr_plot_y)
    plt.title(f"{symbol_map.lookup(gene)} / {gene}")
    plt.xlabel("Mean")
    plt.ylabel("Variance")
    plt.xlim(left=min(0, np.min(lr_plot_x)))
    plt.ylim(bottom=min(0, np.min(lr_plot_y)))
    plt.legend(title="Replicate")
    plt.show()