In [None]:
import pandas as pd
import sqlite3
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from scipy.spatial.distance import cdist, pdist
import seaborn as sns

def idxwhere(x):
    return x[x].index

## Smith 2019 Data

In [None]:
con2019 = sqlite3.connect('../longev/res/C2013.results.db')

In [None]:
unique_to_otu2019 = pd.read_sql(
    """
    SELECT taxon_id, taxon_id_b
    FROM taxonomy
    WHERE taxon_level = 'unique'
      AND taxon_level_b = 'otu-0.03'
    """,
    index_col=['taxon_id'],
    con=con2019,
).squeeze()

In [None]:
otu_taxonomy2019 = pd.read_sql(
    """
    SELECT taxon_id, taxon_level_b, taxon_id_b FROM taxonomy
    WHERE taxon_level = 'otu-0.03'
    """,
    index_col=['taxon_id', 'taxon_level_b'],
    con=con2019,
).squeeze().unstack()[['phylum', 'class', 'order', 'family', 'genus']]

In [None]:
count2019 = pd.read_sql(
    """
    SELECT extraction_id, taxon_id, SUM(tally) AS tally
    FROM rrs_library_taxon_count
    JOIN rrs_library USING (rrs_library_id)
    GROUP BY extraction_id, taxon_id
    """,
    index_col=['extraction_id', 'taxon_id'],
    con=con2019,
).squeeze().unstack(fill_value=0).groupby(unique_to_otu2019, axis='columns').sum()

In [None]:
otu_taxonomy2019[otu_taxonomy2019.family == 'Muribaculaceae'].head(10)

## Smith2020 Data

In [None]:
con2020 = sqlite3.connect('data/core.muri2.2.denorm.db')

In [None]:
count2020 = (pd.read_sql(
        """
        SELECT extraction_id, otu_id, SUM(tally) AS tally
        FROM rrs_taxon_count
        GROUP BY extraction_id, otu_id
        """,
        con=con2020, index_col=['extraction_id', 'otu_id'])
    .squeeze().unstack().fillna(0))

In [None]:
otu_taxonomy2020 = pd.read_sql(
    """
    SELECT DISTINCT otu_id, domain_, phylum_, class_, order_, family_, genus_ FROM rrs_taxonomy
    """,
    index_col='otu_id',
    con=con2020,
)

In [None]:
otu_taxonomy2020[otu_taxonomy2020.family_ == 'Muribaculaceae'].head(10)

## Matching

In [None]:
count2019

In [None]:
count2020.loc[count2019.index]

In [None]:
muri_otus2019 = idxwhere(otu_taxonomy2019.loc[count2019.columns].family == 'Muribaculaceae')
muri_otus2020 = idxwhere((otu_taxonomy2020.loc[count2020.columns].family_ == 'Muribaculaceae')
                         & count2020.loc[count2019.index].sum() > 0)

# fig, axs = plt.subplots(nrows=8, ncols=8, figsize=(15, 15))

# for otu2019, row in zip(muri_otus2019, axs):
#     for otu2020, ax in zip(muri_otus2020, row):
#         ax.scatter(count2019[otu2019], count2020.loc[count2019.index, otu2020])

In [None]:
dmat_corr = pd.DataFrame(
    cdist(
        count2019.loc[:, muri_otus2019].T,
        count2020.loc[count2019.index, muri_otus2020].T,
        metric='correlation',
    ),
    index=muri_otus2019,
    columns=muri_otus2020,
).rename_axis(index='otus2019', columns='otus2020')

dmat_cb = pd.DataFrame(
    cdist(
        count2019.loc[:, muri_otus2019].T,
        count2020.loc[count2019.index, muri_otus2020].T,
        metric='cityblock',
    ),
    index=muri_otus2019,
    columns=muri_otus2020,
).rename_axis(index='otus2019', columns='otus2020')

In [None]:
best_hit = pd.DataFrame({
    'corr_hit': dmat_corr.idxmin(),
    'corr': dmat_corr.min(),
    'cb_hit': dmat_cb.idxmin(),
    'cb': dmat_cb.min(),
    'total2020': count2020.loc[count2019.index].sum(),
})#.dropna()

#total2020 = 
best_hit.join(count2019.sum().rename('total2019'), on='corr_hit').loc[muri_otus2020].head(20)