## Preamble

In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

from scipy.spatial.distance import pdist, squareform

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

import sqlite3

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.plot(
        np.linspace(0, len(loss_history), num=1000),
        np.linspace(len(loss_history), 0, num=1000),
        lw=1, linestyle='--', color='grey'
        )
    plt.title(f'+{min_loss:0.3e}')
    plt.yscale('log')
    return plt.gca()

def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

def linear_distance(linear_index):
    linear_index = linear_index.to_frame()
    return pd.DataFrame(
        squareform(
            pdist(
                linear_index,
                metric='cityblock'
            )
        ),
        index=linear_index.index,
        columns=linear_index.index,
    )

mpl.rcParams['figure.dpi']= 120

## Metadata

In [None]:
con = sqlite3.connect('/pollard/home/bsmith/Projects/ucfmt/sdata/database.db')


library = (
    pd.read_sql(
        'SELECT mgen_library_id, sample_id FROM mgen_library',
        index_col='mgen_library_id',
        con=con
    )
    .rename_axis(index='library_id')
)
sample = pd.read_sql('SELECT * FROM sample', index_col='sample_id', con=con)
sample['collection_date'] = pd.to_datetime(sample['collection_date'])

sample['name'] = sample.subject_id + '.' + sample.sample_type_specific
sample['name_i'] = (sample.reset_index()
                          .groupby('name')
                          .apply(lambda x: pd.Series(range(len(x)),
                                                     index=x.sample_id,
                                                     dtype=str))
                          .droplevel(level='name', axis='index'))
sample['name'] = sample['name'] + '.' + sample['name_i'].astype(str)

subject = pd.read_sql("SELECT * FROM subject", index_col='subject_id', con=con)

visit = pd.read_sql("SELECT * FROM visit", con=con, index_col=['visit_id'])
donor_sample = (visit[['subject_id', 'donor_sample_id', 'visit_type']]
                     .join(sample[['sample_type', 'sample_notes', 'subject_id']],
                           on='donor_sample_id', how='inner', rsuffix='_donor')
                     .set_index('donor_sample_id'))
visit['date'] = pd.to_datetime(visit['date'])
visit['mayo_score'] = visit[['status_mayo_score_stool_frequency',
                             'status_mayo_score_rectal_bleeding',
                             'status_mayo_score_endoscopy_mucosa',
                             'status_mayo_score_global_physician_rating']
                           ].sum(1, skipna=False)
visit['mayo_partial_score'] = visit[['status_mayo_score_stool_frequency',
                                     'status_mayo_score_rectal_bleeding']
                                   ].sum(1, skipna=False)
subject['recipient'] = subject['recipient'].astype(bool)
subject['mayo_score_1'] = (visit[visit.visit_type_specific == 'colonoscopy_1']
                                    .set_index('subject_id')['mayo_score'])
subject['mayo_score_2'] = (visit[visit.visit_type_specific == 'colonoscopy_2']
                                  .set_index('subject_id')['mayo_score'])
subject['mayo_score_change'] = (subject.mayo_score_2 - subject.mayo_score_1)
subject['_antibiotics'] = subject.treatment_abx_pre.map({1: 'abx+', 0: 'abx-'})
subject['arm'] = subject._antibiotics + '/' + subject.treatment_maintenance_method
subject['_responder'] = subject.responder_status.map({1: 'responder', 0: 'nonresponder'})
# Assign donors as their own donor_subject_id
subject.loc[~subject.recipient, 'donor_subject_id'] = subject.loc[~subject.recipient].index
subject['date_of_initial_fmt'] = pd.to_datetime(subject['date_of_initial_fmt'])
subject['baseline_sample_id'] = (sample.query('sample_type == "baseline"')
                                       .join(subject, on='subject_id')
                                       .reset_index()
                                       .set_index('subject_id')
                                       .sample_id)

visit['days_post_fmt'] = (visit.join(subject, on='subject_id')
                               .apply(lambda x: (x.date - x.date_of_initial_fmt).days, axis=1))

# Add metadata for donor-means
for donor_id in subject.donor_subject_id.dropna().unique():
    sample.loc[donor_id + '_mean', 'subject_id'] = donor_id
    sample.loc[donor_id + '_mean', 'sample_type'] = 'donor'
    sample.loc[donor_id + '_mean', 'sample_type_specific'] = 'donor_mean'
    sample.loc[donor_id + '_mean', 'name'] = donor_id + '_mean'

meta = library.join(sample, on='sample_id').join(subject, on='subject_id').sort_values(['donor_subject_id', 'subject_id', 'collection_date'])
meta['days_post_fmt'] = (meta.collection_date - meta.date_of_initial_fmt).dt.days

assert meta.index.is_unique
assert meta.sample_id.is_unique

## Load data

In [None]:
species_id = 102506

fit = xr.open_dataset(f'data/ucfmt.sp-{species_id}.gtpro-pileup.filt.sfacts-s50-g5000-gamma1-rho1-pi1-eps1000-alph100.nc')
fit

## Check fit

In [None]:
plot_loss_history(fit.elbo_trace.values)

In [None]:
np.abs(fit.y - (fit.p_noerr * fit.m)).sum() / fit.m.sum()

In [None]:
plt.hist(fit.pi.max('strain'), bins=20)
#plt.yscale('log')
None

In [None]:
plt.plot(fit.pi.max('library_id').to_series().sort_values(ascending=False).values)
plt.axhline(0, lw=1, c='k', linestyle='--')

In [None]:
#fit.gamma.sel(strain=fit.pi.max('library_id'))

In [None]:
plt.hist(fit.pi.max('library_id'), bins=20)
plt.yscale('log')
None

In [None]:
plt.hist((fit.pi > 0.1).sum('library_id'), bins=50)
plt.yscale('log')

## UCFMT

In [None]:
ucfmt_libs = list(set(meta.index) & set(fit.library_id.values))
ucfmt_meta = meta.loc[ucfmt_libs]
ucfmt_strains = idxwhere((fit.pi.sel(library_id=ucfmt_libs).sum('library_id') > 1e-1).to_series())

sns.clustermap(
    fit.pi.sel(library_id=ucfmt_libs, strain=ucfmt_strains).to_pandas(),
    metric='cosine',
    xticklabels=1, yticklabels=1,
    norm=mpl.colors.PowerNorm(1/5),
)

In [None]:
sns.clustermap(
    (
        ((fit.gamma * 2) - 1).T
    ).to_pandas(),
    metric='cosine',
    cmap='coolwarm',
    xticklabels=1,
)

In [None]:
d = fit.pi.to_pandas().loc[ucfmt_meta.sort_values(['subject_id', 'sample_type_specific']).index]

sns.heatmap(d, norm=mpl.colors.PowerNorm(1/2))

In [None]:
d = (fit.gamma.to_pandas() * 2 - 1).T.sort_index()

sns.heatmap(d, cmap='coolwarm', vmin=-1, vmax=1, cbar_kws=dict(ticks=[-1, 0, 1], ))

In [None]:
d = ((fit.y / fit.m) * 2 - 1).to_pandas().sort_index(axis=1).T
fig = plt.figure(figsize=(3, 5))
sns.heatmap(d, cmap='coolwarm', vmin=-1, vmax=1, cbar_kws=dict(ticks=[-1, 0, 1], ))

In [None]:
d = ((fit.y / fit.m) * 2 - 1).to_pandas().sort_index(axis=1).T.fillna(0)
fig = plt.figure(figsize=(3, 5))
sns.clustermap(d, metric='cosine', cmap='coolwarm', vmin=-1, vmax=1)

In [None]:
from itertools import cycle

subject_id_order = ['S0041', 'S0047', 'S0053', 'S0055',
                    'S0056', 'S0001', 'S0004', 'S0013',
                    'S0021', 'S0024', 'S0027', 'S0008']

sample_type_specific_order = [
    'donor',
    'baseline',
    'pre_maintenance_1',
    'pre_maintenance_2',
    'pre_maintenance_3',
    'pre_maintenance_4',
    'pre_maintenance_5',
    'pre_maintenance_6',
    'followup_1',
    'followup_2',
]
sample_type_specific_order_idx = pd.Series({v: i for i, v in enumerate(sample_type_specific_order)})

ucfmt_strain_order = fit.pi.sel(strain=ucfmt_strains).mean('library_id').to_series().sort_values(ascending=False).index

color_cycle = ([ 'blue', 'green',
                'red', 'cyan',
                'magenta', 'paleturquoise',
                'yellowgreen',
                'pink',
                'orange',
                'coral', 'purple', 'teal',
                'lime', 'gold',
                'turquoise', 'darkgreen', 'lavender',
                'tan', 'salmon', 'brown',
                
              ])
strain_cmap = dict(zip(ucfmt_strain_order, cycle(color_cycle)))


d0 = fit.pi.sel(library_id=ucfmt_libs, strain=ucfmt_strains).to_pandas()

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, figsize=(20, 10))
for subject_id, ax in zip(subject_id_order, axs.flatten()):
    donor_subject_id = subject.loc[subject_id].donor_subject_id

    d1 = pd.concat([
        d0[meta.subject_id == subject_id].rename(meta.sample_type_specific),
        d0[meta.subject_id == donor_subject_id].mean().to_frame('donor').T
    ]).reindex(['donor'] + sample_type_specific_order)

    for strain in ucfmt_strain_order:
        if strain not in d1.columns:
            continue
        d2 = d1.assign(x=sample_type_specific_order_idx).dropna()
        ax.plot(d2.drop('donor').x, d2.drop('donor')[strain], marker='o', label=strain, c=strain_cmap[strain])
        ax.scatter(d2.x, d2[strain], marker='o', c=strain_cmap[strain])    
        ax.set_title(f'{subject_id} ({donor_subject_id})')

plt.xticks(ticks=sample_type_specific_order_idx, labels=sample_type_specific_order, rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1, 1))
plt.yscale('symlog', linthreshy=1e-3)
plt.ylim(1e-4, 2)

None

In [None]:
from itertools import cycle

subject_id_order = ['S0041', 'S0047', 'S0053', 'S0055',
                    'S0056', 'S0001', 'S0004', 'S0013',
                    'S0021', 'S0024', 'S0027', 'S0008']

sample_type_specific_order = [
    'donor',
    'baseline',
    'pre_maintenance_1',
    'pre_maintenance_2',
    'pre_maintenance_3',
    'pre_maintenance_4',
    'pre_maintenance_5',
    'pre_maintenance_6',
    'followup_1',
    'followup_2',
]
sample_type_specific_order_idx = pd.Series({v: i for i, v in enumerate(sample_type_specific_order)})

ucfmt_strain_order = fit.pi.sel(strain=ucfmt_strains).mean('library_id').to_series().sort_values(ascending=False).index

color_cycle = ([ 'blue', 'green',
                'red', 'cyan',
                'magenta', 'paleturquoise',
                'yellowgreen',
                'pink',
                'orange',
                'coral', 'purple', 'teal',
                'lime', 'gold',
                'turquoise', 'darkgreen', 'lavender',
                'tan', 'salmon', 'brown',
                
              ])
strain_cmap = dict(zip(ucfmt_strain_order, cycle(color_cycle)))


d0 = fit.pi.sel(library_id=ucfmt_libs, strain=ucfmt_strains).to_pandas()

fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(12, 2))
for subject_id, ax in zip(['S0008', 'S0056', 'S0047'], axs.flatten()):
    donor_subject_id = subject.loc[subject_id].donor_subject_id

    d1 = pd.concat([
        d0[meta.subject_id == subject_id].rename(meta.sample_type_specific),
        d0[meta.subject_id == donor_subject_id].mean().to_frame('donor').T
    ]).reindex(['donor'] + sample_type_specific_order)

    for strain in ucfmt_strain_order:
        if strain not in d1.columns:
            continue
        d2 = d1.assign(x=sample_type_specific_order_idx).dropna()
        ax.plot(d2.drop('donor').x, d2.drop('donor')[strain], marker='o', label=strain, c=strain_cmap[strain])
        ax.scatter(d2.x, d2[strain], marker='o', c=strain_cmap[strain])    
        ax.set_title(f'{subject_id} ({donor_subject_id})')
        ax.set_xticks(np.arange(0, 10))

#plt.xticks(ticks=sample_type_specific_order_idx, labels=sample_type_specific_order, rotation=45, ha='right')
plt.legend(bbox_to_anchor=(1, 1))
plt.yscale('symlog', linthreshy=1e-3)
plt.ylim(1e-4, 2)

None

### D0097

In [None]:
d97_libs = list(set(idxwhere(meta.donor_subject_id == 'D0097')) & set(fit.library_id.values))
d97_strains = idxwhere((fit.pi.sel(library_id=d97_libs).sum('library_id') > 1e-1).to_series())

sns.clustermap(
    fit.pi.sel(library_id=d97_libs, strain=d97_strains).to_pandas(),
    metric='cosine',
    xticklabels=1, yticklabels=1,
    norm=mpl.colors.PowerNorm(1/1),
)

In [None]:
sns.clustermap(
    fit.pi.sel(
        library_id=ucfmt_meta[lambda x: x.subject_id=='D0097'].index,
        strain=d97_strains,
    ).to_pandas(),
    metric='cosine',
    xticklabels=1, yticklabels=1,
    norm=mpl.colors.PowerNorm(1/3),
)

In [None]:
sns.clustermap(
    (
        ((fit.gamma * 2) - 1)
        .sel(strain=d97_strains).T
    ).to_pandas().sort_index(),
    metric='cosine',
    cmap='coolwarm',
    xticklabels=1,
#     row_cluster=False,
)

In [None]:
from scripts.strain_facts import genotype_distance

genotype_distance(fit.gamma.sel(strain=18), fit.gamma.sel(strain=9))

### D0044

In [None]:
d44_libs = list(set(idxwhere(meta.donor_subject_id == 'D0044')) & set(fit.library_id.values))
d44_strains = idxwhere((fit.pi.sel(library_id=d44_libs).sum('library_id') > 1e-1).to_series())

sns.clustermap(
    fit.pi.sel(library_id=d44_libs, strain=d44_strains).to_pandas(),
    metric='cosine',
    xticklabels=1, yticklabels=1,
    norm=mpl.colors.PowerNorm(1/1),
)

In [None]:
sns.clustermap(
    fit.pi.sel(
        library_id=ucfmt_meta[lambda x: x.subject_id=='D0044'].index,
        strain=d44_strains,
    ).to_pandas(),
    metric='cosine',
    xticklabels=1, yticklabels=1,
    norm=mpl.colors.PowerNorm(1/2),
)

In [None]:
sns.clustermap(
    (
        ((fit.gamma * 2) - 1)
        .sel(strain=d44_strains).T
    ).to_pandas().sort_index(),
    metric='cosine',
    cmap='coolwarm',
    xticklabels=1,
#     row_cluster=False,
)

## Strain Entropy

In [None]:
strain_total_cvrg = (fit.pi.T * fit.m.mean('position')).sum('library_id')
strain_entropy = binary_entropy(fit.gamma).mean('position')

plt.scatter(strain_total_cvrg, strain_entropy, s=5)
plt.ylabel('strain-entropy')
plt.xlabel('estimated-total-coverage')

#plt.xscale('symlog', linthresh=1e2)
#plt.xlim(-10, 1000)
plt.xscale('log')

In [None]:
plt.hist(strain_entropy)
None