In [None]:
%load_ext autoreload
# %autoreload

In [None]:
import os
os.chdir('..')
os.path.realpath(os.path.curdir)

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from lib.pandas_util import idxwhere
import matplotlib as mpl
import lib.plot

In [None]:
# d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Bacteroides_B;s__Bacteroides_B dorei
species_id = '102478'  # '100035'  # 

In [None]:
mgen = pd.read_table('meta/ucfmt/mgen.tsv', index_col=['mgen_id'])
sample = pd.read_table('meta/ucfmt/sample.tsv', index_col=['sample_id'])
subject = pd.read_table('meta/ucfmt/subject.tsv', index_col=['subject_id'])

assert mgen.sample_id.isin(sample.index).all()
assert sample.subject_id.isin(subject.index).all()

In [None]:
_species_depth = (
    pd.read_table('data/ucfmt.a.r.proc.gtpro.species_depth.tsv')
    .assign(species_id=lambda x: x.species_id.astype(str))
    .set_index(['sample', 'species_id'])
    .squeeze()
    .unstack(fill_value=0)
)

pseudo = 1e-3
plt.hist(np.log10(_species_depth[species_id] + pseudo), bins=50)
None

In [None]:
m = mgen.join(sample, on='sample_id')[lambda x: x.subject_id.str.startswith('S0')].assign(
    total_species_depth=_species_depth.sum(1)
)
duplicated_subject_mgen_id_list = idxwhere(m.duplicated(subset=['subject_id', 'sample_type_specific'], keep=False))
m.loc[duplicated_subject_mgen_id_list].sort_values(['subject_id', 'sample_type_specific'])[['subject_id', 'sample_type_specific', 'total_species_depth']]

In [None]:
drop_mgen_id_list = ['SS01008', 'SS01093c', 'SS01013', 'SS01117', 'SS01120', 'SS01126', 'SS01185']

In [None]:
np.random.seed(0)

# fit = sf.World.load(f'data/sp-{species_id}.ucfmt.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts11-s75-seed0.world.nc').collapse_similar_strains(thresh=0.05).drop_low_abundance_strains(thresh=0.01)
_fit = sf.World.load(f'data_temp/sp-{species_id}.ucfmt.a.r.proc.gtpro.filt-poly05-cvrg05.ss-g10000-block0-seed0.fit-sfacts11-s75-seed0.world.nc')
# _fit = _fit.collapse_similar_strains(thresh=0.1).drop_low_abundance_strains(thresh=0.05)
_fit = _fit.collapse_similar_strains(thresh=0.01).drop_low_abundance_strains(thresh=0.05)


# fit_rename = fit.data.copy()'
# fit_rename['sample'] = fit.data.sample.to_series().map(meta['fullname']).to_list()
# fit = sf.data.World(fit_rename)

position_ss = _fit.random_sample(position=min(1000, len(_fit.position))).position

In [None]:
sf.evaluation.metagenotype_error2(_fit, discretized=True)

In [None]:
sf.plot.plot_community(
    _fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage(),
)

In [None]:
mgen_id = 'DS0097_035'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
mgen_id = 'DS0097_025'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
mgen_id = 'DS0044_003'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
mgen_id = 'DS0097_013'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
mgen_id = 'DS0097_034'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
mgen_id = 'SS01200'
sf.plot.plot_metagenotype_frequency_spectrum(_fit, mgen_id, bins=51)
_fit.community.data.sel(sample=mgen_id).to_series().sort_values(ascending=False).head()

In [None]:
sf.plot.plot_metagenotype(
    _fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    scaley=0.01,
)

In [None]:
_fit.sizes

In [None]:
sf.plot.plot_genotype(
    _fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage('position'),
)

In [None]:
_meta_all = mgen.drop(drop_mgen_id_list).join(sample, on='sample_id').join(subject, on='subject_id').assign(
    total_species_depth=_species_depth.sum(1),
    species_depth=_species_depth[species_id],
    is_fit=lambda x: x.index.to_series().isin(_fit.sample.to_series()),
)
_meta_all['fullname'] = (
    _meta_all
    .groupby(['subject_id', 'sample_type_specific'])
    .apply(lambda df: df.assign(_i=range(len(df.index))))
    .apply(lambda x: x.subject_id + '.' + x.sample_type_specific + '.' + str(x._i), axis=1)
)
_meta_all

In [None]:
fit = _fit.data.sel(sample=idxwhere(_meta_all.is_fit))
fit['sample'] = _meta_all.loc[idxwhere(_meta_all.is_fit)].fullname.values
fit = sf.World(fit)

meta_all = _meta_all.set_index('fullname')
species_depth = _species_depth.rename(_meta_all.fullname)

meta_all, species_depth = lib.plot.align_indexes(meta_all, species_depth)

pseudo = 1e-3
plt.hist(np.log10(species_depth[species_id] + pseudo), bins=50)
None

In [None]:
species_rabund = species_depth.divide(species_depth.sum(1), axis=0)

pseudo = 1e-6
plt.hist(np.log10(species_rabund[species_id] + pseudo), bins=50)
None

In [None]:
sf.plot.plot_community(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    row_linkage_func=lambda w: w.genotype.linkage(),
)

In [None]:
sf.plot.plot_metagenotype(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage(),
    scaley=0.01,
)

In [None]:
sf.plot.plot_genotype(
    fit.sel(position=position_ss),
    col_linkage_func=lambda w: w.metagenotype.linkage('position'),
)

In [None]:
sample_type_specific_order = [
    'baseline',
    'pre_maintenance_1', 'pre_maintenance_2', 'pre_maintenance_3',
    'pre_maintenance_4', 'pre_maintenance_5', 'pre_maintenance_6',
    'followup_1', 'followup_2', 'followup_3',
]
sample_type_specific_order

In [None]:
m = meta_all.drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).total_species_depth.unstack()[sample_type_specific_order]

sns.heatmap(d)

In [None]:
# Total metagenome depth

m = meta_all.drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).total_species_depth.unstack()[sample_type_specific_order]
is_fit_annot = m.set_index(['subject_id', 'sample_type_specific']).is_fit.unstack(fill_value=False)[sample_type_specific_order]

sns.heatmap(d, annot=is_fit_annot, norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
# Species depth

m = meta_all.drop(idxwhere(meta_all.recipient == 0))
d = m.set_index(['subject_id', 'sample_type_specific']).species_depth.unstack()[sample_type_specific_order]
is_fit_annot = m.set_index(['subject_id', 'sample_type_specific']).is_fit.unstack(fill_value=False)[sample_type_specific_order]

sns.heatmap(d, annot=is_fit_annot, norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
plt.hist(np.log10(species_rabund[species_id] + 1e-6))

In [None]:
strain_depth = ((fit.community.data * species_depth[species_id].to_xarray())).to_pandas().reindex(idxwhere(meta_all.species_depth.notna()), fill_value=0)
strain_rabund = ((fit.community.data * species_rabund[species_id].to_xarray())).to_pandas().reindex(idxwhere(meta_all.species_depth.notna()), fill_value=0)
strain_rabund

In [None]:
sns.clustermap(strain_rabund, figsize=(10, 15), yticklabels=1)

In [None]:
color_list = ["#91322d", "#62ecb6", "#ed0e1c", "#c2dcb8", "#cf115d", "#399283", "#f37d21", "#5310f0", "#f1c039", "#5d4030", "#f8cac2", "#74aff3", "#aa7b1b", "#35618f", "#9dd84e", "#6538ac", "#5c922f", "#e033d3", "#61f22d", "#dd8eeb", "#0b5313", "#fd8992", "#20d8fd"]
strain_order = strain_rabund.mean().sort_values(ascending=False).index
strain_palette = lib.plot.construct_ordered_palette_from_list(strain_order, colors=color_list)

thresh = 1e-4
sample_type_x = pd.Series(dict(
    donor=-1,
    baseline=0,
    pre_maintenance_1=1,
    pre_maintenance_2=2,
    pre_maintenance_3=3,
    pre_maintenance_4=4,
    pre_maintenance_5=5,
    pre_maintenance_6=6,
    followup_1=7,
    followup_2=8,
    followup_3=9,
))


subject_order = ['S0001', 'S0027', 'S0053', 'S0059', 'S0063']
ncols = 5
nrows = int(np.ceil(len(subject_order) / ncols))

fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(3 * ncols, 2.5 * nrows))
axs = np.reshape(axs, (nrows, ncols))

for subject_id, ax in zip(subject_order, axs.flatten()):
    donor_id = subject.loc[subject_id].donor_subject_id

    donor_sample_list = idxwhere(meta_all.subject_id == donor_id)
    donor_mean_sample_rabund = strain_rabund.loc[donor_sample_list].mean()
    donor_strain_list = idxwhere((donor_mean_sample_rabund > thresh))

    try:
        baseline_sample = meta_all[lambda x: (x.sample_type_specific == 'baseline') & (x.subject_id == subject_id)].index[0]
    except IndexError:
        baseline_strain_list = []
    else:
        baseline_sample_rabund = strain_rabund.loc[baseline_sample]
        baseline_strain_list = idxwhere((baseline_sample_rabund > thresh))

    subject_sample_list = idxwhere(meta_all.subject_id == subject_id)
    common_strains = idxwhere((strain_rabund.loc[subject_sample_list + donor_sample_list] > thresh * 10).sum() >= 3)
    focal_strains = list(set(donor_strain_list) | set(baseline_strain_list) | set(common_strains))

    d0 = strain_rabund.loc[subject_sample_list].join(meta_all[['sample_type_specific']]).set_index('sample_type_specific')
    d1 = pd.concat([d0, donor_mean_sample_rabund.to_frame('donor').T])
    d2 = d1[focal_strains].assign(other=d1.sum(1) - d1[focal_strains].sum(1)).assign(x=sample_type_x).sort_values('x')
    d3 = d2.drop('donor')

    for strain in focal_strains + ['other']:
        ax.plot(d3['x'], d3[strain], c=strain_palette[strain], marker='o', alpha=0.7, lw=2)
        if donor_sample_list:
            ax.scatter(d2.loc['donor', 'x'], d2.loc['donor', strain], c=strain_palette[strain], marker='o', alpha=0.7, s=70)
    ax.set_title((subject_id, donor_id))

ax.set_yscale('symlog', linthresh=thresh)
ax.set_ylim(0, 1)

for ax in axs[-1]:
    ax.set_xticks(sample_type_x.values)
    ax.set_xticklabels(sample_type_x.index, rotation=45, ha='right')

In [None]:
color_list = ["#91322d", "#62ecb6", "#ed0e1c", "#c2dcb8", "#cf115d", "#399283", "#f37d21", "#5310f0", "#f1c039", "#5d4030", "#f8cac2", "#74aff3", "#aa7b1b", "#35618f", "#9dd84e", "#6538ac", "#5c922f", "#e033d3", "#61f22d", "#dd8eeb", "#0b5313", "#fd8992", "#20d8fd"]
strain_order = strain_rabund.mean().sort_values(ascending=False).index
strain_palette = lib.plot.construct_ordered_palette_from_list(strain_order, colors=color_list)

for strain_id in strain_palette:
    plt.scatter([], [], marker='o', c=strain_palette[strain_id], label=strain_id)
plt.legend()

In [None]:
color_list = ["#91322d", "#62ecb6", "#ed0e1c", "#c2dcb8", "#cf115d", "#399283", "#f37d21", "#5310f0", "#f1c039", "#5d4030", "#f8cac2", "#74aff3", "#aa7b1b", "#35618f", "#9dd84e", "#6538ac", "#5c922f", "#e033d3", "#61f22d", "#dd8eeb", "#0b5313", "#fd8992", "#20d8fd"]
strain_order = strain_rabund.mean().sort_values(ascending=False).index
strain_palette = lib.plot.construct_ordered_palette_from_list(strain_order, colors=color_list, other='lightgrey')

thresh = 1e-4
sample_type_x = pd.Series(dict(
    donor=-2,
    baseline=0,
    pre_maintenance_1=1,
    pre_maintenance_2=2,
    pre_maintenance_3=3,
    pre_maintenance_4=4,
    pre_maintenance_5=5,
    pre_maintenance_6=6,
    followup_1=7,
    followup_2=8,
    followup_3=9,
))


all_subjects = idxwhere(meta_all[meta_all.recipient.astype(bool)].subject_id.value_counts() > 2)
focal_subjects = ['S0001', 'S0027', 'S0053', 'S0059', 'S0063']
subject_order = focal_subjects + [s for s in all_subjects if s not in focal_subjects]

ncols = 5
nrows = int(np.ceil(len(subject_order) / ncols))

fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True, figsize=(3 * ncols, 2.5 * nrows))
axs = np.reshape(axs, (nrows, ncols))

for subject_id, ax in zip(subject_order, axs.flatten()):
    donor_id = subject.loc[subject_id].donor_subject_id

    donor_sample_list = idxwhere(meta_all.subject_id == donor_id)
    donor_mean_sample_rabund = strain_rabund.loc[donor_sample_list].mean()
    donor_strain_list = idxwhere((donor_mean_sample_rabund > thresh))

    try:
        baseline_sample = meta_all[lambda x: (x.sample_type_specific == 'baseline') & (x.subject_id == subject_id)].index[0]
    except IndexError:
        baseline_strain_list = []
    else:
        baseline_sample_rabund = strain_rabund.loc[baseline_sample]
        baseline_strain_list = idxwhere((baseline_sample_rabund > thresh))

    subject_sample_list = idxwhere(meta_all.subject_id == subject_id)
    common_strains = idxwhere((strain_rabund.loc[subject_sample_list + donor_sample_list] > thresh * 10).sum() >= 3)
    focal_strains = list(set(donor_strain_list) | set(baseline_strain_list) | set(common_strains))

    d0 = strain_rabund.loc[subject_sample_list].join(meta_all[['sample_type_specific']]).set_index('sample_type_specific')
    d1 = pd.concat([d0, donor_mean_sample_rabund.to_frame('donor').T])
    d2 = d1[focal_strains].assign(other=d1.sum(1) - d1[focal_strains].sum(1)).assign(x=sample_type_x).sort_values('x')
    d3 = d2.drop('donor')

    for strain in ['other'] + focal_strains:
        ax.plot(d3['x'], d3[strain], c=strain_palette[strain], marker='o', alpha=0.7, lw=2)
        if donor_sample_list:
            ax.scatter(d2.loc['donor', 'x'], d2.loc['donor', strain], c=strain_palette[strain], marker='o', alpha=0.7, s=70)
    ax.set_title(f"{subject_id} ({donor_id})")

ax.set_yscale('symlog', linthresh=thresh)
ax.set_ylim(0, 1)

for ax in axs[-1]:
    ax.set_xticks(sample_type_x.values)
    ax.set_xticklabels(sample_type_x.index, rotation=45, ha='right')
    
fig.savefig(f'fig/ucfmt_engraftment_{species_id}.pdf', bbox_inches='tight')
fig.savefig(f'fig/ucfmt_engraftment_{species_id}.png', dpi=400, bbox_inches='tight')