# Notebook Setup

In [None]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import statsmodels.formula.api as sm
import sqlite3
import seaborn as sns
import patsy
from sklearn.decomposition import PCA
from lifelines import KaplanMeierFitter
from matplotlib.ticker import StrMethodFormatter
from statsmodels.stats.multitest import fdrcorrection
import itertools

import matplotlib as mpl

import rpy2.ipython
%load_ext rpy2.ipython.rmagic

from scripts.lib.stats import raise_low, lrt_phreg, phreg_aic, mannwhitneyu
from scripts.lib.plotting import boxplot_with_points, load_style, residuals_plot
from skbio.diversity.alpha import chao1, simpson_e
from skbio.stats import subsample_counts
from skbio import DistanceMatrix
from skbio.stats.ordination import pcoa

concat = lambda list_of_lists: list(itertools.chain(*list_of_lists))
richness = lambda x: (x > 0).sum()

def order_within(df, groupby, sortby, ascending=True):
    return (df.sort_values(sortby, ascending=ascending)
              .groupby(groupby)
              .apply(lambda x: pd.Series(range(len(x)),
                                         index=x.index,
                                         name='order'))
              .reset_index(level=0)
              ['order'])

In [None]:
loaded_style = load_style('paper')

color_map = loaded_style['color_map']
mark_map = loaded_style['mark_map']
assign_significance_symbol = loaded_style['assign_significance_symbol']
savefig = loaded_style['savefig']
fullwidth = loaded_style['fullwidth']

In [None]:
from scripts.lib.data import load_data
loaded_data = load_data('res/C2013.results.db')

con = loaded_data['con']
conc = loaded_data['conc']
mols = loaded_data['mols']
mol_c_count = loaded_data['mol_c_count']
mouse = loaded_data['mouse']

In [None]:
count = (pd.read_sql("""
    SELECT mouse_id, rrs_library_id, taxon_id, tally
    FROM rrs_library_taxon_count
    JOIN rrs_library_metadata USING (rrs_library_id)
    WHERE taxon_level = 'unique'
    AND spike_id NOT NULL
                    """, con=con,
                    index_col=['mouse_id', 'rrs_library_id', 'taxon_id'])
        # Reshape into wide-format
        ['tally'].unstack().fillna(0).astype(int)
        # Drop libraries without an associated mouse_id
        .reset_index().dropna(subset=['mouse_id']).set_index('mouse_id')
        .drop('rrs_library_id', axis='columns')
        )
assert count.index.is_unique
rabund = count.apply(lambda x: x / x.sum(), axis='columns')


In [None]:
taxonomy = (pd.read_sql(
    """
    SELECT taxon_id AS seq_id, taxon_id_b AS otu
    FROM taxonomy
    WHERE taxon_level = 'unique'
      AND taxon_level_b = 'otu-0.03'
      AND confidence > 0.7
    """,
    con=con, index_col=['seq_id']))
taxonomy['mean_abund'] = rabund.mean(axis='index')
taxonomy.dropna(subset=['mean_abund'], inplace=True)
taxonomy['order'] = order_within(taxonomy, 'otu', 'mean_abund', ascending=False)
taxonomy['short_seq_id'] = taxonomy.otu + '_' + (taxonomy.order + 1).astype(str).str.pad(width=4, fillchar='0')

rabund_rn = rabund.rename(columns=taxonomy.short_seq_id)
taxonomy = taxonomy.reset_index().set_index('short_seq_id')

# Explore Commons Sequences

In [None]:
taxonomy[taxonomy.otu.isin(['Otu0001', 'Otu0004'])].sort_values(['mean_abund'], ascending=False).head(6)

In [None]:
taxonomy.sort_values('mean_abund', ascending=False).head(20)

# Phylotype Analysis (Partial)

In [None]:
subseq_counts = pd.read_sql("""
    SELECT taxon_id, taxon_id_b, SUM(tally) AS total
    FROM rrs_library_taxon_count
    JOIN taxonomy USING (taxon_id, taxon_level)
    WHERE taxon_level = 'unique'
      AND taxon_id_b IN ('Otu0001', 'Otu0004')
    GROUP BY taxon_id
    ORDER BY total DESC
    """, con=con)

In [None]:
subseq_counts[subseq_counts.total > 1000]

In [None]:
subseq_counts.total.quantile([0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.95, 0.99, 1])

In [None]:
print('Probability of seeing a particular erroneous sequence:')
# ASSUMPTIONS:
error_rate = 0.001       # Constant error probability
subs_alphabet_size = 3  # Given an erroneous base, exactly equal probability of any other base being substituted
seq_length =  240

specific_error_rate = np.exp(np.log(1 - error_rate) * (seq_length - 1) +
                             np.log(error_rate / subs_alphabet_size) * 1)
print(specific_error_rate)

print()
print('Probability of a perfect sequence:')
perfect_seq_rate = (1 - error_rate) ** seq_length
print(perfect_seq_rate)

print()
print('Ratio of specific error seq to perfect seqs:')
print(specific_error_rate / perfect_seq_rate)

print()
print("^For every correct sequence we expect this many with a particular (single-position) error^")

print()
max_count = subseq_counts.total.max()
specific_errors_expect = max_count * specific_error_rate / perfect_seq_rate
print(('So, in our library, where we recovered {} of the most common sequence,'
       ' we expect all ({} * {}) 1-base deviations from that sequence to have {:.1f} copies')
          .format(max_count, seq_length,
                  subs_alphabet_size,
                  specific_errors_expect)
     )

print()
print(('If the distribution were Poisson, and therefore nearly normal'
       ' it would have a mean of {0:.1f} and a stdev of sqrt({0:.1f})={1:.1f}').format(specific_errors_expect,
                                                                  np.sqrt(specific_errors_expect)))

# Generate Supplemental Figure

In [None]:
data = mouse.join(rabund_rn, how='inner')
data['total_otu1'] = data[list(taxonomy[taxonomy.otu.isin(['Otu0001'])].index)].sum(axis='columns')
data = data.sort_values([ 'site'
                        , 'treatment'
                        , 'total_otu1'
                        ], ascending=[ True
                                     , False
                                     , True
                                     ])

# otus = ['Otu0001', 'Otu0004', 'Otu0005']
# cmaps = [mpl.cm.PuOr, mpl.cm.PiYG, mpl.cm.coolwarm]
otus = ['Otu0001', 'Otu0004']
rename_otus = {'Otu0001_0001': 'OTU-1.1',
               'Otu0001_0002': 'OTU-1.2',
               'Otu0001_0003': 'OTU-1.3',
               'Otu0001_0004': 'OTU-1.4',
               'Otu0004_0001': 'OTU-4.1',
               'Otu0004_0002': 'OTU-4.2',
               'Otu0004_0003': 'OTU-4.3',
               'Otu0004_0004': 'OTU-4.4',
               'Otu0001_other': 'Other OTU-1',
               'Otu0004_other': 'Other OTU-4'

              }

cmaps = [mpl.cm.PuOr, mpl.cm.PiYG]

fig, axs = plt.subplots(nrows=len(otus), figsize=(fullwidth, 1.75 * len(otus)), sharex=True, sharey=True)
take_top_k_seqs = 4

legends = []
for otu, ax, cm in zip(otus, axs, cmaps):
    color = np.row_stack([cm(np.linspace(0, 1, take_top_k_seqs)),
                         np.array([0.4, 0.4, 0.4, 1])])
    all_seqs = list(taxonomy[taxonomy.otu.isin([otu])].sort_values('mean_abund', ascending=False).index)
    top_seqs = all_seqs[:take_top_k_seqs]
    other_seqs = all_seqs[take_top_k_seqs:]
    data[top_seqs] = rabund_rn[top_seqs]
    data['{}_other'.format(otu)] = rabund_rn[other_seqs].sum(axis='columns')
    data['{}_total'.format(otu)] = rabund_rn[all_seqs].sum(axis='columns')


    cols = top_seqs + ['{}_other'.format(otu)]
    #data[cols] = data[cols].apply(lambda x: x / x.sum(), axis='columns')
    (data[cols].rename(columns=rename_otus)
         .plot.bar(stacked=True, width=1,
                   ax=ax, color=color))
    ax.set_xticklabels([])
    legends.append(ax.legend(bbox_to_anchor=(1.01, 1)))
    
    split_location = 0
    for site in ['TJL', 'UM', 'UT']:
        d0 = data[data.site == site]
        for treatment in ['control', 'acarbose']:
            d1 = d0[d0.treatment == treatment]
            ax.annotate('{} {}'.format(site, treatment),
                        xy=(split_location / len(data) + 0.01, 0.91),
                        xycoords='axes fraction', rotation=0,
                        fontsize=5)
            split_location += len(d1)
            ax.axvline(split_location - 0.5, color='k', linestyle='--')
        ax.axvline(split_location - 0.5, color='k')     
        
annotations = []
for panel, ax in zip(['A', 'B'], axs):
    letter = ax.annotate(panel, xy=(0, 1.03), xycoords='axes fraction', fontweight='heavy')
    annotations.append(letter)
    ax.set_ylabel('')
    ax.set_yticklabels('{:2.0f}%'.format(x * 100) for x in ax.get_yticks()[:-1])
    
ax.set_xlabel('Mouse')

savefig(fig, 'fig/s247_phylotypes',
        bbox_inches='tight', bbox_extra_artists=legends + annotations)