In [None]:
import os as _os

_os.chdir('..')

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
import sfacts as sf

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
import matplotlib as mpl
import scipy as sp
from operator import eq
from itertools import cycle
from lib.pandas_util import idxwhere
import lib.plot
from glob import glob
from tqdm import tqdm
import os
import warnings

In [None]:
mgen = pd.read_table('meta/hmp2/mgen.tsv', index_col='library_id')
preparation = pd.read_table('meta/hmp2/preparation.tsv', index_col='preparation_id')
stool = pd.read_table('meta/hmp2/stool.tsv', index_col='stool_id')
visit = pd.read_table('meta/hmp2/visit.tsv', index_col='visit_id')
subject = pd.read_table('meta/hmp2/subject.tsv', index_col='subject_id')

meta_all = (
    mgen
    .join(preparation.drop(columns='library_type'), on='preparation_id')
    .join(stool, on='stool_id')
    .join(visit, on='visit_id', rsuffix='_')
    .join(subject, on='subject_id')
    .assign(new_name=lambda x: (
        x[['subject_id', 'week_number']]
        .assign(library_id=x.index)
        .assign(week_number=lambda x: x.week_number.fillna(999).astype(int))
        .apply(lambda x: '_'.join(x.astype(str)), axis=1)
    ))
    # .reset_index()
    # .set_index('new_name')
)

library_id_to_new_name = meta_all.new_name

assert not any(meta_all.subject_id.isna())

# TODO: Rename samples based on subject and visit number
# TODO: Drop duplicate stools

In [None]:
species_list = (
    pd.read_table("meta/species_group.tsv")[
        lambda x: x.species_group_id == "hmp2"
    ]
    .species_id.astype(str)
    .unique()
)

In [None]:
all_stats = []
missing_species = []
for species_id in species_list:
    inpath = f'data/group/hmp2/species/sp-{species_id}/r.proc.gtpro.filt-poly05-cvrg05.mgtp.nc'
    if not os.path.exists(inpath):
        missing_species.append(species_id)
        continue
    mgtp1 = sf.Metagenotype.load(inpath)
    # mgtp1.data['sample'] = library_id_to_new_name.loc[mgtp1.sample].to_list()
    meta1 = meta_all.loc[mgtp1.sample]

    heterogeneity_stats1 = {}
    for subject_id in meta1.subject_id.unique():
        sample_list = idxwhere(meta1.sort_values('week_number').subject_id == subject_id)
        mgtp2 = mgtp1.sel(sample=sample_list)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            heterogeneity_stats1[subject_id] = dict(
                num_samples = len(sample_list),
                mean_pairwise_mgen_diss = np.mean(sp.spatial.distance.squareform(mgtp2.pdist())),
                mean_entropy = ((mgtp2.entropy() * mgtp2.mean_depth()).sum() / mgtp2.mean_depth().sum()).values,
                mean_squared_entropy = ((mgtp2.entropy().pipe(np.square) * mgtp2.mean_depth()).sum() / mgtp2.mean_depth().sum()).values,
                mean_depth = mgtp2.mean_depth().values.mean(),
            )

    heterogeneity_stats1 = pd.DataFrame(heterogeneity_stats1).T.assign(
        multiple_samples=lambda x: x.num_samples > 1,
        probable_strain_heterogeneity=lambda x: x.mean_entropy > 0.025,
        probable_strain_transition=lambda x: x.mean_pairwise_mgen_diss > 0.1,
        species_id=species_id
    )
    all_stats.append(heterogeneity_stats1)

In [None]:
all_stats = pd.concat(all_stats).rename_axis(index='subject_id').reset_index().set_index(['species_id', 'subject_id'])

In [None]:
d1 = all_stats[lambda x: x.multiple_samples].groupby(['species_id', 'probable_strain_heterogeneity']).apply(len).to_frame(name='tally').reset_index().set_index(['species_id', 'probable_strain_heterogeneity']).tally.unstack(fill_value=0).assign(num=lambda x: x[False] + x[True]).assign(het_frac=lambda x: x[True] / x.num).rename(columns={True: 'het', False: 'nohet'})
d2 = all_stats[lambda x: x.multiple_samples].groupby(['species_id', 'probable_strain_transition']).apply(len).to_frame(name='tally').reset_index().set_index(['species_id', 'probable_strain_transition']).tally.unstack(fill_value=0).assign(num=lambda x: x[False] + x[True]).assign(trans_frac=lambda x: x[True] / x.num).rename(columns={True: 'trans', False: 'notrans'})

d = d1.join(d2[['notrans', 'trans', 'trans_frac']])

fig, axs = plt.subplots(nrows=2, figsize=(60, 10), sharex=True)
d[d.num > 5].sort_values('num', ascending=False)[['het', 'nohet']].plot.bar(stacked=True, ax=axs[0])
d[d.num > 5].sort_values('num', ascending=False)[['trans', 'notrans']].plot.bar(stacked=True, ax=axs[1])

In [None]:
plt.scatter('het_frac', 'trans_frac', data=d, c='num')
sns.regplot('het_frac', 'trans_frac', data=d, scatter=False, lowess=True)
plt.xscale('logit')
plt.yscale('logit')
print(sp.stats.spearmanr(d['het_frac'], d['trans_frac']))

In [None]:
species_id = '102478'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
print(_mgtp.sizes)

sf.plot.plot_metagenotype(
    _mgtp.random_sample(position=min(_mgtp.sizes['position'], 500)),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    col_colors_func=lambda w: w.metagenotype.entropy() > 0.05,
)

In [None]:
species_id = '102517'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
print(_mgtp.sizes)

sf.plot.plot_metagenotype(
    _mgtp.random_sample(position=min(_mgtp.sizes['position'], 500)),
    col_linkage_func=lambda w: _mgtp.linkage(),
    col_colors_func=lambda w: _mgtp.entropy() > 0.05,
)

In [None]:
species_id = '102327'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
print(_mgtp.sizes)

sf.plot.plot_metagenotype(
    _mgtp.random_sample(position=min(_mgtp.sizes['position'], 500)),
    col_linkage_func=lambda w: _mgtp.linkage(),
    col_colors_func=lambda w: _mgtp.entropy() > 0.05,
)

In [None]:
species_taxonomy = pd.read_table('ref/gtpro/species_taxonomy_ext.tsv', names=['genome_id', 'species_id', 'taxonomy_string']).assign(species_id=lambda x: x.species_id.astype(str)).set_index('species_id')[['taxonomy_string']].assign(taxonomy_split=lambda x: x.taxonomy_string.str.split(';'))

for level_name, level_number in [('p__', 1), ('c__', 2), ('o__', 3), ('f__', 4), ('g__', 5), ('s__', 6)]:
    species_taxonomy = species_taxonomy.assign(**{level_name: species_taxonomy.taxonomy_split.apply(lambda x: x[level_number])}) 
species_taxonomy = species_taxonomy.drop(columns=['taxonomy_split'])

In [None]:
species_taxonomy.loc[all_stats.index.get_level_values('species_id').unique()]['p__'].value_counts()

In [None]:
d1 = all_stats[lambda x: x.multiple_samples].groupby(['species_id', 'probable_strain_heterogeneity']).apply(len).to_frame(name='tally').reset_index().set_index(['species_id', 'probable_strain_heterogeneity']).tally.unstack(fill_value=0).assign(num=lambda x: x[False] + x[True]).assign(het_frac=lambda x: x[True] / x.num).rename(columns={True: 'het', False: 'nohet'})
d2 = all_stats[lambda x: x.multiple_samples].groupby(['species_id', 'probable_strain_transition']).apply(len).to_frame(name='tally').reset_index().set_index(['species_id', 'probable_strain_transition']).tally.unstack(fill_value=0).assign(num=lambda x: x[False] + x[True]).assign(trans_frac=lambda x: x[True] / x.num).rename(columns={True: 'trans', False: 'notrans'})

d = (
    d1
    .join(d2[['notrans', 'trans', 'trans_frac']])
    .join(species_taxonomy)
    # [lambda x: x.taxonomy_string.str.startswith('d__Bacteria;p__Firmicutes_A')]
    .sort_values(['p__', 'num'])
    .reset_index().assign(species_and_tax=lambda x: x.species_id + '_' + x.p__).set_index('species_and_tax')
)

fig, axs = plt.subplots(nrows=2, figsize=(60, 10), sharex=True)
d[d.num > 5][['het', 'nohet']].plot.bar(stacked=True, ax=axs[0])
d[d.num > 5][['trans', 'notrans']].plot.bar(stacked=True, ax=axs[1])

In [None]:
species_id = '103682'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
_mgtp.data['sample'] = library_id_to_new_name.loc[_mgtp.sample].to_list()
print(_mgtp.sizes)

sf.plot.plot_metagenotype(
    _mgtp.random_sample(position=min(_mgtp.sizes['position'], 500)).mlift('sortby', 'sample'),
    col_linkage_func=lambda w: _mgtp.linkage(),
    col_colors_func=lambda w: _mgtp.entropy() > 0.05,
    col_cluster=False,
)

In [None]:
subject_by_species_stats = (
    all_stats
     .reset_index()
     .rename(columns=dict(
         num_samples='num_mgtp_samples',
         mean_pairwise_mgen_diss='mean_pairwise_mgtp_diss',
         mean_entropy='mean_mgtp_entropy',
         mean_depth='mean_mgtp_depth',
     ))
    .drop(columns=[
        'multiple_samples',
        'probable_strain_heterogeneity',
        'probable_strain_transition',
        'mean_squared_entropy',
    ])
)

subject_by_species_stats.to_csv('data/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.subject_by_species_stats.tsv', sep='\t', index=False)
subject_by_species_stats

In [None]:
subject_by_species_stats = pd.read_table(
    'data/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.subject_by_species_stats.tsv',
    dtype=dict(species_id=str),
)

In [None]:
species_depth_uncleaned = pd.read_table(
    'data/hmp2.a.r.proc.gtpro.species_depth.tsv',
    dtype=dict(sample=str, species_id=str, depth=float)
).rename(columns=dict(sample='library_id')).set_index(['library_id', 'species_id']).depth.unstack(fill_value=0)
species_depth = species_depth_uncleaned.drop(idxwhere(species_depth_uncleaned.sum(1) < 100))
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)
species_rabund

In [None]:
strain_depth_uncleaned = pd.read_table(
    'data/hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts18-s75-seed0.clean-diss05-abund05-entr100.strain_depth.tsv',
    dtype=dict(sample=str, strain=str, depth=float)
).rename(columns=dict(sample='library_id')).set_index(['library_id', 'strain']).depth.unstack(fill_value=0)
strain_depth = strain_depth_uncleaned.drop(idxwhere(strain_depth_uncleaned.sum(1) < 100))
strain_rabund = strain_depth.divide(strain_depth.sum(1), axis=0)
strain_rabund

In [None]:
species_genome_length = pd.read_table('ref_temp/midasdb_uhgg.metadata.tsv', dtype=dict(species_id=str, Length=int), index_col='species_id').rename(columns={'Length': 'genome_length'}).genome_length

In [None]:
d = (
    species_depth_uncleaned
    .sum(1)
    .to_frame('total_species_depth')
    .join(mgen.sequenced_reads)
    .assign(expect_total_sequence=(species_depth_uncleaned * species_genome_length).fillna(0).sum(1))
)
plt.scatter('sequenced_reads', 'expect_total_sequence', data=d, s=1)
plt.plot([0, 1e7], [0, 1e7 * 95])
sp.stats.pearsonr(d.sequenced_reads, d.total_species_depth)

In [None]:
num_mgtp_samples_matrix = subject_by_species_stats.set_index(['subject_id', 'species_id']).num_mgtp_samples.unstack(fill_value=0)
num_depth_samples_matrix = (species_depth > 0.05).groupby(meta_all.subject_id).sum()

In [None]:
species_subject_prevalence = (species_depth > 0.05).groupby(meta_all.subject_id).any().sum().sort_values(ascending=False)
most_prevalent_species = idxwhere((species_subject_prevalence > 85))
len(most_prevalent_species)

In [None]:
d0 = subject_by_species_stats.join(meta_all.groupby('subject_id')['ibd_diagnosis'].first(), on='subject_id')
sns.boxplot(
    x='species_id',
    y='mean_mgtp_entropy',
    hue='ibd_diagnosis',
    data=d0[d0.species_id.isin(most_prevalent_species)]
)
lib.plot.rotate_xticklabels()

In [None]:
d0 = subject_by_species_stats.join(meta_all.groupby('subject_id')['ibd_diagnosis'].first(), on='subject_id')
d1 = d0[d0.species_id == '102492']
sp.stats.mannwhitneyu(
    d1[d1.ibd_diagnosis == 'nonIBD'].mean_mgtp_entropy.astype(float).dropna(),
    d1[d1.ibd_diagnosis != 'nonIBD'].mean_mgtp_entropy.astype(float).dropna()
)

In [None]:
feat = 'ibd_diagnosis'

d0 = subject_by_species_stats.join(meta_all.groupby('subject_id')[feat].first() != 'nonIBD', on='subject_id')
d1 = d0[d0.species_id == '102492']

sns.boxplot(
    x='species_id',
    y='mean_mgtp_entropy',
    hue=feat,
    data=d1,
)
lib.plot.rotate_xticklabels()

lib.stats.mannwhitneyu(feat, 'mean_mgtp_entropy', data=d1)

In [None]:
species_taxonomy.loc['102492']

In [None]:
d = subject.assign(mean_rabund_102492=species_rabund.groupby(meta_all.subject_id).mean()['102492']).join(subject_by_species_stats[lambda x: x.species_id == '102492'].set_index('subject_id'))

fig, axs = plt.subplots(2)
sns.boxplot('ibd_diagnosis', 'mean_rabund_102492', data=d, ax=axs[0])
sns.boxplot('ibd_diagnosis', 'mean_mgtp_entropy', data=d, ax=axs[1])

In [None]:
species_id = '102492'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
_mgtp.data['sample'] = library_id_to_new_name.loc[_mgtp.sample].to_list()
print(_mgtp.sizes)

sf.plot.plot_metagenotype(
    _mgtp.random_sample(position=min(_mgtp.sizes['position'], 500)).mlift('sortby', 'sample'),
    col_linkage_func=lambda w: _mgtp.linkage(),
    col_colors_func=lambda w: _mgtp.entropy() > 0.05,
    col_cluster=False,
)

In [None]:
species_id = '102492'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
d = _mgtp.entropy().to_series().to_frame().join(meta_all)
plt.hist(d[~d.status_antibiotics].entropy, density=True)
plt.hist(d[d.status_antibiotics].entropy, density=True)

lib.stats.mannwhitneyu('status_antibiotics', 'entropy', data=d)
# sns.stripplot('status_antibiotics', 'entropy', data=d, s=1)

In [None]:
_fit = sf.World.load('data_temp/sp-102492.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts27-s75-seed0.world.nc')


d = pd.DataFrame(dict(mgtp_entropy=_fit.metagenotype.entropy(), comm_entropy=_fit.community.entropy(), mgtp_depth=_fit.metagenotype.mean_depth()))
plt.scatter('mgtp_entropy', 'comm_entropy', data=d, s=5, c='mgtp_depth', norm=mpl.colors.LogNorm())
sns.regplot('mgtp_entropy', 'comm_entropy', data=d, scatter=False)

# plt.yscale('symlog', linthresh=1e-5)
# plt.xscale('symlog', linthresh=1e-5)

In [None]:
sp.stats.mannwhitneyu(d1[d1[feat] == False].mean_mgtp_entropy, d1[d1[feat] == True].mean_mgtp_entropy)

In [None]:
d = subject_by_species_stats.set_index(['subject_id', 'species_id']).mean_pairwise_mgtp_diss.astype(float).unstack().fillna(0)
sns.clustermap(d[most_prevalent_species] + 1e-5, metric='cosine')

In [None]:
d = subject_by_species_stats.set_index(['subject_id', 'species_id']).mean_mgtp_entropy.astype(float).unstack().fillna(0)
sns.clustermap(d[most_prevalent_species] + 1e-5, metric='cosine')

In [None]:
d = subject_by_species_stats.set_index(['subject_id', 'species_id']).mean_pairwise_mgtp_diss.to_frame().join(subject_by_species_stats.set_index(['subject_id', 'species_id']).mean_mgtp_entropy).dropna()
plt.scatter('mean_mgtp_entropy', 'mean_pairwise_mgtp_diss', data=d.xs('100022', level='species_id'), s=1)
plt.xscale('symlog', linthresh=1e-5)
plt.yscale('symlog', linthresh=1e-5)

In [None]:
from itertools import product
from scipy.spatial.distance import squareform, pdist

r = strain_rabund
all_strain_bc_diss = pd.DataFrame(squareform(pdist(r, metric='braycurtis')), index=r.index, columns=r.index)
all_strain_jc_diss = pd.DataFrame(squareform(pdist(r, metric='jaccard')), index=r.index, columns=r.index)

In [None]:
from itertools import product
from scipy.spatial.distance import squareform, pdist

r = species_rabund
all_species_bc_diss = pd.DataFrame(squareform(pdist(r, metric='braycurtis')), index=r.index, columns=r.index)
all_species_jc_diss = pd.DataFrame(squareform(pdist(r, metric='jaccard')), index=r.index, columns=r.index)

In [None]:
m = meta_all['subject_id'].to_frame()
diff_subj = pd.DataFrame(squareform(pdist(m, metric=lambda x, y: x != y)), index=m.index, columns=m.index)

In [None]:
species_id = '102492'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
mgtp_102492_diss = _mgtp.pdist()

In [None]:
species_id = '102506'
_mgtp = sf.Metagenotype.load(f'data/sp-{species_id}.hmp2.a.r.proc.gtpro.filt-poly05-cvrg05.mgen.nc')
mgtp_102506_diss = _mgtp.pdist()

In [None]:
x = all_strain_bc_diss
y = mgtp_102492_diss
z = diff_subj.astype(bool)

idx = list(set(x.index) & set(y.index) & set(z.index))
d = pd.DataFrame(dict(
    x=squareform(x.loc[idx, idx]),
    y=squareform(y.loc[idx, idx]),
    z=squareform(z.loc[idx, idx]),
))

plt.scatter('x', 'y', data=d, s=1, c='z')

In [None]:
x = all_strain_bc_diss
y = mgtp_102506_diss
z = diff_subj.astype(bool)

idx = list(set(x.index) & set(y.index) & set(z.index))
d = pd.DataFrame(dict(
    x=squareform(x.loc[idx, idx]),
    y=squareform(y.loc[idx, idx]),
    z=squareform(z.loc[idx, idx]),
))

plt.scatter('x', 'y', data=d, s=1, c='z')

In [None]:
x = all_strain_bc_diss
y = mgtp_102506_diss
z = diff_subj.astype(bool)

idx = list(set(x.index) & set(y.index) & set(z.index))
d = pd.DataFrame(dict(
    x=squareform(x.loc[idx, idx]),
    y=squareform(y.loc[idx, idx]),
    z=squareform(z.loc[idx, idx]),
))


bins = np.linspace(0, 1.4, num=51)
plt.hist(d.y[d.z], bins=bins, label='diff')
plt.hist(d.y[~d.z], bins=bins, label='same', alpha=0.5)
plt.yscale('log')

In [None]:
x = all_strain_bc_diss
y = mgtp_102492_diss
z = diff_subj.astype(bool)

idx = list(set(x.index) & set(y.index) & set(z.index))
d = pd.DataFrame(dict(
    x=squareform(x.loc[idx, idx]),
    y=squareform(y.loc[idx, idx]),
    z=squareform(z.loc[idx, idx]),
))


bins = np.linspace(0, 1.4, num=51)
plt.hist(d.y[d.z], bins=bins, label='diff')
plt.hist(d.y[~d.z], bins=bins, label='same', alpha=0.5)
plt.yscale('log')

In [None]:
x = all_species_bc_diss
y = all_strain_bc_diss
z = diff_subj.astype(bool)

idx = list(set(x.index) & set(y.index) & set(z.index))
d = pd.DataFrame(dict(
    x=squareform(x.loc[idx, idx]),
    y=squareform(y.loc[idx, idx]),
    z=squareform(z.loc[idx, idx]),
))

plt.scatter('x', 'y', data=d, s=1, c='z')