# Preamble

In [None]:
%load_ext autoreload

In [None]:
import os as _os

# NOTE: This cell is to allow the notebook to be run from a subdirectory of the project root.
# TODO: Either export this "PROJECT_ROOT" environmental variable, or modify this cell for your setup.
_os.chdir('..')
_os.path.realpath(_os.path.curdir)

## Imports

In [None]:
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

In [None]:
def rotate_xticklabels(ax=None, rotation=45, ha="right", **kwargs):
    if ax is None:
        ax = plt.gca()
    ax.set_xticklabels(
        [x.get_text() for x in ax.get_xticklabels()],
        rotation=rotation,
        ha=ha,
        **kwargs,
    )

def idxwhere(condition, x=None):
    if x is None:
        x = condition
    return list(x[condition].index)

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 50

# Prepare Metadata

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

_species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
)


def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])


species_taxonomy = _species_taxonomy.Lineage.apply(parse_taxonomy_string)
species_taxonomy

In [None]:
def label_een_experiment_sample(x):
    if x.sample_type == "human":
        label = f"[{x.name}] {x.collection_date_relative_een_end} {x.diet_or_media}"
    elif x.sample_type in ["Fermenter_inoculum"]:
        label = f"[{x.name}] {x.source_samples} inoc {x.diet_or_media}"
    elif x.sample_type in ["Fermenter"]:
        # label = f"[{x.name}] {x.source_samples} frmnt {x.diet_or_media}"
        label = f"[{x.name}] {x.source_samples} {x.diet_or_media}"
    elif x.sample_type in ["mouse"]:
        if x.status_mouse_inflamed == "Inflamed":
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} inflam"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} inflam"
        elif x.status_mouse_inflamed == "not_Inflamed":
            # label = f"[{x.name}] {x.source_samples} 🐭 {x.mouse_genotype} {x.diet_or_media} not_inf"
            label = f"[{x.name}] {x.source_samples} {x.diet_or_media} not_inf"
        else:
            raise ValueError(f"sample type {x.status_mouse_inflamed} not understood")
    else:
        raise ValueError(f"sample type {x.sample_type} not understood")
    return label


sample = (
    pd.read_table("meta/een-mgen/sample.tsv")
    .assign(
        label=lambda x: x[
            ["collection_date_relative_een_end", "diet_or_media", "sample_id"]
        ].apply(tuple, axis=1)
    )
    .set_index("sample_id")
    .assign(full_label=lambda d: d.apply(label_een_experiment_sample, axis=1))
)

# Prepare Data

In [None]:
species_depth = (
    pd.read_table(
        "data/group/een/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
        names=["sample", "species_id", "depth"],
        index_col=["sample", "species_id"],
    )
    .depth.unstack(fill_value=0)
    .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
    .rename({"CF_15": "CF_11", "CF_11": "CF_15"})  # Sample swap
)
species_relabund = species_depth.divide(species_depth.sum(1), axis=0)

species_relabund

In [None]:
# Strains with maximum, within-species fraction less than this are dropped.
min_strain_frac = 0.05

strain_depth = []
strain_to_species = []
missing_files = []
for species_id in species_depth.columns:
    # Load data
    path = f"data/group/een/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv"
    if os.path.exists(path):
        _strain_frac = (
            pd.read_table(path, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
            # Normalize samples names and correct the sample swap.
            .rename(columns=str, index=lambda x: "CF_" + str(int(x.split("_")[1])))
            .rename({"CF_11": "CF_15", "CF_15": "CF_11"})
        )
    else:
        missing_files.append(path)
        _strain_frac = pd.DataFrame([])
    assert _strain_frac.index.isin(species_depth.index).all()

    # Drop very rare strains.
    _keep_strains = idxwhere(_strain_frac.max() >= min_strain_frac)
    _strain_frac = _strain_frac.reindex(
        index=species_depth.index, columns=_keep_strains, fill_value=0
    )
    _strain_frac = _strain_frac.assign(__other=lambda x: 1 - x.sum(1))
    _strain_frac[_strain_frac < 0] = 0
    _strain_frac = _strain_frac.divide(_strain_frac.sum(1), axis=0)

    # Multiply strain relative abundance by species depth (resulting in strain depth.)
    _strain_depth = _strain_frac.multiply(species_depth[species_id], axis=0)

    # Rename strains to prefix with species_id.
    _strain_depth = _strain_depth.rename(columns=lambda s: f"{species_id}_{s}")

    # Append to results
    strain_depth.append(_strain_depth)
    strain_to_species.append(pd.Series(species_id, index=_strain_depth.columns))

# Compile full tables.
strain_depth = pd.concat(strain_depth, axis=1)
strain_to_species = pd.concat(strain_to_species)
strain_relabund = strain_depth.divide(strain_depth.sum(1), axis=0)

# Check that we didn't introduce error.
assert np.allclose(
    strain_depth.groupby(strain_to_species, axis="columns").sum(), species_depth
)

# Report how many species were run.
print(len(species_depth.columns), len(missing_files))

# Visualize E. coli strain fractions

In [None]:
species_id = "102506"
strain_frac_thresh = 0.2

print(species_taxonomy.loc[species_id])

sample_meta = sample.sort_values(
    [
        "subject_id",
        "sample_type",
        "collection_date_relative_een_end",
        "diet_or_media",
        "mouse_genotype",
    ]
).loc[strain_depth.index]
strain_frac = (
    strain_depth.loc[:, strain_to_species == species_id]
    .apply(lambda x: x / x.sum(), axis=1)
)
sample_meta, strain_frac = sample_meta.align(strain_frac, axis="index")

all_strains = strain_frac.columns
strain_palette = dict(
    zip(all_strains, mpl.cm.Spectral(np.linspace(0, 1, num=len(all_strains))))
)
subject_order = sample_meta.subject_id.unique()

num_subjects = len(subject_order)
num_cols = 4
num_rows = int(np.ceil(num_subjects / num_cols))
max_num_samples = sample_meta.subject_id.value_counts().max()
fig, axs = plt.subplots(
    num_rows,
    num_cols,
    figsize=(15 * num_cols, 6 * num_rows),
    gridspec_kw=dict(wspace=0.1, hspace=3),
    sharex=False,
    sharey=True,
    squeeze=False,
)

for subject_id, ax in zip(subject_order, axs.flatten()):
    _subject_strain_frac = strain_frac.loc[sample_meta.subject_id == subject_id]
    _subject_sample_meta = sample_meta.loc[sample_meta.subject_id == subject_id].assign(
        xpos=lambda x: range(len(x.index))
    )
    _subject_strain_order = idxwhere(_subject_strain_frac.max() >= strain_frac_thresh)
    last_top = 0
    for strain_id in _subject_strain_order:
        ax.bar(
            x=_subject_sample_meta.xpos,
            height=_subject_strain_frac[strain_id],
            bottom=last_top,
            width=1.0,
            alpha=1.0,
            color=strain_palette[strain_id],
            edgecolor="k",
            lw=1,
            label=strain_id,
        )
        last_top += _subject_strain_frac[strain_id]
    ax.bar(
        x=_subject_sample_meta.xpos,
        height=1.0 - last_top,
        bottom=last_top,
        width=1.0,
        alpha=1.0,
        color="grey",
        edgecolor="k",
        lw=1,
        label="other",
    )
    ax.set_title((subject_id, species_id))
    ax.set_xticks(range(len(_subject_sample_meta.index)))
    ax.set_xticklabels(_subject_sample_meta.full_label.values)
    ax.set_xlim(-0.5, max_num_samples - 0.5)
    # ax.legend(bbox_to_anchor=(1, 1))
    rotate_xticklabels(ax=ax)