## Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere


import sfacts as sf

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## Load Data

In [None]:
fit = sf.data.World.load('data/zshi.sp-102492.metagenotype.filt-poly05-cvrg25.fit-sfacts13-s500-g500-seed3.world.nc')
fit.sizes

In [None]:
ref = sf.data.Metagenotypes.load('data/gtprodb.sp-102492.genotype.nc').mlift('sel', position=fit.position).to_estimated_genotypes(pseudo=0)
ref.sizes

In [None]:
bins = np.linspace(0, 1, num=51)

plt.hist(fit.metagenotypes.to_estimated_genotypes(pseudo=1e-3).entropy(), bins=bins)
plt.hist(fit.genotypes.entropy(), bins=bins)
plt.hist(ref.entropy(), bins=bins)

plt.yscale('log')
None

In [None]:
bins = np.linspace(0, 1, num=51)

fig, axs = plt.subplots(2)
axs[0].hist(fit.genotypes.values.flatten(), bins=bins)
axs[1].hist(fit.communities.max("strain").values.flatten(), bins=bins)

for ax in axs:
    ax.set_yscale('log')
None

In [None]:
sf.plot.plot_genotype(fit, scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
sf.plot.plot_genotype(ref, scaley=1e-2, scalex=1e-3, yticklabels=0)

In [None]:
fit_genotypes_filt = fit.genotypes.mlift('sel', strain=fit.genotypes.entropy() < 0.25)
fit_genotypes_highent = fit.genotypes.mlift('sel', strain=fit.genotypes.entropy() > 0.25)

fit_genotypes_filt.sizes, fit_genotypes_highent.sizes

In [None]:
plt.hist(squareform(fit_genotypes_filt.pdist()), bins=np.linspace(0, 0.5, num=101))
plt.yscale('log')
None

In [None]:
g = sf.data.Genotypes.concat(dict(
    bad=fit_genotypes_highent,
    good=fit_genotypes_filt
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('good'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
g = sf.data.Genotypes.concat(dict(
    ref=ref,
    fit=fit_genotypes_filt,
    ent=fit_genotypes_highent,
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('fit'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
g = sf.data.Genotypes.concat(dict(
    ref=ref,
    fit=fit_genotypes_filt,
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('fit'), scaley=2e-2, scalex=1e-3, yticklabels=0)

In [None]:
g = sf.data.Genotypes.concat(dict(
    ref=ref,
    fit=fit_genotypes_filt,
), dim='strain')


sf.plot.plot_genotype(g, row_colors_func=lambda w: w.strain.str.startswith('fit'), scaley=5e-3, scalex=1e-3, yticklabels=0)

In [None]:
position_meta = pd.read_table(
    'ref/gtpro/variants_main.covered.hq.snp_dict.tsv',
    names=['species_id', 'position', 'contig', 'contig_position', 'ref', 'alt']
).set_index('position')[lambda x: x.species_id.isin([102492])]

position_meta

In [None]:
ref_diss = ref.pdist("strain")
ref_clust = AgglomerativeClustering(n_clusters=None, distance_threshold=0.02, affinity='precomputed', linkage='complete').fit_predict(ref_diss)
ref_clust = pd.Series(ref_clust, index=ref_diss.index)
ref_clust.value_counts()

In [None]:
ref_agg = sf.data.Genotypes(
    ref.data.to_dataframe().join(ref_clust.rename("clust")).reset_index().groupby(["clust", "position"]).mean().rename_axis(index=dict(clust='strain')).squeeze().to_xarray()
)

In [None]:
ld = {}
for contig, pos in position_meta.loc[ref_agg.position].groupby('contig'):
    print(contig)
    g = ref_agg.sel(position=pos.index)
    r2 = (1 - pdist(g.values.T, 'correlation'))**2
    x = pdist(np.expand_dims(pos.contig_position.values, axis=1), 'cityblock')
    ld[contig] = (x, r2)
ref_ld = pd.DataFrame(np.concatenate([np.stack([x, r2], axis=1) for x, r2 in ld.values()]), columns=['x', 'r2'])

In [None]:
est_diss = fit_genotypes_filt.pdist("strain")
est_clust = AgglomerativeClustering(n_clusters=None, distance_threshold=0.1, affinity='precomputed', linkage='complete').fit_predict(est_diss)
est_clust = pd.Series(est_clust, index=est_diss.index)
est_clust.value_counts()

In [None]:
est_agg = sf.data.Genotypes(
    fit_genotypes_filt.data.to_dataframe().join(est_clust.rename("clust")).reset_index().groupby(["clust", "position"]).mean().genotypes.rename_axis(index=dict(clust='strain')).to_xarray()
)

In [None]:
ld = {}
for contig, pos in position_meta.loc[est_agg.position].groupby('contig'):
    print(contig)
    g = est_agg.sel(position=pos.index)
    r2 = (1 - pdist(g.values.T, 'correlation'))**2
    x = pdist(np.expand_dims(pos.contig_position.values, axis=1), 'cityblock')
    ld[contig] = (x, r2)
est_ld = pd.DataFrame(np.concatenate([np.stack([x, r2], axis=1) for x, r2 in ld.values()]), columns=['x', 'r2'])

In [None]:
stepsize = 1
right = 500

d = est_ld[lambda p: p.x < right]
bins_est = {}
for start in range(0, right, stepsize):
    stop = start + stepsize
    bins_est[start] = d[(d.x >= start) & (d.x < stop)].r2.mean()
    
    
d = ref_ld[lambda p: p.x < right]
bins_ref = {}
for start in range(0, right, stepsize):
    stop = start + stepsize
    bins_ref[start] = d[(d.x >= start) & (d.x < stop)].r2.mean()


fig = plt.figure(figsize=(10, 5))

plt.hexbin('x', 'r2', data=d, cmap='Blues', norm=mpl.colors.PowerNorm(1/3), mincnt=1, gridsize=(30, 10), label='__nolegend__')
plt.colorbar(label='Count')

ax = plt.gca()

ax.plot(pd.Series(bins_est), color='red', label=f'inferred genotypes')
ax.plot(pd.Series(bins_ref), color='blue', label=f'reference genotypes')
ax.axhline(est_ld.r2.mean(), lw=1, color='red', linestyle='--')
ax.axhline(ref_ld.r2.mean(), lw=1, color='blue', linestyle='--')

ax.legend(title=f'Mean LD at distance X ({stepsize} bp Bin)')  #bbox_to_anchor=(0.85, 1.15), ncol=2)