In [None]:
%load_ext autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere, repeated
import matplotlib as mpl
import lib.plot
import statsmodels as sm
from statsmodels.stats.multitest import fdrcorrection
from tqdm import tqdm

In [None]:
import sfacts as sf

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/hmp2/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

mgen_meta = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
)

assert not any(mgen_meta.subject_id.isna())

# meta.columns

In [None]:
_subject_week = (
    visit
    .join(subject, on='subject_id')
    .reset_index()
    .dropna(subset=['subject_id', 'week_number'])
    .groupby(['subject_id', 'week_number'])
    .apply(lambda d: d.loc[d.notna().sum(1).sort_values().index[-1]])
    .assign(subject_week_id=lambda x: x.subject_id + '_' + x.week_number.astype(int).astype(str))
    .set_index('subject_week_id')
    .join(stool.groupby('visit_id').fecal_calprotectin.mean(), on='visit_id')
    .sort_values(['subject_id', 'week_number'])
)

mgen_to_subject_week = mgen_meta.dropna(subset=['week_number']).apply(lambda x: x.subject_id + '_' + str(int(x.week_number)), axis=1).rename('subject_week_id')
mgen_to_subject_week
#.groupby(['subject_id', 'week_number']).visit_id.count().sort_values(ascending=False)

In [None]:
species_depth = (
    pd.read_table('data/hmp2.a.r.proc.gtpro.species_depth.tsv', index_col=['sample', 'species_id'])
    .squeeze()
    .unstack('species_id', fill_value=0)
    .groupby(mgen_to_subject_week)
    .sum()
)
species_depth.columns = species_depth.columns.astype(str)

In [None]:
subject_week = _subject_week.assign(has_mgen=lambda x: x.index.isin(species_depth.index))

In [None]:
strain_depth = pd.read_table(
    'data_temp/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts18-s75-seed0.strain_depth.tsv',
    # names=['library_id', 'species_strain_id', 'depth'],
    index_col=['sample', 'strain'],
).squeeze().unstack('strain', fill_value=0).groupby(mgen_to_subject_week).sum()

plt.hist(strain_depth.sum(1) - species_depth.sum(1), bins=50)
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])
    
strain_taxonomy = strain_depth.columns.to_series().str.split('-').str[0].to_frame(name='species_id').join(species_taxonomy, on='species_id')

species_taxonomy = strain_taxonomy.drop_duplicates(subset=['species_id']).set_index('species_id')

## E. coli MIDAS Genes

In [None]:
gene_clusters = pd.read_table('ref_temp/midasdb_uhgg/pangenomes/102506/cluster_info.txt', index_col='centroid_99')

In [None]:
# FIXME: Install python-lz4 into the sfacts module so I can open this file as 'data_temp/sp-102506.hmp2.a.r.proc.midas_merge/genes/102506/102506.genes_depth.tsv.lz4' instead?
genes_depth = pd.read_table('data_temp/sp-102506.hmp2.a.r.proc.midas_genes.depth.tsv', index_col='gene_id').groupby(mgen_to_subject_week, axis='columns').sum().T

In [None]:
genes_depth_75 = genes_depth.groupby(gene_clusters.centroid_75, axis='columns').sum()

In [None]:
mean_depth_present_genes = genes_depth_75.T.apply(lambda x: x[x > 0].median())

In [None]:
species_depth['102506'].sort_values()

In [None]:
plt.hist(np.log10(genes_depth_75.loc['C3009_10'] + 1e-3), bins=100)
plt.axvline(np.log10(species_depth.loc['C3009_10', '102506']))
plt.yscale('log')

In [None]:
def trim_gmean_nonzero(x, proportiontocut, axis=0):
    return np.exp(sp.stats.trim_mean(np.log(x[lambda x: x > 0]), proportiontocut, axis=axis))

In [None]:
def trim_mean_top_n(x, n, proportiontocut, axis=0):
    return sp.stats.trim_mean(np.sort(x)[-n:], proportiontocut, axis=axis)

In [None]:
mean_depth_present_genesB = genes_depth_75.T.apply(lambda x: x[x > 0].median())
mean_depth_present_genesA = genes_depth_75.T.apply(trim_gmean_nonzero, proportiontocut=0.1)
mean_depth_present_genes = genes_depth_75.T.apply(trim_mean_top_n, n=2000, proportiontocut=0.3)

In [None]:
d = pd.DataFrame(dict(species=species_depth['102506'], genes=mean_depth_present_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

## E. coli Strains

In [None]:
ecoli_strain_depth = strain_depth.loc[:, strain_taxonomy.species_id == '102506']

In [None]:
sns.clustermap(ecoli_strain_depth, norm=mpl.colors.PowerNorm(1/10))

In [None]:
d = (ecoli_strain_depth > 1e-3).sum().sort_values(ascending=False)
top_strains = d.head(20).index

In [None]:
d.head()

### Full-species

In [None]:
ecoli_samples = idxwhere(species_depth['102506'] > 1.0)
no_ecoli_samples = idxwhere(species_depth['102506'] < 1e-5)
len(ecoli_samples), len(no_ecoli_samples)

In [None]:
d = pd.DataFrame(dict(species=species_depth['102506'], genes=mean_depth_present_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
x = species_depth
y = genes_depth_75.loc[species_depth.index]
gene_species_cos_dist = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
species_thresh = 0.1
species_gene_hits = idxwhere(gene_species_cos_dist.loc['102506'] < species_thresh)
len(species_gene_hits)

In [None]:
gene_species_cos_dist.idxmin().value_counts().head(5)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_species_cos_dist.loc['102506'], bins=bins)
plt.hist(gene_species_cos_dist.loc['102506', gene_species_cos_dist.idxmin() == '102506'], bins=bins)
plt.hist(gene_species_cos_dist.loc['102506'].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)
plt.axvline(species_thresh, lw=1, linestyle='--', color='k')

plt.yscale('log')
None

In [None]:
gene_clusters.loc[species_gene_hits].marker_id.dropna().value_counts()

In [None]:
sns.clustermap(genes_depth_75[species_gene_hits].loc[ecoli_samples], norm=mpl.colors.SymLogNorm(linthresh=0.5), metric='cosine')

In [None]:
mean_depth_species_genes = genes_depth_75[species_gene_hits].apply(sp.stats.trim_mean, proportiontocut=0.1, axis=1)
mean_depth_species_genes.sort_values(ascending=False).head(10)

In [None]:
d = pd.DataFrame(dict(species=species_depth['102506'], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
d = pd.DataFrame(dict(species=species_depth['102506'], genes=genes_depth_75[gene_clusters.loc[species_gene_hits].marker_id.dropna().index].mean(1)))
plt.scatter('species', 'genes', data=d)
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
depth_ratio_75 = genes_depth_75.divide(mean_depth_species_genes, axis=0)  # TODO: Should this be the 75% clustering???
depth_ratio = genes_depth.divide(mean_depth_species_genes, axis=0)  # TODO: Should this be the 75% clustering???

In [None]:
d = np.log2(depth_ratio_75.loc[ecoli_samples] + 1e-2).dropna().T
sns.jointplot(x='C3009_10', y='M2034_42', data=d.assign(hits=lambda x: x.index.isin(species_gene_hits)), hue='hits',  alpha=0.5, s=2)
None

### Strain 102506-2

In [None]:
strain_id = top_strains[0]
strain_id

In [None]:
pure_samples = idxwhere((ecoli_strain_depth.apply(lambda x: x / x.sum(), axis=1) > 0.95).any(1))
pure_samples_with_strain = list(set(ecoli_samples) & set(idxwhere((ecoli_strain_depth.idxmax(1) == strain_id))) & set(pure_samples))
focal_samples = list(set(pure_samples_with_strain) | set(no_ecoli_samples))
len(pure_samples_with_strain), len(focal_samples)

In [None]:
d = pd.DataFrame(dict(species=species_depth['102506'], genes=mean_depth_species_genes))
plt.scatter('species', 'genes', data=d.loc[focal_samples])
plt.plot([0, 1e2], [0, 1e2])
plt.yscale('symlog')
plt.xscale('symlog')

In [None]:
x = species_depth.loc[focal_samples]
y = genes_depth.loc[focal_samples]
gene_strain_cos_dist = pd.DataFrame(sp.spatial.distance.cdist(x.T, y.T, metric='cosine'), index=x.columns, columns=y.columns)

In [None]:
strain_thresh = 0.2
strain_gene_hits = idxwhere(gene_strain_cos_dist.loc['102506'] < strain_thresh)
len(strain_gene_hits)

In [None]:
gene_strain_cos_dist.idxmin().value_counts().head(5)

In [None]:
bins = np.linspace(0, 1, num=101)
plt.hist(gene_strain_cos_dist.loc['102506'], bins=bins)
plt.hist(gene_strain_cos_dist.loc['102506', gene_strain_cos_dist.idxmin() == '102506'], bins=bins)
plt.hist(gene_strain_cos_dist.loc['102506'].reindex(gene_clusters.marker_id.dropna().index, axis='columns').dropna(), bins=bins)

plt.yscale('log')
None

In [None]:
gene_clusters.loc[strain_gene_hits].marker_id.dropna().value_counts()

In [None]:
sns.clustermap(genes_depth[strain_gene_hits].loc[ecoli_samples], norm=mpl.colors.SymLogNorm(linthresh=1e-1), metric='cosine')

In [None]:
sns.clustermap(genes_depth[strain_gene_hits].loc[pure_samples_with_strain], norm=mpl.colors.SymLogNorm(linthresh=1e-1), metric='cosine')

In [None]:
d0 = np.log2(strain_mean_depth_ratio)
d1 = np.log2(depth_ratio.loc[pure_samples_with_strain].mean())

bins = np.linspace(-10, 5, num=51)

d2 = pd.DataFrame(dict(mean_depth_ratio=d1, wmean_depth_ratio=d0)).assign(hit=lambda x: x.index.isin(strain_gene_hits))

g = sns.jointplot(data=d2, x='wmean_depth_ratio', y='mean_depth_ratio', hue='hit')
g.ax_joint.plot([-6, 6], [-6, 6], color='k')
# plt.hist(d0, bins=bins)
# plt.hist(d0.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)

# plt.hist(d1, bins=bins)
# plt.hist(d1.reindex(idxwhere(gene_clusters.marker_id.notna())).dropna(), bins=bins)
# plt.yscale('log')
None

In [None]:
best_strain_gene_match = gene_strain_cos_dist.loc['102506'].fillna(1.0).groupby(gene_clusters.centroid_75).idxmin()
assert best_strain_gene_match.is_unique and best_strain_gene_match.index.is_unique
strain_to_species_gene = best_strain_gene_match.to_frame('centroid_99').reset_index().set_index('centroid_99', drop=False)['centroid_75']
strain_to_species_gene

In [None]:
strain_gene_info = (
    strain_mean_depth_ratio
    .to_frame('depth_ratio')
    .assign(
        marker_id=gene_clusters.marker_id,
        strain_cos=gene_strain_cos_dist.loc['102506'].fillna(1.0),
        # species_cos=gene_species_cos_dist.loc['102506'].fillna(1.0),
        species_gene_id=strain_to_species_gene,
    )
    .join(gene_species_cos_dist.loc['102506'].to_frame('species_cos'), on='species_gene_id')
    .assign(species_cos=lambda x: x.species_cos.fillna(1.0))
    .assign(hit=lambda x: (x.depth_ratio > 0.2) & ((x.strain_cos < 0.2) | (x.species_cos < 0.2)))
)

In [None]:
d = strain_gene_info[lambda x: (x.depth_ratio > 0.2) & ((x.strain_cos < 0.2) | (x.species_cos < 0.2))].sort_values('depth_ratio')

plt.scatter('strain_cos', 'species_cos', data=d, c='depth_ratio', s=5, norm=mpl.colors.LogNorm())
plt.yscale('log')
plt.xscale('log')
plt.colorbar()

print(len(d.marker_id.value_counts()))
d.marker_id.value_counts()

In [None]:
mean_depth_species_genes.loc[pure_samples_with_strain].sort_values(ascending=False).head(20)

In [None]:
x, y = 'M2064_53', 'H4040_22'
d = np.log2(depth_ratio.loc[[x, y]] + 1e-2).dropna().T.join(strain_gene_info)
sns.jointplot(x=x, y=y, data=d.sort_values('hit'), hue='hit', s=3, marginal_kws=dict(common_norm=False))

In [None]:
sns.clustermap(genes_depth[idxwhere(strain_gene_info.hit)].loc[pure_samples_with_strain], norm=mpl.colors.SymLogNorm(linthresh=1e-1), metric='cosine')

In [None]:
d = (depth_ratio / strain_gene_info.depth_ratio).loc[ecoli_samples, strain_gene_info.hit].dropna()

sns.clustermap(
    d,
    norm=mpl.colors.SymLogNorm(linthresh=1e-1, vmin=1e-2, vmax=1e2),
    metric='cosine',
    row_colors=d.index.to_series().isin(pure_samples_with_strain).map({False: 'blue', True: 'red'}),
)

In [None]:
idxwhere(mgen_to_subject_week == 'C3023_36')

In [None]:
C3023_36_rerun = pd.read_table('data_temp/sp-102506.hmp2.a.r.proc.midas_output/CSM7KOTA_G110632/genes/102506.genes.tsv', index_col=['gene_id'])

In [None]:
d = C3023_36_rerun['mean_coverage'].to_frame('mean_coverage_new').join(genes_depth.loc['C3023_36'].to_frame('mean_coverage_old')).fillna(1e-2).apply(np.log2)

plt.scatter(x='mean_coverage_old', y='mean_coverage_new', data=d, s=2)
# plt.yscale('symlog', linthresh=1e-2)
# plt.xscale('symlog', linthresh=1e-2)
# plt.ylim(-1e-2, 1e4)
# plt.xlim(-1e-2, 1e4)