## Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp
from tqdm import tqdm

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere

import sfacts as sf

from tqdm import tqdm

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## Load Data

In [None]:
species_id = 102492

fit = sf.data.World.load(f'data/zshi.sp-{species_id}.metagenotype.filt-poly05-cvrg25.fit-sfacts44-s200-g5000-seed0.refit-sfacts41-g10000-seed0.world.nc')
fit.data['position'] = fit.data.position.astype(int)
print(fit.sizes)


cull_threshold = 0.05

fit_communities = fit.communities.mlift('sel', strain=fit.communities.max("sample") > cull_threshold)
print((1 - fit_communities.sum("strain")).max())
fit_communities = sf.Communities(fit_communities.data / fit_communities.sum("strain"))
fit_genotypes = fit.genotypes.mlift('sel', strain=fit_communities.strain)

fit = sf.World.from_combined(fit_communities, fit_genotypes, fit.metagenotypes)
print(fit.sizes)

In [None]:
ref = sf.data.Metagenotypes.load(f'data/gtprodb.sp-{species_id}.genotype.nc')
ref_genotypes = ref.mlift('sel', position=fit_genotypes.position).to_estimated_genotypes(pseudo=0)
ref_genotypes.sizes

In [None]:
position_meta = pd.read_table(
    'ref/gtpro/variants_main.covered.hq.snp_dict.tsv',
    names=['species_id', 'position', 'contig', 'contig_position', 'ref', 'alt']
).set_index('position')[lambda x: x.species_id.isin([species_id])]

position_meta

In [None]:
fit_dist = fit_genotypes.discretized().pdist()
ref_dist = ref_genotypes.discretized().pdist()

In [None]:
dedup_thresh = 0.05  # / distance_proportionality

fit_dedup_clust = pd.Series(
    AgglomerativeClustering(
        distance_threshold=dedup_thresh, n_clusters=None, affinity='precomputed', linkage='average'
    ).fit_predict(fit_dist),
    index=fit_genotypes.strain.astype(int),
)
ref_dedup_clust = pd.Series(
    AgglomerativeClustering(
        distance_threshold=dedup_thresh, n_clusters=None, affinity='precomputed', linkage='average'
    ).fit_predict(ref_dist),
    index=ref_genotypes.strain,
)

In [None]:
fit_genotypes_dedup = sf.Genotypes(fit_genotypes.to_series().unstack('strain').groupby(fit_dedup_clust, axis='columns').mean().rename(columns=lambda x: int(x)).rename_axis(columns='strain').T.stack().to_xarray())

In [None]:
ref_genotypes_dedup = sf.Genotypes(ref_genotypes.to_series().unstack('strain').groupby(ref_dedup_clust, axis='columns').mean().rename(columns=lambda x: int(x)).rename_axis(columns='strain').T.stack().to_xarray())

In [None]:
(fit_genotypes.sizes['strain'], fit_genotypes_dedup.sizes['strain']), (ref_genotypes.sizes['strain'], ref_genotypes_dedup.sizes['strain'])

In [None]:
_g_ref = ref_genotypes_dedup.random_sample(position=3681)  # FIXME: May not be a stable estimate until many positions are included (20k)
r2_ref = (1 - pdist(_g_ref.values.T, 'correlation'))**2

_g_fit = fit_genotypes_dedup.sel(position=_g_ref.position)  # FIXME: May not be a stable estimate until many positions are included (20k)
r2_fit = (1 - pdist(_g_fit.values.T, 'correlation'))**2

In [None]:
np.mean(r2_ref), np.mean(r2_fit)

In [None]:
np.median(r2_ref), np.median(r2_fit)

In [None]:
np.quantile(r2_ref, 0.9), np.quantile(r2_fit, 0.9)

In [None]:
from scipy.stats import mannwhitneyu, wilcoxon

n = 10000
wilcoxon(r2_ref[:n], r2_fit[:n])

In [None]:
ld = {}
for contig, pos in position_meta.loc[ref_genotypes_dedup.position].groupby('contig'):
    print(contig)
    g = ref_genotypes_dedup.sel(position=pos.index)
    r2 = (1 - pdist(g.values.T, 'correlation'))**2
    x = pdist(np.expand_dims(pos.contig_position.values, axis=1), 'cityblock')
    ld[contig] = (x, r2)
ref_ld = pd.DataFrame(np.concatenate([np.stack([x, r2], axis=1) for x, r2 in ld.values()]), columns=['x', 'r2'])

In [None]:
ld = {}
for contig, pos in position_meta.loc[fit_genotypes_dedup.position].groupby('contig'):
    print(contig)
    g = fit_genotypes_dedup.sel(position=pos.index)
    r2 = (1 - pdist(g.values.T, 'correlation'))**2
    x = pdist(np.expand_dims(pos.contig_position.values, axis=1), 'cityblock')
    ld[contig] = (x, r2)
fit_ld = pd.DataFrame(np.concatenate([np.stack([x, r2], axis=1) for x, r2 in ld.values()]), columns=['x', 'r2'])

In [None]:
max_dist = 12000

ld_profile = (
    pd.DataFrame(dict(
        x=fit_ld.x,
        fit=fit_ld.r2,
        ref=ref_ld.r2))
    [lambda x: x.x < max_dist]
    .groupby('x')
    .quantile(0.9)
)

In [None]:
nx, ny = 121, 51
xlim = np.array([0.5, 1e5])
ylim = np.array([0, 1])
xbins = np.unique(np.floor(np.logspace(*np.log10(xlim), num=51)).astype(int))
ybins = np.unique(np.linspace(*ylim, num=ny))

d = fit_ld#[fit_ld.x < 1e5]

hist, xedges, yedges = np.histogram2d(
    d['x'],
#     d['x'],
    d['r2'],
    bins=(xbins, ybins),
#     range=np.array([[np.min(xbins), np.max(xbins)], [np.min(ybins), np.max(ybins)]]),
)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
hist = pd.DataFrame(hist, columns=yedges[1:], index=xedges[:-1]).T
norm_hist = hist / hist.sum()

In [None]:
# TODO: Calculate the 90th percentile within the distance window from the histogram

out = []
for i, left in enumerate(xbins[:-1]):
    right = xbins[i + 1]
    out.append((
        left,
        right,
        fit_ld[(fit_ld.x > left) & (fit_ld.x <= right)].r2.quantile(0.9),
        ref_ld[(ref_ld.x > left) & (ref_ld.x <= right)].r2.quantile(0.9)
    ))
    
hist_bin_quantile90 = pd.DataFrame(out, columns=['left', 'right', 'fit', 'ref'])

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_90th.pdf', dpi=400)

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

# ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
# ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

# ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
# ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


# ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
# ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_no_trends.png', dpi=400)

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

# ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

# ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


# ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_no_ref.png', dpi=400)

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_90th.pdf', dpi=400)

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

# ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
# ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


# ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
# ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_90th_novec.png', dpi=400)

In [None]:
fig = plt.figure()

plt.pcolormesh(norm_hist.columns, norm_hist.index, norm_hist, norm=mpl.colors.LogNorm(vmin=1e-3, vmax=1.0), cmap='binary')
plt.xscale('log')
plt.colorbar()


ax = plt.gca()

# ax.scatter('x', 'ref', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='blue', label=f'reference genotypes', s='s', alpha=0.5)
# ax.scatter('x', 'fit', data=ld_profile.reset_index().assign(s=lambda x: 20 / x['x']), color='red', label=f'inferred genotypes', s='s', alpha=0.5)

ax.plot('right', 'ref', data=hist_bin_quantile90, lw=1, color='blue', alpha=1)
ax.plot('right', 'fit', data=hist_bin_quantile90, lw=1, color='red', alpha=1)


ax.axhline(np.quantile(r2_ref, 0.9), lw=1, color='blue', linestyle='--')
ax.axhline(np.quantile(r2_fit, 0.9), lw=1, color='red', linestyle='--')

# ax.legend(title=f'Mean LD at pairwise distance')  #bbox_to_anchor=(0.85, 1.15), ncol=2

ax.set_xlabel('Genomic Distance (bp)')
ax.set_ylabel('LD ($r^2$)')
ax.set_ylim(0, 1.03)
ax.set_xlim(0.9, 1e4)


fig.savefig(f'fig/ld_decay_{species_id}_90th_noscatter.pdf', dpi=400)

In [None]:
fig, ax = plt.subplots(figsize=(2, 1))

ax.plot([], color='blue', label=f'reference', alpha=0.7)
ax.plot([], color='red', label=f'inferred', alpha=0.7)
ax.legend()  #bbox_to_anchor=(0.85, 1.15), ncol=2)

ax.axis('off')
fig.savefig(f'fig/ld_decay_{species_id}_legend.pdf', dpi=400)

In [None]:
# LD_{90,1/2}
(ld_profile > 0.5).idxmin()

In [None]:
ld_profile.head(5)

In [None]:
np.quantile(r2_ref, 0.9), np.quantile(r2_fit, 0.9)