In [None]:
%load_ext autoreload

In [None]:
import os as _os
_os.chdir(_os.environ['PROJECT_ROOT'])
_os.path.realpath(_os.path.curdir)

#### Imports

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, aligned_index, align_indexes, invert_mapping
import lib.thisproject.data
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm
import subprocess
from tempfile import mkstemp
import time
import subprocess
from itertools import chain
import os
from itertools import product
from mpl_toolkits.axes_grid1 import make_axes_locatable
import lib.thisproject.data

import sfacts as sf

In [None]:
# Written using ChatGPT

import numpy as np

def logit_space(start, end, num=50, endpoint=True, base=10.0):
    """
    Return numbers spaced evenly on a logit scale.

    Parameters:
        start (float): The starting value for the range (0 < start < 1).
        end (float): The ending value for the range (0 < end < 1).
        num (int, optional): Number of points in the output array. Default is 50.
        endpoint (bool, optional): If True, `end` is the last value in the range. If False,
                                   `end` is not included. Default is True.
        base (float, optional): The base of the logit space. Default is 10.0.

    Returns:
        numpy.ndarray: An array of `num` equally spaced values on the logit scale.
    """
    if not (0 < start < 1) or not (0 < end < 1):
        raise ValueError("Start and end values must be in the (0, 1) interval.")
    if num <= 0:
        raise ValueError("Number of points (num) must be positive.")
    if base <= 1.0:
        raise ValueError("Base must be greater than 1.0 for logit space.")

    # Convert the start and end values to the logit scale
    start_logit = np.log(start / (1 - start))
    end_logit = np.log(end / (1 - end))

    # Generate the logit space values
    logit_values = np.logspace(start_logit, end_logit, num=num, endpoint=endpoint, base=base)

    # Convert the logit space values back to the original scale
    output_values = logit_values / (1 + logit_values)

    return output_values

# Example usage:
start_val = 0.1
end_val = 0.9
num_points = 100

logit_space_values = logit_space(start_val, end_val, num=num_points)
print(logit_space_values)

plt.plot(logit_space_values)

### Set Style

In [None]:
sns.set_context('talk')
plt.rcParams['figure.dpi'] = 50

In [None]:
species_taxonomy = lib.thisproject.data.load_species_taxonomy("ref/gtpro/species_taxonomy_ext.tsv")
species_taxonomy

In [None]:
genomes_meta = (
    pd.read_table('ref/midasdb_uhgg/genomes.tsv', index_col='genome')
    .rename_axis(index='genome_id')
    .rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
)
genomes_meta

In [None]:
reference_meta = (
    pd.read_table('ref/uhgg_genomes_all_4644.tsv', dtype={'species': str}, index_col='Genome')
    .rename_axis(index='genome_id')
    .rename(lambda s: 'UHGG' + s[len("GUT_GENOME"):])
    .join(genomes_meta)
)
reference_meta

In [None]:
# - How many samples
group = 'xjin_ucfmt_hmp2'  # But drop xjin_ samples

sample_list = pd.read_table('meta/mgen_group.tsv')[lambda x: (x.mgen_group == group) & (~x.mgen_id.str.startswith('xjin_'))].mgen_id.to_list()
assert len(sample_list) == len(set(sample_list))
len(sample_list)

In [None]:
# - How many species analyzed
species_list1 = pd.read_table('meta/species_group.tsv')[lambda x: (x.species_group_id == group)].species_id.to_list()
assert len(species_list1) == len(set(species_list1))
len(species_list1)

In [None]:
# - How many species found in xjin_ucfmt_hmp2?
# TODO: Remind myself of what my species filters were.
species_list2 = [str(x) for x in lib.pandas_util.read_list('data/group/xjin_ucfmt_hmp2/r.proc.gtpro.horizontal_coverage.select_species.list')]
len(species_list2)

In [None]:
# - For each species:
species_id = '102506'

In [None]:
#   - How many species-x-samples pairs at sufficient depth (1x)
species_depth = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.species_depth.tsv', names=['sample', 'depth'], index_col=['sample']).depth.reindex(sample_list, fill_value=0)
sample_list2 = idxwhere((species_depth > 1))
len(sample_list2)

In [None]:
#   - How many strains were these collapsed into (with at least 1x depth)
strain_frac = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv', index_col=['sample', 'strain']).community.unstack('strain').loc[sample_list2].rename(columns=str)
strain_depth = (strain_frac.T * species_depth.loc[sample_list2]).T
strain_list = idxwhere((strain_depth > 1).any())
len(strain_list)

In [None]:
#   - How many of these had at least one "pure" sample
strain_list2 = idxwhere((strain_frac[strain_list] > 0.95).any())
len(strain_list2)

In [None]:
#   - How many passed "species gene frac" threshold?
strain_meta = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30_thresh-corr100-depth250.strain_meta.tsv', index_col='strain').rename(str)
strain_list3 = idxwhere(strain_meta.reindex(strain_list2, fill_value=0).species_gene_frac > 0.9)
len(strain_list3)

In [None]:
#   - How many passed gene count filtering? (These are our final numbers)
x = strain_meta.loc[strain_list3].num_genes
_df, _loc, _scale = sp.stats.t.fit(x.values, fix_df=2)
_dist0 = sp.stats.t(_df, _loc, _scale)
_dist1 = sp.stats.norm(_loc, _scale)
thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
thresh_min_num_uhgg_genes = _dist1.ppf(0.001)

strain_list4 = idxwhere((x > thresh_min_num_uhgg_genes) & (x < thresh_max_num_uhgg_genes))
len(strain_list4)

In [None]:
species_strain_counts = {}
strain_details = []

for species_id in tqdm(species_list2):
    #   - How many species-x-samples pairs at sufficient depth (1x)
    species_depth = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv', names=['sample', 'depth'], index_col=['sample']).depth.reindex(sample_list, fill_value=0)
    sample_list2 = idxwhere((species_depth > 0.05))
    
    #   - How many strains were these collapsed into (with at least 1x depth)
    try:
        strain_frac = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv', index_col=['sample', 'strain']).community.unstack('strain').reindex(sample_list2, fill_value=0)
    except FileNotFoundError as err:
        print(f"SFacts output missing for {species_id}.")
        print(err)
        continue

    strain_list0 = idxwhere((strain_frac > 0.5).any())
    strain_depth = (strain_frac.T * species_depth.loc[sample_list2]).T    
    #   - How many of these had at least one "pure" sample
    strain_list1 = idxwhere((strain_frac[strain_list0] > 0.95).any())
    
    try:
        strain_meta = pd.read_table(f'data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95_ss-all_t-30_thresh-corr100-depth250.strain_meta.tsv', index_col='strain').loc[strain_list1]
    except FileNotFoundError as err:
        print(f"SPGC output missing for {species_id}.")
        print(err)
        continue

    #   - How many of these had a total depth of >1x?
    strain_list2 = idxwhere(strain_meta.sum_depth > 1)  # 

    #   - How many passed "species gene frac" threshold?
    strain_list3 = idxwhere(strain_meta.reindex(strain_list2, fill_value=0).species_gene_frac > 0.9)
    
    #   - How many passed gene count filtering? (These are our final numbers)
    x0 = strain_meta.loc[strain_list3].num_genes  # Use only high quality strains to create distribution.
    x1 = strain_meta.num_genes  # Assess all strains.
    if len(x0) < 1:
        strain_list4 = []
    else:
        _df, _loc, _scale = sp.stats.t.fit(x0.values, fix_df=2)
        _dist0 = sp.stats.t(_df, _loc, _scale)
        _dist1 = sp.stats.norm(_loc, _scale)
        thresh_max_num_uhgg_genes = _dist1.ppf(0.999)
        thresh_min_num_uhgg_genes = _dist1.ppf(0.001)
        strain_list4 = idxwhere((x1 > thresh_min_num_uhgg_genes) & (x1 < thresh_max_num_uhgg_genes))

    species_strain_counts[species_id] = pd.Series(dict(
        num_species_samples=len(sample_list2),  # Species depth >1x
        num_inferred_strains=len(strain_list0),  # "Inferred" means >50% in at least one sample.
        num_strains_with_pure_sample=len(strain_list1),  # At least one "pure" sample
        num_strains_with_sufficient_depth=len(strain_list2),  # >1x depth across all samples  # NOTE: This includes xjin samples.
        num_complete_spgc=len(strain_list3),  # Species gene frac >90%
        num_passing_spgc=len(set(strain_list3) & set(strain_list4)),  # Not a gene count outlier.
    ))
    strain_details.append(
        pd.DataFrame(index=strain_list0)
        .join(strain_meta)
        .assign(
            species=species_id,
            strain=lambda x: x.index,
            has_inference=lambda x: x.index.isin(strain_list0),
            has_pure_sample=lambda x: x.index.isin(strain_list1),
            has_sufficient_depth=lambda x: x.index.isin(strain_list2),
            has_species_genes=lambda x: x.index.isin(strain_list3),
            has_reasonable_gene_count=lambda x: x.index.isin(strain_list4),
            has_species_genes_and_reasonable_gene_count=lambda x: x.has_species_genes & x.has_reasonable_gene_count,
        )
    )
species_strain_counts = pd.DataFrame(species_strain_counts).T
strain_details = pd.concat(strain_details).set_index(['species', 'strain'])

In [None]:
species_strain_counts.sum()

In [None]:
strain_details[['has_inference', 'has_pure_sample', 'has_sufficient_depth', 'has_species_genes', 'has_reasonable_gene_count', 'has_species_genes_and_reasonable_gene_count']].sum()

In [None]:
species_strain_counts.groupby(species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)).sum().sort_values('num_passing_spgc', ascending=False)

In [None]:
species_taxonomy[lambda x: x.c__ == 'c__Coriobacteriia']

In [None]:
species_strain_counts.groupby(species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)).sum()#.sort_values('num_passing_spgc', ascending=False)

In [None]:
d = (
    species_strain_counts
    .groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
    .T
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='rainbow')

fig, ax = plt.subplots(figsize=(3, 6))
for taxon in d.columns:
    ax.plot(d[taxon], c=_palette[taxon], label=taxon, lw=3, alpha=0.8)
# d.plot(kind='line')
ax.set_yscale('symlog', linthresh=1)
ax.set_ylim(1)
lib.plot.rotate_xticklabels()
ax.legend(bbox_to_anchor=(1, 1))

In [None]:
d = (
    species_strain_counts
    .groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
    .assign(phylum=lambda x: x.index.to_series().str[len('p__'):])
    .set_index('phylum')
    .drop(columns=['num_species_samples'])
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='Spectral')

fig = plt.figure(figsize=(9, 5))
for level in d.columns:
    plt.bar(d.index, d[level], color=_palette[level])

    
plt.yscale('log')
plt.ylabel('Count')
# plt.legend(bbox_to_anchor=(1, 1))
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels(rotation=25)

fig, ax = plt.subplots()
_rename_levels = dict(
    num_species_samples='Species-Sample Pairs',  # Species depth >1x
    num_inferred_strains='Strains Inferred',  # "Inferred" means >50% in at least one sample.
    num_strains_with_pure_sample='+ Pure Samples (>95%)',  # At least one "pure" sample
    num_strains_with_sufficient_depth='+ Sufficient Depth (>1x)',  # >1x depth across all samples  # NOTE: This includes xjin samples.
    num_complete_spgc='+ High "Completeness" (>90%)',  # Species gene frac >90%
    num_passing_spgc='+ Appropriate Gene Count',  # Not a gene count outlier.
)
for level in d.columns:
    ax.bar(d.index, 0, color=_palette[level], label=_rename_levels[level])
ax.legend()
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)


# fig, ax = plt.subplots()
# for level, height in zip(_palette, reversed(np.linspace(0, 1, num=len(_palette) + 1))):
#     ax.bar(0, height, color=_palette[level])
#     ax.annotate(level, xy=(0, height), ha='center', va='top')

In [None]:
d = (
    species_strain_counts
    .groupby(
        species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        # species_taxonomy.apply(lambda x: x.p__, axis=1)
    )
    .sum()
    .sort_values('num_passing_spgc', ascending=False)
)
_palette = lib.plot.construct_ordered_palette(d.columns, cm='Spectral')

for level in d.columns:
    plt.bar(d.index, d[level], color=_palette[level], label=level)

# for level in reversed(d.columns):
#     plt.bar(d.index, 0, color=_palette[level], label=level)

plt.yscale('log')
plt.legend(bbox_to_anchor=(1, 1))
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels()

In [None]:
d.sort_values('num_passing_spgc', ascending=False).head(20).index

In [None]:
# _species_list = ['100022', '102506', '102492']
# taxon_list = species_taxonomy.loc[_species_list].taxonomy_string

d = (
    species_strain_counts.groupby(
        # species_taxonomy.apply(lambda x: x.p__ + ';' + x.c__, axis=1)
        species_taxonomy.taxonomy_string
    )
    .sum()
    .sort_values("num_passing_spgc", ascending=False)
)
taxon_list = d.sort_values("num_species_samples", ascending=False).head(40).index
_palette = lib.plot.construct_ordered_palette(d.columns, cm="rainbow")

fig, ax = plt.subplots(figsize=(15, 5))
for level in d.columns:
    plt.bar(taxon_list, d.loc[taxon_list, level], color=_palette[level], label=level)

# for level in reversed(d.columns):
#     plt.bar(d.index, 0, color=_palette[level], label=level)
plt.legend(bbox_to_anchor=(1, 1))

plt.yscale("symlog", linthresh=1, linscale=0.1)
plt.ylim(0.1)
plt.yticks(np.logspace(0, 5, num=6), minor=False)
plt.yticks([], minor=True)
lib.plot.rotate_xticklabels()

In [None]:
species_taxonomy[species_taxonomy.s__.str.contains('fragilis')]

In [None]:
d = (
    species_strain_counts
    .join(
        reference_meta.groupby('species').Genome_type.value_counts().unstack(fill_value=0).rename(str)
    )
    .join(species_taxonomy)
)

d.sort_values('num_species_samples', ascending=False).head(40)

In [None]:
d = (
    species_strain_counts
    .join(
        reference_meta.groupby('species').Genome_type.value_counts().unstack(fill_value=0).rename(str)
    )
    .join(species_taxonomy)
    .sort_values('num_passing_spgc')
    # .set_index('taxonomy_string')
    [['num_passing_spgc', 'Isolate', 'MAG']]
)

taxon_list = d.sort_values('num_passing_spgc', ascending=False).head(40).index
# _palette = lib.plot.construct_ordered_palette(d.columns, cm='rainbow')

fig, ax = plt.subplots(figsize=(15, 5))
d.loc[taxon_list].plot.bar(ax=ax)
ax.set_yscale('log')

In [None]:
xbins = np.logspace(-2, 4, num=50)
ybins = np.linspace(0, 1, num=50)

plt.hist2d(
    'sum_depth',
    'species_gene_frac',
    data=strain_details,
    bins=(xbins, ybins),
    norm=mpl.colors.PowerNorm(1/2),
    cmap='magma_r',
    cmin=1,
)
plt.colorbar()
plt.xscale('log')
plt.xlabel('Total Core Genome Depth')
plt.ylabel('Species Gene Fraction')

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.scatter(
    "max_depth",
    "sum_depth",
    data=strain_details.assign(
        one_minus_species_gene_frac=lambda x: 1 - x.species_gene_frac
    ),
    s=1,
    c="one_minus_species_gene_frac",
    norm=mpl.colors.SymLogNorm(0.1),
)
plt.yscale("log")
plt.xscale("log")
plt.colorbar()

In [None]:
import sklearn as skl
from sklearn.svm import SVC

d = strain_details[['max_depth', 'sum_depth', 'has_reasonable_gene_count']].dropna()

# scatter plot
fig, ax = plt.subplots(figsize=(10, 10))
ax.scatter('max_depth', 'sum_depth', data=d, s=1, c='has_reasonable_gene_count')
plt.yscale("log")
plt.xscale("log")

In [None]:
fig = plt.figure(figsize=(10, 10))
plt.scatter(
    "sum_depth",
    "num_sample",
    data=strain_details.assign(
        one_minus_species_gene_frac=lambda x: 1 - x.species_gene_frac
    ),
    s=1,
    c="one_minus_species_gene_frac",
    norm=mpl.colors.SymLogNorm(0.1),
)
plt.yscale("log")
plt.xscale("log")
plt.colorbar()

In [None]:
mgen_inpath="meta/hmp2/mgen.tsv"
preparation_inpath="meta/hmp2/preparation.tsv"
stool_inpath="meta/hmp2/stool.tsv"
subject_inpath="meta/hmp2/subject.tsv"

mgen = pd.read_table(mgen_inpath, index_col='library_id')
preparation = pd.read_table(preparation_inpath, index_col='preparation_id')
stool = pd.read_table(stool_inpath, index_col='stool_id')
subject = pd.read_table(subject_inpath, index_col='subject_id')

mgen_meta = mgen.join(preparation, on='preparation_id', lsuffix='_mgen', rsuffix='_preparation').join(stool, on='stool_id').join(subject, on='subject_id')
mgen_meta.subject_id.value_counts()

In [None]:
all_frac_from_hmp2 = []

# for species_id in tqdm(species_list2):
for species_id in tqdm(species_list2):
    #   - How many species-x-samples pairs at sufficient depth (1x)
    species_depth = pd.read_table(
        f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99_new-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
        names=["sample", "depth"],
        index_col=["sample"],
    ).depth.reindex(sample_list, fill_value=0)

    #   - How many strains were these collapsed into (with at least 1x depth)
    try:
        strain_frac = pd.read_table(
            f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv",
            index_col=["sample", "strain"],
        ).community.unstack("strain")
    except FileNotFoundError as err:
        print(f"SFacts output missing for {species_id}.")
        print(err)
        continue

    try:
        _frac, _meta = align_indexes(strain_frac, mgen_meta)
    except AssertionError:
        print(set(strain_frac.index) & set(mgen_meta.index))
        print(f"No HMP2 samples found in strain table for {species_id}.")
        continue

    all_frac_from_hmp2.append(_frac.rename(columns=lambda x: f"{species_id}_{x}"))

all_frac_from_hmp2 = pd.concat(all_frac_from_hmp2, axis=1)

In [None]:
strain_found, _meta = align_indexes(all_frac_from_hmp2.fillna(0) > 0.25, mgen_meta)

different_subject = sp.spatial.distance.pdist(
    _meta[['subject_id']],
    metric=lambda x, y: x != y,
).astype(bool)

In [None]:
jaccard_all_species = sp.spatial.distance.pdist(strain_found, metric='jaccard')

In [None]:
bins = np.linspace(0, 1, num=100)
plt.hist(jaccard_all_species[different_subject], density=True, bins=bins)
plt.hist(jaccard_all_species[~different_subject], density=True, bins=bins)
None

In [None]:
# Not a distance, so squareform will be wrong along the diagonal
shared_strains_cdmat = sp.spatial.distance.pdist(strain_found, metric=lambda x, y: (x & y).sum())
shared_strains_dmat = pd.DataFrame(sp.spatial.distance.squareform(shared_strains_cdmat) + np.diag(strain_found.sum(1)), index=all_frac_from_hmp2.index, columns=all_frac_from_hmp2.index)

In [None]:
d = shared_strains_dmat
sns.clustermap(d, norm=mpl.colors.PowerNorm(1/2))

In [None]:
different_subject.shape

In [None]:
bins = np.linspace(0, 200)

x = pd.Series(different_subject)
y = pd.Series(shared_strains_cdmat)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
fig.subplots_adjust(hspace=0.15)

for ax in [ax1, ax2]:
    ax.hist(y[x], alpha=0.7, bins=bins, density=True, label='Different Subject')
    ax.hist(y[~x], alpha=0.7, bins=bins, density=True, label='Same Subject')

ax1.set_ylim(bottom=0.2, top=0.23)
ax2.set_ylim(bottom=0, top=0.03)

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

ax2.set_xlabel('Shared Strains')
ax1.legend()



plt.xlabel('Shared Strains (count)')
plt.ylabel('Density')
None

In [None]:
plt.plot(np.linspace(0, 1, (x).sum()), y[x].sort_values().values)
plt.plot(np.linspace(0, 1, (~x).sum()), y[~x].sort_values().values)

In [None]:
y = pd.Series(shared_strains_cdmat)
x = pd.Series(different_subject)

y[~x].quantile([0.05, 0.25, 0.5, 0.75, 0.95]), y[x].quantile([0.05, 0.25, 0.5, 0.75, 0.95])

In [None]:
# Most strains are found in only one subject.
bins = np.arange(1, 40)
plt.hist((all_frac_from_hmp2 > 0.5).groupby(mgen_meta.subject_id).any().sum(), bins=bins)

In [None]:
thresh = 0.5

# How many strains are found at >thresh% in any sample?
observed_strain_list = idxwhere((all_frac_from_hmp2 > thresh).any())
print(len(observed_strain_list))
# How many of these are found above this thresh in more than one sample:
multi_sample_observed_strain_list = idxwhere((all_frac_from_hmp2 > thresh).sum() > 1)
print(len(multi_sample_observed_strain_list))
# How many of these are found above this thresh in more than one subject?
multi_subject_observed_strain_list = idxwhere((all_frac_from_hmp2 > thresh).groupby(mgen_meta.subject_id).any().sum() > 1)
print(len(multi_subject_observed_strain_list))
# What is the distribution of number of subjects in multi-sample strains?

x = (all_frac_from_hmp2[multi_sample_observed_strain_list] > thresh).groupby(mgen_meta.subject_id).any().sum()

plt.hist(x, bins=np.linspace(0, 40, num=41), density=True)
# plt.yscale('log')
x.value_counts()

In [None]:
x.sort_values(ascending=False)

In [None]:
species_strain_specificity = {}

# for species_id in tqdm(species_list2):
for species_id in tqdm(species_list2):
    #   - How many species-x-samples pairs at sufficient depth (1x)
    species_depth = pd.read_table(
        f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
        names=["sample", "depth"],
        index_col=["sample"],
    ).depth.reindex(sample_list, fill_value=0)

    #   - How many strains were these collapsed into (with at least 1x depth)
    try:
        strain_frac = pd.read_table(
            f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv",
            index_col=["sample", "strain"],
        ).community.unstack("strain")
    except FileNotFoundError as err:
        print(f"SFacts output missing for {species_id}.")
        print(err)
        continue

    try:
        _frac, _meta = align_indexes(strain_frac, mgen_meta)
    except AssertionError:
        print(set(strain_frac.index) & set(mgen_meta.index))
        print(f"No HMP2 samples found in strain table for {species_id}.")
        continue

    num_subjects = len(_meta.subject_id.unique())

    share_no_strains = sp.spatial.distance.pdist(
        (_frac > 0.05), metric=lambda x, y: (x * y).sum() == 0
    )
    different_subject = sp.spatial.distance.pdist(
        _meta[['subject_id']],
        metric=lambda x, y: x != y,
    )
    contingency = (
        pd.DataFrame(
            dict(
                subject=pd.Series(different_subject).astype(bool).map({True: 'different', False: 'same'}),
                share_strains=pd.Series(share_no_strains).astype(bool).map({True: 'none_shared', False: 'shared'}),
            )
        )
        .value_counts()
        .reindex(
            [('same', 'shared'), ('same', 'none_shared'), ('different', 'shared'), ('different', 'none_shared')], fill_value=0
        )
    )
    contingency_pc = contingency.unstack() + 1
    odds_ratio_pc = (
        contingency_pc.loc['same', 'shared'] / contingency_pc.loc['same', 'none_shared']
    ) / (contingency_pc.loc['different', 'shared'] / contingency_pc.loc['different', 'none_shared'])
    species_strain_specificity[species_id] = pd.concat(
        [
            contingency,
            pd.Series(dict(odds_ratio_pc=odds_ratio_pc, num_samples=_frac.shape[0], num_subjects=num_subjects)),
        ]
    )

species_strain_specificity = pd.DataFrame(species_strain_specificity).T

In [None]:
species_strain_specificity

In [None]:
plt.scatter('num_samples', 'log_odds_ratio_pc', data=species_strain_specificity.assign(log_odds_ratio_pc=lambda x: np.log2(x.odds_ratio_pc)))
plt.xscale('log')
plt.axhline(0, lw=1, linestyle='--', color='k')

In [None]:
species_strain_specificity.sort_values('num_samples', ascending=False).head(20)

In [None]:
species_strain_specificity.sort_values('odds_ratio_pc', ascending=False).head(20)

In [None]:
all_depth_from_hmp2 = []

# for species_id in tqdm(species_list2):
for species_id in tqdm(species_list2):
    #   - How many species-x-samples pairs at sufficient depth (1x)
    species_depth = pd.read_table(
        f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gene99-v22-agg75.spgc_specgene-ref-t25-p95.species_depth.tsv",
        names=["sample", "depth"],
        index_col=["sample"],
    ).depth.reindex(sample_list, fill_value=0)

    #   - How many strains were these collapsed into (with at least 1x depth)
    try:
        strain_frac = pd.read_table(
            f"data/group/xjin_ucfmt_hmp2/species/sp-{species_id}/r.proc.gtpro.sfacts-fit.comm.tsv",
            index_col=["sample", "strain"],
        ).community.unstack("strain")
    except FileNotFoundError as err:
        print(f"SFacts output missing for {species_id}.")
        print(err)
        continue

    try:
        _frac, _meta, _species_depth = align_indexes(strain_frac, mgen_meta, species_depth)
    except AssertionError:
        print(set(strain_frac.index) & set(mgen_meta.index))
        print(f"No HMP2 samples found in strain table for {species_id}.")
        continue

    _depth = _frac.multiply(_species_depth, axis=0)

    all_depth_from_hmp2.append(_depth.rename(columns=lambda x: f"{species_id}_{x}"))

all_depth_from_hmp2 = pd.concat(all_depth_from_hmp2, axis=1).fillna(0)

In [None]:
x, y = lib.pandas_util.align_indexes(all_species_depth.sum(1), all_depth_from_hmp2.sum(1))

plt.scatter(x, y)

In [None]:
all_strain_frac = all_depth_from_hmp2.divide(all_depth_from_hmp2.sum(1), axis=0)

# Not a distance, so squareform will be wrong along the diagonal
shared_strains_bc_cdmat = sp.spatial.distance.pdist(all_strain_frac, metric='braycurtis')
shared_strains_bc_dmat = pd.DataFrame(sp.spatial.distance.squareform(shared_strains_bc_cdmat), index=all_strain_frac.index, columns=all_strain_frac.index)

In [None]:
sns.clustermap(shared_strains_bc_dmat)

In [None]:
_meta = mgen_meta.loc[all_strain_frac.index]
_diff_subject = sp.spatial.distance.pdist(
    _meta[['subject_id']],
    metric=lambda x, y: x != y,
).astype(bool)

assert _meta.subject_id.value_counts().map(lambda x: x * (x - 1) / 2).sum() == (~_diff_subject).sum()

bins = np.linspace(0, 1, num=11)
plt.hist(shared_strains_bc_cdmat[_diff_subject], density=True, bins=bins, alpha=0.6)
plt.hist(shared_strains_bc_cdmat[~_diff_subject], density=True, bins=bins, alpha=0.6)

None

In [None]:
q = [0.05, 0.25, 0.5, 0.75, 0.95]
(
    np.quantile(shared_strains_bc_cdmat[_diff_subject], q),
    np.quantile(shared_strains_bc_cdmat[~_diff_subject], q),
)

In [None]:
shared_strains_jacc_cdmat = sp.spatial.distance.pdist(all_strain_frac > 0.001, metric='jaccard')
shared_strains_jacc_dmat = pd.DataFrame(sp.spatial.distance.squareform(shared_strains_bc_cdmat), index=all_strain_frac.index, columns=all_strain_frac.index)

In [None]:
_drop_samples = idxwhere(all_depth_from_hmp2.sum(1) < 50)

filt_all_strain_frac, _meta = align_indexes(all_depth_from_hmp2.divide(all_depth_from_hmp2.sum(1), axis=0).drop(_drop_samples), mgen_meta)

# Not a distance, so squareform will be wrong along the diagonal
filt_shared_strains_cdmat = sp.spatial.distance.pdist(filt_all_strain_frac > 0.0001, metric=lambda x, y: (x & y).sum())
filt_shared_strains_dmat = pd.DataFrame(sp.spatial.distance.squareform(filt_shared_strains_cdmat) + np.diag((filt_all_strain_frac > 0.001).sum(1)), index=filt_all_strain_frac.index, columns=filt_all_strain_frac.index)
different_subject_cdmat = sp.spatial.distance.pdist(
    _meta[['subject_id']],
    metric=lambda x, y: x != y,
).astype(bool)

In [None]:
bins = np.linspace(0, 80, num=80)

fig, (ax1, ax2) = plt.subplots(2, sharex=True)
fig.subplots_adjust(hspace=0.15)

for ax in [ax1, ax2]:
    ax.hist(filt_shared_strains_cdmat[different_subject_cdmat], density=True, alpha=0.6, bins=bins, label='Different Subject')
    ax.hist(filt_shared_strains_cdmat[~different_subject_cdmat], density=True, alpha=0.6, bins=bins, label='Same Subject')

ax1.set_ylim(bottom=0.8, top=0.94)
ax2.set_ylim(bottom=0, top=0.1)

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

ax2.set_xlabel('Shared Strains')
ax1.legend()

None

In [None]:
species_strain_specificity.head(2)

In [None]:
d0 = species_strain_specificity.assign(
    frac_same_subject_with_shared_strains=lambda x: x[('same', 'shared')] / (x[('same', 'shared')] + x[('same', 'none_shared')]),
    frac_diff_subject_with_shared_strains=lambda x: x[('different', 'shared')] / (x[('different', 'shared')] + x[('different', 'none_shared')]),
    num_subjects_x_10=lambda x: x.num_subjects * 10,
)

d1 = d0.dropna(subset=["frac_same_subject_with_shared_strains", "frac_diff_subject_with_shared_strains"])[lambda x: x.num_subjects > 2].sort_values('num_subjects', ascending=True)


fig, ax = plt.subplots(figsize=(10, 10))

cbar_artist = ax.scatter('frac_same_subject_with_shared_strains', 'frac_diff_subject_with_shared_strains', c='num_samples', s='num_subjects_x_10', data=d1, alpha=0.4, label='__nolegend__')
for n in [2, 4, 8, 16, 32]:
    ax.scatter([], [], s=n * 10, c='k', alpha=0.4, label=n)
ax.legend(title='Distinct Subjects')

fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.15, 0.025, 0.67])
cbar = fig.colorbar(cbar_artist, cax=cbar_ax)#, label="count strains")
cbar.solids.set_alpha(1.0)

ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)
ax.set_aspect(1.0)
ax.set_xticks([0, 0.5, 1.0])
ax.set_yticks([0, 0.5, 1.0])

ax.set_xlabel('Same Subject Strain Sharing')
ax.set_ylabel('Different Subject Strain Sharing')