## Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere

import sfacts as sf

from tqdm import tqdm

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## Load Data

In [None]:
species_id = 102492

fit = sf.data.World.load(f'data/zshi.sp-{species_id}.metagenotype.filt-poly05-cvrg25.fit-sfacts44-s200-g5000-seed0.refit-sfacts41-g10000-seed0.world.nc')
fit.data['position'] = fit.data.position.astype(int)
print(fit.sizes)


cull_threshold = 0.05
distance_proportionality = 0.10687

fit_communities = fit.communities.mlift('sel', strain=fit.communities.max("sample") > cull_threshold)
print((1 - fit_communities.sum("strain")).max())
fit_communities = sf.Communities(fit_communities.data / fit_communities.sum("strain"))
fit_genotypes = fit.genotypes.mlift('sel', strain=fit_communities.strain)

fit = sf.World.from_combined(fit_communities, fit_genotypes, fit.metagenotypes)
print(fit.sizes)

In [None]:
ref = sf.data.Metagenotypes.load(f'data/gtprodb.sp-{species_id}.genotype.nc').mlift('sel', position=fit.position).to_estimated_genotypes(pseudo=0)
ref.sizes

In [None]:
plt.hist(ref.mean('strain'))

In [None]:
# TODO: Decide if I want to discretize here.
ref_dist = ref.discretized().pdist()
fit_dist = fit_genotypes.discretized().pdist()

In [None]:
dedup_thresh = 0.001 / distance_proportionality

ref_clust = pd.Series(
    AgglomerativeClustering(
        distance_threshold=dedup_thresh, n_clusters=None, affinity='precomputed', linkage='average'
    ).fit_predict(ref_dist),
    index=ref_dist.columns,
)
fit_clust = pd.Series(
    AgglomerativeClustering(
        distance_threshold=dedup_thresh, n_clusters=None, affinity='precomputed', linkage='average'
    ).fit_predict(fit_dist),
    index=fit_dist.columns.astype(int),
)

ref_clust.value_counts()

In [None]:
fit_clust.value_counts().sort_values(ascending=False).head()

In [None]:
# FIXME: Discretize deduplicated strains?
ref_dedup = sf.Genotypes(ref.to_series().unstack('strain').groupby(ref_clust, axis='columns').mean().rename_axis(columns='strain').T.stack().to_xarray())
fit_dedup = sf.Genotypes(fit_genotypes.to_series().unstack('strain').groupby(fit_clust, axis='columns').mean().rename(columns=lambda x: int(x)).rename_axis(columns='strain').T.stack().to_xarray())

In [None]:
(ref.sizes['strain'], ref_dedup.sizes['strain']), (fit_genotypes.sizes['strain'], fit_dedup.sizes['strain'])

In [None]:
from scipy.spatial.distance import pdist, squareform

g = sf.data.Genotypes.concat(dict(
    ref=ref_dedup,
    fit=fit_dedup,
), dim='strain')

dist = pd.DataFrame(g.discretized().pdist(), index=g.strain, columns=g.strain)

In [None]:
d = dist.loc[
        lambda x: x.index.str.startswith('ref_'),
        lambda x: x.columns.str.startswith('ref_')
    ]
min_dist_ref_to_ref = (d + np.eye(len(d))).min()

d = dist.loc[
        lambda x: x.index.str.startswith('fit_'),
        lambda x: x.columns.str.startswith('fit_')
    ]
min_dist_fit_to_fit = (d + np.eye(len(d))).min()

min_dist_fit_to_ref = dist.loc[
    lambda x: x.index.str.startswith('ref_'),
    lambda x: x.columns.str.startswith('fit_')
].min()

bins = np.linspace(0, 0.03, num=31)



plt.hist(
    min_dist_ref_to_ref * distance_proportionality,
    bins=bins,
    alpha=0.5,
    density=True,
    label='ref2ref',
)
plt.hist(
    min_dist_fit_to_fit * distance_proportionality,
    bins=bins,
    alpha=0.5,
    density=True,
    label='fit2fit',
)
plt.hist(
    min_dist_fit_to_ref * distance_proportionality,
    bins=bins,
    alpha=0.5,
    density=True,
    label='fit2ref',
)

plt.legend()


# plt.yscale('log')

In [None]:
plt.scatter(min_dist_fit_to_ref, min_dist_fit_to_fit)
plt.plot([0, 0.225], [0, 0.225])

In [None]:
sp.stats.wilcoxon(min_dist_fit_to_ref, min_dist_fit_to_fit, alternative='greater')

In [None]:
np.random.seed(0)

sf.plot.plot_genotype(
    g.random_sample(position=3500).discretized(),
    row_colors_func=lambda w: w.strain.str.startswith('fit_'),
    row_linkage_func=lambda w: g.discretized().linkage(method='average'),
    scaley=3e-2,
    scalex=2e-3,
    yticklabels=0,
    dheight=0.001,
    cmap='gray_r',
    norm=mpl.colors.PowerNorm(1, vmin=-0.1, vmax=1.),
)

In [None]:
clust_thresh = np.quantile(squareform(dist), 0.1)
print(clust_thresh, clust_thresh * distance_proportionality)

plt.hist(squareform(dist))
plt.axvline(clust_thresh, color='k')

In [None]:
all_clust = pd.Series(distance_proportionalityAgglomerativeClustering(
        distance_threshold=clust_thresh, n_clusters=None, affinity='precomputed', linkage='average'
    ).fit_predict(dist),
    index=g.strain,
)

clust_type = all_clust.index.to_series().str[:3]

In [None]:
from itertools import product
    
def count_clust_types(clust, key):
    clust_types = (
        clust
        .to_frame(name='clust')
        .assign(key=key)
        .groupby(['clust', 'key'])
        .apply(len)
        .unstack(fill_value=0)
    )
    all_keys = clust_types.columns.to_list()
    all_possible_clust_types = pd.DataFrame(product([True, False], repeat=len(all_keys)), columns=all_keys)
    return clust_types, clust_types.apply(lambda x: x > 0).groupby(all_keys).apply(len).reindex(all_possible_clust_types, fill_value=0)

clust_stats, clust_type_tally = count_clust_types(all_clust, clust_type)

clust_type_tally

In [None]:
_clust_stats = (
    clust_stats
    .assign(
        tally=lambda x: x.sum(1),
    )
    .assign(only_fit=lambda x: x.ref==0, only_ref=lambda x: x.fit==0)
    .assign(both=lambda x: ~(x.only_fit | x.only_ref))
    .assign(clust_class=lambda x: x[['only_ref', 'both', 'only_fit']].values.argmax(1))
)

clust_genotypes = sf.Genotypes(g.to_series().unstack('strain').groupby(all_clust, axis='columns').mean().rename_axis(columns='strain').T.stack().to_xarray())

sf.plot.plot_genotype(
    clust_genotypes.random_sample(position=2500),
    row_colors_func=lambda w: _clust_stats[['clust_class']].to_xarray(),
    scaley=1e-2,
    scalex=3e-3,
    yticklabels=0
)

In [None]:
sf.plot.plot_genotype(
    clust_genotypes.discretized().random_sample(position=2500),
    row_colors_func=lambda w: _clust_stats[['clust_class']].to_xarray(),
    scaley=1e-2,
    scalex=3e-3,
    yticklabels=0
)

In [None]:
np.random.seed(0)

sf.plot.plot_genotype(
    clust_genotypes.discretized().random_sample(position=3500),
    row_colors_func=lambda w: _clust_stats[['clust_class']].to_xarray(),
    row_linkage_func=lambda w: clust_genotypes.linkage(method='complete'),
    scaley=3e-2,
    scalex=2e-3,
    yticklabels=0,
    dheight=0.001,
    cmap='gray_r',
    norm=mpl.colors.PowerNorm(1, vmin=-0.1, vmax=1.),
)

plt.savefig(f'fig/coclustering_{species_id}.png', dpi=400)

In [None]:
len(_clust_stats)

In [None]:
# Count the number of each type of genotype in each type of cluster.
_clust_stats.groupby(['only_fit', 'both', 'only_ref'])[['fit', 'ref']].sum().apply(lambda x: x / x.sum())

In [None]:
unmatched_inferred_strains = idxwhere(all_clust.isin(idxwhere(_clust_stats.only_fit)))
print(len(unmatched_inferred_strains))

print(1 - dist.loc[lambda x: x.columns.str.startswith('ref_'), unmatched_inferred_strains].min().mean() * distance_proportionality)

In [None]:
np.random.seed(0)

def permutation_clust_types(clust, key, n=1, progress=False):
    _, observed = count_clust_types(clust, key)
    permutations = []
    for _ in tqdm(range(n), disable=(not progress)):
        perm_clust = pd.Series(np.random.choice(clust.values, size=len(clust), replace=False), index=clust.index)
        permutations.append(count_clust_types(perm_clust, key)[1])
    return observed, pd.DataFrame(permutations)

obs, perm = permutation_clust_types(all_clust, all_clust.index.to_series().str[:3], n=9999, progress=True)

In [None]:
obs.to_frame(name='tally').assign(frac=lambda x: x / x.sum())

In [None]:
fig, axs = plt.subplots(3, figsize=(5, 5), sharex=True, sharey=True)

for (key, c, label), ax in zip([((True, True), 'grey', 'both'), ((True, False), 'tab:red', 'fit'), ((False, True), 'tab:blue', 'ref')], axs):
    ax.hist(perm[key], color=c)
    ax.axvline(obs[key], color=c, label=label)
    ax.set_title(label)
fig.tight_layout()

In [None]:
def tally_permutation_test(obs, perm):    
    out = {}
    for key in obs.index:
        out[key, '>'] = ((obs[key] > perm[key]).sum())
        out[key, '=='] = ((obs[key] == perm[key]).sum())
        out[key, '<'] = ((obs[key] < perm[key]).sum())
    return pd.Series(out).unstack()

tally_permutation_test(obs, perm).apply(lambda x: x / x.sum(), axis=1)

In [None]:
np.random.seed(0)

def permutation_strain_clust_type(clust, key, n=1, progress=False):
    clust_stats, _ = count_clust_types(clust, key)
    observed = clust_stats.groupby((clust_stats > 0).apply(lambda x: tuple(x), axis=1)).sum().stack()
    permutations = []
    for _ in tqdm(range(n), disable=(not progress)):
        perm_clust = pd.Series(np.random.choice(clust.values, size=len(clust), replace=False), index=clust.index)
        perm_clust_stats, _ = count_clust_types(perm_clust, key)
        permutations.append(perm_clust_stats.groupby((perm_clust_stats > 0).apply(lambda x: tuple(x), axis=1)).sum().stack())
    return observed, pd.DataFrame(permutations)

obs2, perm2 = permutation_strain_clust_type(all_clust, all_clust.index.to_series().str[:3], n=9999, progress=True)

In [None]:
obs2

In [None]:
perm2

In [None]:
fig, axs = plt.subplots(5, figsize=(5, 5), sharex=True, sharey=True)

for (key, c, label), ax in zip(
    [
        (((True, False), 'fit'), 'tab:blue', 'fit'),
        (((False, True), 'ref'), 'tab:red', 'ref'),
    ],
    axs
):
    ax.hist(perm2[key], color=c)
    ax.axvline(obs2[key], color=c, label=label)
    ax.set_title(label)
    
    

ax = axs[2]
c = 'grey'
ax.hist(perm2[((True, True), 'fit')] + perm2[((True, True), 'ref')], color=c)
ax.axvline(obs2[((True, True), 'fit')] + obs2[((True, True), 'ref')], c=c)
ax.set_title('both (both)')

ax = axs[3]
c = 'grey'
ax.hist(perm2[((True, True), 'fit')], color=c)
ax.axvline(obs2[((True, True), 'fit')], c=c)
ax.set_title('both (fit)')

ax = axs[4]
c = 'grey'
ax.hist(perm2[((True, True), 'ref')], color=c)
ax.axvline(obs2[((True, True), 'ref')], c=c)
ax.set_title('both (ref)')

fig.tight_layout()

In [None]:
(
    (perm2[((True, True), 'fit')] + perm2[((True, True), 'ref')])
    >=
    (obs2[((True, True), 'fit')] + obs2[((True, True), 'ref')])
).mean()

In [None]:
(
    (perm2[((True, True), 'fit')])
    >=
    (obs2[((True, True), 'fit')])
).mean()

In [None]:
(
    (perm2[((True, True), 'ref')])
    >=
    (obs2[((True, True), 'ref')])
).mean()

In [None]:
(
    obs2[((True, False), 'fit')]
    <=
    perm2[((True, False), 'fit')]    
).mean()

In [None]:
perm2[((True, False), 'fit')].mean()

In [None]:
obs2[((True, False), 'fit')]