## Preamble

### Project Template

In [None]:
%load_ext autoreload

In [None]:
import os as _os

_os.chdir(_os.environ["PROJECT_ROOT"])
_os.path.realpath(_os.path.curdir)

### Imports

In [None]:
import os
import subprocess
import sys
import time
from datetime import datetime
from glob import glob
from itertools import chain, product
from tempfile import mkstemp

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import sfacts as sf
import statsmodels.api as sm
import statsmodels.formula.api as smf
import xarray as xr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

import lib.plot
from lib.dissimilarity import load_dmat_as_pickle
from lib.pandas_util import align_indexes, aligned_index, idxwhere, invert_mapping

In [None]:
import lib.thisproject.data

### Set Style

In [None]:
sns.set_context("talk")
plt.rcParams["figure.dpi"] = 100

## Data Setup

### Metadata

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
def parse_taxonomy_string(taxonomy_string):
    values = taxonomy_string.split(";")
    return pd.Series(values, index=["d__", "p__", "c__", "o__", "f__", "g__", "s__"])

In [None]:
species_taxonomy_inpath = "ref/uhgg_genomes_all_v2.tsv"

species_taxonomy = (
    pd.read_table(species_taxonomy_inpath)[lambda x: x.Genome == x.Species_rep]
    .assign(species_id=lambda x: "1" + x.MGnify_accession.str.split("-").str[2])
    .set_index("species_id")
    .Lineage.apply(parse_taxonomy_string)
)
species_taxonomy

In [None]:
phylum_palette = lib.plot.construct_ordered_palette(
    sorted(species_taxonomy.p__.unique()),
    cm="tab10",
)

# assert len(set(phylum_palette.values())) == len((phylum_palette.values()))

### Metadata

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/hmp2/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

meta_all = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
    .assign(new_name=lambda x: (
        x[['subject_id', 'week_number']]
        .assign(library_id=x.index)
        .assign(week_number=lambda x: x.week_number.fillna(999).astype(int))
        .apply(lambda x: '_'.join(x.astype(str)), axis=1)
    ))
    # .reset_index()
    # .set_index('new_name')
)

library_id_to_new_name = meta_all.new_name

assert not any(meta_all.subject_id.isna())

# TODO: Rename samples based on subject and visit number
# TODO: Drop duplicate stools

### Species Depth

In [None]:
species_depth = []
_missing_species = []

for species in tqdm(species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.species_depth.tsv"
    if not os.path.exists(inpath):
        _missing_species.append(species)
        continue
    data = pd.read_table(inpath, names=['sample', 'depth']).assign(species=species)
    species_depth.append(data)
species_depth = pd.concat(species_depth).set_index(['sample', 'species']).depth.unstack(fill_value=0)

print(
    len(_missing_species),
    "out of",
    len(species_list),
    "species are missing.",
)

In [None]:
depth_thresh = 0.2

species_found = species_depth > depth_thresh
species_prevalence = species_found.groupby(meta_all.subject_id).any().mean().sort_values(ascending=False)

species_rabund_when_found = species_depth.divide(species_depth.sum(1), axis=0).where(species_found, np.nan)
species_mean_rabund_when_found = species_rabund_when_found.groupby(meta_all.subject_id).mean().mean().sort_values(ascending=False)
species_median_rabund_when_found = species_rabund_when_found.groupby(meta_all.subject_id).median().median().sort_values(ascending=False)

species_prevalence.to_frame('prevalence').assign(mean_rabund=species_mean_rabund_when_found, median_rabund=species_median_rabund_when_found).join(species_taxonomy).head(20)

### Strain Statistics

In [None]:
def classify_genome(x):
    if (x.genome_type == "Isolate") & x.passes_filter:
        return "isolate"
    elif (x.genome_type == "Isolate") & ~x.passes_filter:
        return "isolate_fails_qc"
    elif (x.genome_type == "MAG") & x.passes_filter:
        return "mag"
    elif (x.genome_type == "MAG") & ~x.passes_filter:
        return "mag_fails_qc"
    elif (x.genome_type == "SPGC") & x.passes_filter:
        return "spgc"
    elif (x.genome_type == "SPGC") & x.passes_geno_positions & x.passes_in_sample_list:
        return "sfacts_only"
    elif (x.genome_type == "SPGC") & ~(
        x.passes_geno_positions & x.passes_in_sample_list
    ):
        return "sfacts_fails_qc"
    else:
        raise ValueError("Genome did not match classification criteria:", x)

In [None]:
filt_stats = []
missing_species = []

for species in tqdm(species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.gene99_v15-v22-agg75.spgc-fit.strain_meta_spgc_and_ref.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        continue
    data = pd.read_table(inpath).assign(species=species, inpath=inpath)
    filt_stats.append(data)
filt_stats = pd.concat(filt_stats).assign(
    genome_class=lambda x: x.apply(classify_genome, axis=1)
)

print(
    len(missing_species),
    "out of",
    len(species_list),
    "species are missing stats.",
)

In [None]:
# TODO: Collect strain fractions:

strain_depth = []
strain_to_species = []
missing_species = []
for species in tqdm(species_list):
    inpath = f"data/group/hmp2/species/sp-{species}/r.proc.gtpro.sfacts-fit.comm.tsv"
    if not os.path.exists(inpath):
        missing_species.append(species)
        d = pd.DataFrame([])
    else:
        d = (
            pd.read_table(inpath, index_col=["sample", "strain"])
            .squeeze()
            .unstack()
        )

    if species in species_depth.columns:
        _species_depth = species_depth[species]
    else:
        _species_depth = 0
    
    _keep_strains = idxwhere(d.sum() > 0.05)
    assert d.index.isin(species_depth.index).all()
    d = d.reindex(index=species_depth.index, columns=_keep_strains, fill_value=0)
    d = d.assign(__other=lambda x: 1 - x.sum(1)).rename(columns={"__other": -1})
    d[d < 0] = 0
    d = d.divide(d.sum(1), axis=0)
    d = d.multiply(_species_depth, axis=0)
    d = d.rename(columns=lambda s: f"{species}_{s}")
    strain_depth.append(d)
    strain_to_species.append(pd.Series(species, index=d.columns))

strain_depth = pd.concat(strain_depth, axis=1)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
strain_rabund['-1'] = 1 - strain_rabund.sum(1)
strain_to_species = pd.concat(strain_to_species)

print(
    len(missing_species),
    "out of",
    len(species_list),
    "species are missing stats.",
)

In [None]:
# Shared strains between 
other_strain_list = idxwhere(strain_depth.columns.to_series().str.endswith("-1"))
strain_present = strain_depth > 0.1
low_strain_samples = idxwhere(strain_present.sum(1) <= 10)
m, x = align_indexes(meta_all, strain_present.drop(columns=other_strain_list, index=low_strain_samples))

shared_strains = pdist(x, metric=lambda x, y: (x & y).sum())
diff_subj = pdist(m[['subject_id']], lambda x, y: (x != y).sum())

In [None]:
fig, ax = plt.subplots()

bins = np.arange(0, shared_strains.max())

for _diff_subj, c in zip([0, 1], ['tab:purple', 'tab:red']):
    ax.hist(shared_strains[diff_subj == _diff_subj], alpha=0.5, color=c, bins=bins, density=True)
    ax.hist(shared_strains[diff_subj == _diff_subj], histtype='step', color=c, bins=bins, density=True)

ax.set_xlabel('Shared Strains Per Sample Pair')
ax.set_ylabel('density')
ax.set_yticks([0])

In [None]:
fig, axs = plt.subplots(2, figsize=(6, 3), gridspec_kw=dict(height_ratios=(0.3, 0.7)))

# bins = list(np.arange(0, 10)) + list(np.linspace(10, shared_strains.max(), num=10))
bins = np.arange(0, shared_strains.max())
for ax in axs:
    for _diff_subj, label, c in zip([0, 1], ['Same Subject', 'Different Subject'], ['tab:purple', 'tab:red']):
        ax.hist(shared_strains[diff_subj == _diff_subj], alpha=0.5, color=c, bins=bins, density=True, label=label)
        ax.hist(shared_strains[diff_subj == _diff_subj], histtype='step', color=c, bins=bins, density=True, label='__nolegend__')
axs[0].set_ylim(bottom=0.5)
axs[1].set_ylim(top=0.15)
axs[1].set_xlabel('Shared Strains Per Sample Pair')
axs[1].set_ylabel('Pairs (density)')

axs[1].spines['top'].set_visible(False)
axs[0].spines['bottom'].set_visible(False)
axs[0].set_xticks([])

axs[1].legend(loc='upper right')


# for d, ax in zip([0, 1], ['tab:purple', 'tab:red']):
#     bins = np.arange(0, shared_strains.max())
#     ax.hist(shared_strains[diff_subj == _diff_subj], alpha=0.5, color=c, bins=bins, density=True)
#     ax.hist(shared_strains[diff_subj == _diff_subj], histtype='step', color=c, bins=bins, density=True)

# ax.set_xlabel('Shared Strains Per Sample Pair')
# ax.set_ylabel('density')
# ax.set_yticks([0])

In [None]:
strain_by_subject = (strain_depth.groupby(meta_all.subject_id).max() > 0.1).drop(columns=other_strain_list)
total_strains_per_subject = strain_by_subject.sum(1)
total_subjects_per_strain = strain_by_subject.sum(0)

fig, axs = plt.subplots(2, figsize=(6, 7), gridspec_kw=dict(hspace=0.5))

bins = np.linspace(0, total_strains_per_subject.max(), num=40)
axs[0].hist(total_strains_per_subject, bins=bins, density=False, color='k', histtype='step')
axs[0].hist(total_strains_per_subject, bins=bins, density=False, color='k', alpha=0.5)
axs[0].set_ylabel('Subjects (count)')


# axs[0].set_yticks([0])
axs[0].set_xlabel('Strains per Subject')
bins = np.arange(0, total_subjects_per_strain.max())
axs[1].hist(total_subjects_per_strain, bins=bins, density=False, color='k', histtype='step')
axs[1].hist(total_subjects_per_strain, bins=bins, density=False, color='k', alpha=0.5)
axs[1].set_ylabel('Strains (count)')
# axs[1].set_yticks([0])
axs[1].set_xlabel('Subjects per Strain')
# fig.tight_layout()

In [None]:
print(meta_all.loc[species_depth.index].subject_id.value_counts().shape)
print(meta_all.loc[species_depth.index].subject_id.value_counts().quantile([0.25, 0.5, 0.75]))

In [None]:
depth_thresh = 0.1
other_strain_list = idxwhere(strain_depth.columns.to_series().str.endswith("-1"))
strain_presence = (strain_depth > depth_thresh).drop(columns=other_strain_list)

strains_per_species_per_sample = strain_presence.T.groupby(strain_to_species).sum().T
strains_per_species_per_subject = strain_presence.groupby(meta_all.subject_id).sum().T.groupby(strain_to_species).sum().T

bins = np.arange(50)
plt.hist(strains_per_species_per_subject.stack(), bins=bins)
plt.hist(strains_per_species_per_sample.stack(), bins=bins, alpha=0.5)

plt.yscale('log')

# strains_per_species_by_sample = (strain_depth.groupby(meta_all.subject_id).max() > 0.1).drop(columns=other_strain_list)
# total_strains_per_subject = strain_by_subject.sum(1)
# total_subjects_per_strain = strain_by_subject.sum(0)

# fig, axs = plt.subplots(2, figsize=(6, 6), gridspec_kw=dict(hspace=0.4))

# axs[0].hist(total_strains_per_subject, bins=np.linspace(0, total_strains_per_subject.max(), num=20), density=True, color='k', histtype='step')
# axs[0].hist(total_strains_per_subject, bins=np.linspace(0, total_strains_per_subject.max(), num=20), density=True, color='k', alpha=0.5)

# axs[0].set_yticks([0])
# axs[0].set_title('Strains per Subject')
# axs[1].hist(total_subjects_per_strain, bins=np.arange(0, total_subjects_per_strain.max()), density=True, color='k', histtype='step')
# axs[1].hist(total_subjects_per_strain, bins=np.arange(0, total_subjects_per_strain.max()), density=True, color='k', alpha=0.5)
# # axs[1].set_ylabel('density')
# axs[1].set_yticks([0])
# axs[1].set_title('Subjects per Strain')
# # fig.tight_layout()

In [None]:
plt.hist(strain_by_subject.sum(0), bins=np.arange(0, 50))

In [None]:
import lib.stats

lib.stats.mannwhitneyu('x', 'y', data=pd.DataFrame(dict(x=diff_subj.astype(bool), y=shared_strains)))

In [None]:
sp.stats.mannwhitneyu(shared_strains[diff_subj.astype(bool)], shared_strains[~diff_subj.astype(bool)])

In [None]:
# Shared strains between 
other_strain_list = idxwhere(strain_depth.columns.to_series().str.endswith("-1"))
x = (strain_depth.groupby(meta_all.subject_id).max() > 0.1).drop(columns=other_strain_list)

shared_strains_between_subjects = pdist(x, metric=lambda x, y: (x & y).sum())

In [None]:
bins = np.arange(0, 150)

plt.hist(shared_strains_between_subjects, bins=bins)
None

In [None]:
pd.Series(shared_strains_between_subjects).quantile([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.90, 0.95, 0.99])