In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere


import sfacts as sf

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## UCFMT Strain Tracking

In [None]:
# fit_ucfmt_100022 = sf.data.World.load('data/ucfmt.sp-100022.metagenotype.filt-poly05-cvrg15-g2000.fit-sfacts8-s100-seed0.world.nc')
fit = sf.data.World.load('data/ucfmt.sp-100022.metagenotype.filt-poly05-cvrg05-g2000.fit-sfacts12-s100-g2000-seed0.world.nc')


In [None]:
sf.plot.plot_community(
    fit,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
)

In [None]:
sf.plot.plot_genotype(
    fit,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    transpose=True,
)

In [None]:
sample = pd.read_table('meta/sample.tsv', index_col='sample_id')
library = pd.read_table('meta/mgen.tsv', index_col='mgen_id')
meta = library.join(sample, on='sample_id', rsuffix='_sample', lsuffix='_mgen').loc[fit.sample]
sample_type_specific_order = [
    'baseline',
    'pre_maintenance_1',
    'pre_maintenance_2',
    'pre_maintenance_3',
    'pre_maintenance_4',
    'pre_maintenance_5',
    'pre_maintenance_6',
    'followup_1',
    'followup_2'
]

sample = sample.loc[meta.sample_id.unique()]

In [None]:
duplicate_samples = idxwhere(meta.sample_id.duplicated(keep=False))

duplicate_samples
# I need to drop one of the two, so I'm going to drop the 'b' variant

In [None]:
sf.data.Metagenotypes.to

In [None]:
d = fit.sel(sample=idxwhere(meta.sample_id.duplicated(keep=False)))

sf.plot.plot_metagenotype(d, scaley=2e-3)
sf.plot.plot_community(d)

In [None]:
fit.data.mu.sel(sample=duplicate_samples)

In [None]:
sample[['subject_id', 'sample_type']]

In [None]:
rabund = fit.communities.to_series().unstack().groupby(meta.sample_id).mean()

In [None]:
d1.T['followup_1']

In [None]:
sample.subject_id.unique()

In [None]:
d0 = rabund.join(sample[['subject_id', 'sample_type_specific']]).groupby(['subject_id', 'sample_type_specific']).mean()


strain_color_palette = lib.plot.construct_ordered_pallete(d0.columns, cm='tab20_r')
subject_id_list = [
    'S0001', 'S0056', 'S0053',
#     'S0004',
#     'S0013', 'S0008',
#     'S0024',
#     'S0021',
#     'S0053', 'S0047',
#     'S0056',
]

ncol = 3
nrow = int(np.ceil(len(subject_id_list) / ncol))

fig, axs = plt.subplots(nrow, ncol, figsize=(6 * ncol, 4 * nrow))


for subject_id, ax in zip(subject_id_list, axs.flatten()):
    d1 = d0.xs(subject_id).reindex(sample_type_specific_order)
    for i in d1.columns:
        ax.plot(d1[i].values, c=strain_color_palette[i], lw=2)
    ax.set_yscale('symlog', linthresh=1e-2, linscale=0.1)
    ax.set_xticks(range(len(sample_type_specific_order)))
    ax.set_ylim(1e-2 - 1e-3, 1e0 + 1e-1)
    ax.set_title(subject_id)
    yticks = 10. ** np.arange(-3, 1)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f'{y:.0%}' for y in yticks])

In [None]:
rabund.groupby(sample.subject_id).mean()

In [None]:
d0 = rabund.join(sample[['subject_id', 'sample_type_specific']]).groupby(['subject_id', 'sample_type_specific']).mean()


strain_color_palette = lib.plot.construct_ordered_pallete(d0.columns, cm='tab20_r')
subject_id_list = [
    'S0001', 'S0056', 'S0053',
    'S0004',
    'S0013', 'S0008',
    'S0024',
    'S0021',
    'S0053', 'S0047',
    'S0056',
]

ncol = 3
nrow = int(np.ceil(len(subject_id_list) / ncol))

fig, axs = plt.subplots(nrow, ncol, figsize=(6 * ncol, 4 * nrow))


for subject_id, ax in zip(subject_id_list, axs.flatten()):
    d1 = d0.xs(subject_id).reindex(sample_type_specific_order)
    for i in d1.columns:
        ax.plot(d1[i].values, c=strain_color_palette[i], lw=2)
    ax.set_yscale('symlog', linthresh=1e-2, linscale=0.1)
    ax.set_xticks(range(len(sample_type_specific_order)))
    ax.set_ylim(1e-2 - 1e-3, 1e0 + 1e-1)
    ax.set_title(subject_id)
    yticks = 10. ** np.arange(-3, 1)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f'{y:.0%}' for y in yticks])