In [None]:
%load_ext autoreload
# %autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere
import matplotlib as mpl
import lib.plot

In [None]:
# d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Bacteroides_B;s__Bacteroides_B dorei
species_id = '102478'

In [None]:
species_depth = (
    pd.read_table('data/ucfmt.a.r.proc.gtpro.species_depth.tsv')
    .assign(species_id=lambda x: x.species_id.astype(str))
    .set_index(['sample', 'species_id'])
    .squeeze()
    .unstack(fill_value=0)
)

pseudo = 1e-3
plt.hist(np.log10(species_depth[species_id] + pseudo), bins=50)
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

pseudo = 1e-6
plt.hist(np.log10(species_rabund[species_id] + pseudo), bins=50)
None

In [None]:
np.random.seed(0)

# fit = sf.World.load(f'data/sp-{species_id}.ucfmt.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts11-s75-seed0.world.nc').collapse_similar_strains(thresh=0.05).drop_low_abundance_strains(thresh=0.01)
fit = sf.World.load(f'data_temp/sp-{species_id}.ucfmt.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts20-s75-seed0.world.nc').collapse_similar_strains(thresh=0.05).drop_low_abundance_strains(thresh=0.01)

# fit_rename = fit.data.copy()'
# fit_rename['sample'] = fit.data.sample.to_series().map(meta['fullname']).to_list()
# fit = sf.data.World(fit_rename)

position_ss = fit.random_sample(position=min(1000, len(fit.position))).position

In [None]:
mgen = pd.read_table('meta/ucfmt/mgen.tsv', index_col=['mgen_id'])
sample = pd.read_table('meta/ucfmt/sample.tsv', index_col=['sample_id'])
subject = pd.read_table('meta/ucfmt/subject.tsv', index_col=['subject_id'])

assert mgen.sample_id.isin(sample.index).all()
assert sample.subject_id.isin(subject.index).all()

meta_all = mgen.join(sample, on='sample_id').join(subject, on='subject_id').assign(
    total_species_depth=species_depth.sum(1),
    species_depth=species_depth[species_id],
    is_fit=lambda x: x.index.to_series().isin(fit.sample.to_series()),
)
meta_all['fullname'] = (
    meta_all
    .groupby(['subject_id', 'sample_type_specific'])
    .apply(lambda df: df.assign(_i=range(len(df.index))))
    .apply(lambda x: x.subject_id + '.' + x.sample_type_specific + '.' + str(x._i), axis=1)
)
meta_all

In [None]:
sf.plot.plot_community(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    scaley=0.05,
)

In [None]:
sf.plot.plot_genotype(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage('position'),
)

In [None]:
duplicated_subject_mgen_id_list = idxwhere(meta_all[lambda x: x.recipient.astype(bool)].duplicated(subset=['subject_id', 'sample_type_specific'], keep=False))
meta_all.loc[duplicated_subject_mgen_id_list].sort_values(['subject_id', 'sample_type_specific'])[['subject_id', 'sample_type_specific', 'total_species_depth']]

In [None]:
drop_mgen_id_list = ['SS01008', 'SS01093c', 'SS01013', 'SS01117', 'SS01120', 'SS01126', 'SS01185']

In [None]:
sample_type_specific_order = [
    'baseline',
    'pre_maintenance_1', 'pre_maintenance_2', 'pre_maintenance_3',
    'pre_maintenance_4', 'pre_maintenance_5', 'pre_maintenance_6',
    'followup_1', 'followup_2', 'followup_3',
]
sample_type_specific_order

In [None]:
m = meta_all.drop(drop_mgen_id_list).drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).total_species_depth.unstack()[sample_type_specific_order]

sns.heatmap(d)

In [None]:
# Total metagenome depth

m = meta_all.drop(drop_mgen_id_list).drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).total_species_depth.unstack()[sample_type_specific_order]
is_fit_annot = m.set_index(['subject_id', 'sample_type_specific']).is_fit.unstack(fill_value=False)[sample_type_specific_order]

sns.heatmap(d, annot=is_fit_annot, norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
# Species depth

m = meta_all.drop(drop_mgen_id_list).drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).species_depth.unstack()[sample_type_specific_order]
is_fit_annot = m.set_index(['subject_id', 'sample_type_specific']).is_fit.unstack(fill_value=False)[sample_type_specific_order]

sns.heatmap(d, annot=is_fit_annot, norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
plt.hist(np.log10(species_rabund[species_id] + 1e-6))

In [None]:
strain_rabund = ((fit.community.data * species_rabund[species_id].to_xarray())).to_pandas().reindex(idxwhere(meta_all.species_depth.notna()), fill_value=0)
strain_rabund

In [None]:
sns.clustermap(strain_rabund, figsize=(10, 15), yticklabels=1)

In [None]:
thresh = 1e-4
color_list = ["#35618f", "#9dd84e", "#6538ac", "#5c922f", "#e033d3", "#61f22d", "#dd8eeb", "#0b5313", "#fd8992", "#20d8fd", "#91322d", "#62ecb6", "#ed0e1c", "#c2dcb8", "#cf115d", "#399283", "#f37d21", "#5310f0", "#f1c039", "#5d4030", "#f8cac2", "#74aff3", "#aa7b1b"]
strain_order = strain_rabund.mean().sort_values(ascending=False).index
strain_palette = lib.plot.construct_ordered_palette_from_list(strain_order, colors=color_list)
m0 = meta_all.drop(drop_mgen_id_list)

subject_order = subject[subject.recipient == 1].index
ncols = 4
nrows = int(np.ceil(len(subject_order) / ncols))

fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(3 * ncols, 2.5 * nrows))

for subject_id, ax in zip(subject_order, axs.flatten()):
    donor_subject_id = subject.loc[subject_id].donor_subject_id

    subject_mgen_list = idxwhere(m0.subject_id == subject_id)
    donor_subject_mgen_list = idxwhere(m0.subject_id == donor_subject_id)

    rabund_donor = strain_rabund.loc[donor_subject_mgen_list].mean()
    donor_strains = idxwhere(strain_rabund.loc[donor_subject_mgen_list].max() >= thresh)

    ax.set_title((subject_id, donor_subject_id))
    try:
        rabund_subject = strain_rabund.loc[subject_mgen_list]
    except KeyError as err:
        print(err)
        continue
    
    subject_strains = idxwhere(rabund_subject.max() >= thresh)
    all_active_strains = list(set(subject_strains) | set(donor_strains))
    
    m1 = m0.loc[subject_mgen_list]
    print(subject_id, m1.shape)
    m2 = m1[['sample_type_specific']].join(strain_rabund).set_index('sample_type_specific').reindex(sample_type_specific_order)

    for strain_id in strain_order:
        ax.plot(m2[strain_id], marker='o', markerfacecolor='none', linestyle='-', color=strain_palette[strain_id])
        ax.scatter([-1], rabund_donor[strain_id], label=strain_id, color=strain_palette[strain_id])

ax.set_yscale('symlog', linthresh=thresh, linscale=0.3)
ax.set_ylim(-thresh, 1.0)

for ax in axs[-1]:
    ax.set_xticks(range(-1, len(sample_type_specific_order)))
    ax.set_xticklabels(['donor'] + sample_type_specific_order, rotation=45, ha='right')
#ax.legend()
# fig.tight_layout()

In [None]:
for strain_id in strain_palette:
    plt.scatter([], [], marker='o', c=strain_palette[strain_id], label=strain_id)
plt.legend()