## Preamble

In [None]:
%load_ext autoreload
%autoreload 0

In [None]:
%autoreload

In [None]:
import sys
sys.path.append('/pollard/home/bsmith/Projects/haplo-benchmark/include/StrainFacts')

In [None]:
import xarray as xr
import sqlite3
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import warnings
import torch
import pyro
import scipy as sp

import lib.plot
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering
from lib.pandas_util import idxwhere


import sfacts as sf

# from lib.project_style import color_palette, major_allele_frequency_bins
# from lib.project_data import metagenotype_db_to_xarray
# from lib.plot import ordination_plot, mds_ordination, nmds_ordination
# import lib.plot
# from lib.plot import construct_ordered_pallete
# from lib.pandas_util import idxwhere

## UCFMT Strain Tracking

In [None]:
sf.plot.plot_metagenotype(
    mgen_ucfmt_100022.to_world(),
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
)

In [None]:
# fit_ucfmt_100022 = sf.data.World.load('data/ucfmt.sp-100022.metagenotype.filt-poly05-cvrg15-g2000.fit-sfacts8-s100-seed0.world.nc')
fit_ucfmt_100022 = sf.data.World.load('data/ucfmt.sp-100022.metagenotype.filt-poly05-cvrg05-g2000.fit-sfacts12-s100-g2000-seed0.world.nc')


In [None]:
sf.plot.plot_community(
    fit_ucfmt_100022,
    col_linkage_func=lambda w: w.metagenotypes.linkage("sample"),
)

In [None]:
sf.plot.plot_genotype(
    fit_ucfmt_100022,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    transpose=True,
)

In [None]:
# sample = 'DS0485_002'
sample = 'SS01068'

sf.plot.plot_metagenotype_frequency_spectrum(fit_ucfmt_100022, sample_list=[sample])
plt.yscale('log')

fit_ucfmt_100022.data.sel(sample=[sample]).communities.to_series().sort_values(ascending=False).head(7)

## Single-cell genomics

In [None]:
mgen_ucfmt_104345 = sf.data.Metagenotypes.load('data/ucfmt.sp-104345.metagenotype.filt-poly05-cvrg05.nc')

In [None]:
sf.plot.plot_metagenotype(
    mgen_ucfmt_104345.to_world(),
    col_linkage_func=lambda w: mgen_ucfmt_104345.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['SS01009.m']),
    )),
    scaley=1e-3 if mgen_ucfmt_104345.sizes['position'] > 1e2 else 1e-2
)

In [None]:
drplt_ucfmt_104345 = sf.data.Metagenotypes(
    sf.data.Metagenotypes.load('data/drplt.sp-104345.metagenotype.nc')
    .select_samples_with_coverage(0.01)
    .data.reindex(position=mgen_ucfmt_104345.position, fill_value=0)
)
drplt_ucfmt_104345_to_sample = pd.read_table('meta/drplt_to_sample.tsv')

In [None]:
np.random.seed(0)

nposition = min(drplt_ucfmt_104345.sizes['position'], 1000)

d = drplt_ucfmt_104345.random_sample(position=nposition)

sf.plot.plot_depth(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
#         mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

sf.plot.plot_metagenotype(
    d.to_world(),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['mgen_SS01009.m', 'mgen_SS01057.m']),
        depth=w.metagenotypes.mean_depth().pipe(np.log),
#         mgen_entropy=w.metagenotypes.entropy('sample'),
    )),
)

In [None]:
d = drplt_ucfmt_104345
distance_threshold = 0.2

drplt_agg = pd.Series(
    (
        AgglomerativeClustering(
            affinity='cosine',
            distance_threshold=distance_threshold,
            n_clusters=None,
            linkage='average'
        )
        .fit_predict(d.total_counts())
    ),
    d.sample,
    name='clust',
)

drplt_agg_pal = lib.plot.construct_ordered_pallete(drplt_agg, cm='hsv')

sf.plot.plot_metagenotype(
    d.random_sample(position=1000).to_world(),
    col_colors_func=lambda d: drplt_agg.to_xarray(),
#     col_linkage_func=lambda w: w.metagenotypes.linkage(pseudo=1.),
    row_col_annotation_cmap=mpl.cm.hsv,
)

In [None]:
d = drplt_ucfmt_104345


lib.plot.ordination_plot(
    d.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=drplt_agg.to_frame(),
    colorby='clust',
    color_palette=drplt_agg_pal,
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0),
)
None

In [None]:
drplt_derep_ucfmt_104345 = sf.data.Metagenotypes(
    drplt_ucfmt_104345
    .to_series()
    .reset_index()
    .rename(columns={0: 'tally'})
    .join(drplt_agg, on='sample')
    .join(drplt_ucfmt_104345_to_sample.set_index('lib_id'), on='sample')
    .groupby(['sample_id', 'clust', 'position', 'allele'])
    .sum()
    .astype(int)
    .reset_index()
    .assign(label=lambda x: x.sample_id + '_' + x.clust.astype(str))
    .set_index(['label', 'position', 'allele'])
    .rename_axis(index={'label': 'sample'})
    .tally
    .to_xarray()
)
print(drplt_ucfmt_104345.sizes, drplt_derep_ucfmt_104345.sizes)

In [None]:
# d = drplt_filt_derep.select_samples_with_coverage(0.1).data.

drplt_ucfmt_104345_strain_type = sf.estimation.strain_cluster(
    (
        drplt_derep_ucfmt_104345
        .select_samples_with_coverage(0.05)
        .to_estimated_genotypes(pseudo=1.0)
        .to_world()
    ),
    thresh=0.1,
)
                     
drplt_ucfmt_104345_strain_type_palette = lib.plot.construct_ordered_pallete(sorted(drplt_ucfmt_104345_strain_type.unique()), cm='Spectral')
drplt_ucfmt_104345_strain_type.value_counts()
drplt_ucfmt_104345_strain_type.name = 'drplt_type'

### Matched metagenomes

In [None]:
sf.plot.plot_metagenotype(
    mgen_ucfmt_104345.to_world(),
    col_linkage_func=lambda w: mgen_ucfmt_104345.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['SS01009.m']),
    )),
    scaley=5e-4 if mgen_ucfmt_104345.sizes['position'] > 1e2 else 1e-2
)

In [None]:
nposition = min(mgen_ucfmt_104345.sizes['position'], int(2e3))

np.random.seed(0)

# fit_ucfmt_104345, est_list, history_list = sf.workflow.fit_subsampled_metagenotypes_then_collapse_and_iteratively_refit_genotypes(
fit_ucfmt_104345, history = sf.workflow.fit_metagenotypes_simple(
    sf.model_zoo.NAMED_STRUCTURES['ssdd3_with_error'],
    mgen_ucfmt_104345.random_sample(position=nposition),
#     nposition=nposition,
#     npositionB=int(1e5),
    nstrain=100,
    hyperparameters=dict(
        gamma_hyper=1e-10,
#         rho_hyper=5.0,
#         pi_hyper=1.0,
    ),
    anneal_hyperparameters=dict(
        rho_hyper=dict(name='log', start=10.0, end=0.5, wait_steps=1000),
        pi_hyper=dict(name='log', start=5.0, end=0.5, wait_steps=1000),
    ),
    annealiter=4000,
#     stage2_hyperparameters=dict(gamma_hyper=1.0),
#     diss_thresh=0.02,
#     frac_thresh=1e-3,
    estimation_kwargs=dict(
        jit=True, catch_keyboard_interrupt=True, ignore_jit_warnings=True, maxiter=int(1e6),
    ),
    device='cuda',
#     seed=1,
)

In [None]:
sf.evaluation.metagenotype_error(fit_ucfmt_104345, fit_ucfmt_104345)[0]

In [None]:
sf.plot.plot_loss_history(history)

In [None]:
sample = 'SS01009.m'

sf.plot.plot_metagenotype_frequency_spectrum(mgen_ucfmt_104345.to_world(), sample_list=[sample])
plt.yscale('log')

sample_community = fit_ucfmt_104345.data.sel(sample=sample).communities.to_series().sort_values(ascending=False)
strain_list = idxwhere(sample_community > 0.05)
sample_community.head(7)

In [None]:
# sf.plot.plot_metagenotype(
#     fit_ucfmt_104345,
#     col_colors_func=lambda w: xr.Dataset(dict(
#         focal=w.sample.isin(['SS01009.m']),
#     )),
# )
sf.plot.plot_community(
    fit_ucfmt_104345,
    col_linkage_func=lambda w: mgen_ucfmt_104345.linkage("sample"),
    col_colors_func=lambda w: xr.Dataset(dict(
        focal=w.sample.isin(['SS01009.m']),
    )),
    row_colors_func=lambda w: xr.Dataset(dict(
        focal=w.communities.data.sel(sample='SS01009.m').to_series(),
    )),
)

In [None]:
sf.plot.plot_genotype(
    fit_ucfmt_104345,
#     col_linkage_func=lambda w: mgen_ucfmt_104345.linkage("sample"),
    row_colors_func=lambda w: xr.Dataset(dict(
        focal=w.communities.data.sel(sample='SS01009.m').to_series(),
    )),

)

In [None]:
plt.hist(fit_ucfmt_104345.communities.max("strain").values, bins=np.linspace(0, 1, num=51))
# plt.yscale('log')

In [None]:
e = fit_ucfmt_104345
sample_list = [s for s in e.sample.values if s in ['SS01009.m']]


w = sf.data.Genotypes.concat(dict(
    est=e.genotypes.mlift('sel', strain=strain_list),
#     othere=e.genotypes,
    mgen=e.metagenotypes.mlift('sel', sample=sample_list).mlift('sel', position=e.position).to_estimated_genotypes(),
#     otherm=mgen_filt.mlift('sel', position=e.position).to_estimated_genotypes(),
    drplt=drplt_derep_ucfmt_104345.select_samples_with_coverage(0.05).mlift('sel', position=e.position).to_estimated_genotypes(),
), dim='strain')

def _assign_gtype(x):
    return x.index.to_series().str.split('_').apply(lambda x: x[0])

fig, ax = plt.subplots(figsize=(5, 5))

ax, ordin, *_ = lib.plot.ordination_plot(
    w.pdist(),
    ordin=lib.plot.nmds_ordination,
    meta=pd.DataFrame(dict(
        clust=drplt_ucfmt_104345_strain_type.rename(lambda s: 'drplt_' + s),
        is_est=w.strain.str.startswith('est_').to_series(),
    )).fillna(-1),
    colorby='clust',
    color_palette=drplt_ucfmt_104345_strain_type_palette,
    markerby='is_est',
    marker_palette={True: '>', False: 'o'},
    zorderby='is_est',
    markersizeby='is_est',
    markersize_palette={True: 60, False: 40},
    ordin_kws={'is_dmat': True},
    fill_legend=False,
    scatter_kws=dict(lw=0.5, alpha=0.95),
    ax=ax
)

ax.set_xlabel('PCo1')
ax.set_ylabel('PCo2')

ordin['gtype'] = ordin.index.to_series().str.split('_').apply(lambda x: x[0])
for name, d1 in ordin[ordin.gtype.isin(['est', 'mgen'])].iterrows():
    ax.annotate(name, xy=d1[['PC1', 'PC2']].to_list())
None

In [None]:
e = fit_ucfmt_104345
sample_list = [s for s in e.sample.values if s in ['SS01009.m']]

d = sf.data.Genotypes.concat(
        dict(
        est=e.genotypes.mlift('sel', strain=strain_list),
        other=e.genotypes,
        mgen=(
            mgen_ucfmt_104345
            .mlift('sel', sample=sample_list)
            .mlift('sel', position=e.position)
            .to_estimated_genotypes()
        ),
        drplt=drplt_derep_ucfmt_104345.mlift('sel', position=e.position).to_estimated_genotypes(),
    ),
    dim='strain'
)

sf.plot.plot_genotype(
    d.to_world(),
    col_colors_func=None,
    row_colors_func=lambda w: xr.Dataset(dict(
        fnd=w.strain.str.startswith('est_'),
        mgen=w.strain.str.startswith('mgen_'),
        drplt=w.strain.to_dataframe().join(drplt_ucfmt_104345_strain_type.rename(lambda s: 'drplt_' + s)).drplt_type,
    )),
    col_linkage_func=lambda w: e.metagenotypes.linkage("position"),
    scalex=3e-2 if len(e.position) < 3000 else 1e-3,
    transpose=True,
#     row_linkage_func=lambda w: w.genotypes.linkage(),
#     col_linkage_func=None,
#     col_cluster=False,
#     row_cluster=False,
)

In [None]:
sf.plot.plot_genotype(
    fit_ucfmt_104345,
    col_linkage_func=lambda w: w.metagenotypes.linkage("position"),
    row_colors_func=lambda w: xr.Dataset(dict(
        found=w.strain.isin(strain_list),
    )),
    scalex=1e-2 if len(e.position) < 3000 else 1e-3,
    transpose=True,
)

In [None]:
# stem = 'data/ucfmt.sp-104345.metagenotype.filt-poly05-cvrg05.fit-sfacts12-s100-g2000-seed3'
# _fit_ucfmt_104345 = sf.data.World.load(f'{stem}.world.nc')

_fit_ucfmt_104345 = fit_ucfmt_104345

_drplt = drplt_derep_ucfmt_104345.select_samples_with_coverage(0.05).mlift('sel', position=_fit_ucfmt_104345.position).to_estimated_genotypes().to_world()
_fit = _fit_ucfmt_104345.genotypes.to_world()
_fit_and_found = _fit_ucfmt_104345.genotypes.mlift('sel', strain=strain_list).to_world()
_mgen = _fit_ucfmt_104345.metagenotypes.mlift('sel', sample=sample_list).to_estimated_genotypes().to_world()

_fit_disc = _fit_ucfmt_104345.genotypes.discretized().to_world()
_fit_and_found_disc = _fit_ucfmt_104345.genotypes.mlift('sel', strain=strain_list).discretized().to_world()
_mgen_disc = _fit_ucfmt_104345.metagenotypes.mlift('sel', sample=sample_list).to_estimated_genotypes().discretized().to_world()


bins = np.linspace(0, 0.3, num=51)
# plt.hist(sf.evaluation.match_genotypes(_drplt, _fit)[1], bins=bins, color='tab:green', label='estimated (all)', histtype='step', linestyle='-')
plt.hist(sf.evaluation.match_genotypes(_drplt, _fit_and_found)[1], bins=bins, color='tab:blue', label='estimated', histtype='step', linestyle='-')
plt.hist(sf.evaluation.match_genotypes(_drplt, _mgen)[1], bins=bins, color='tab:red', label='metagenotype', histtype='step', linestyle='-')
# plt.hist(sf.evaluation.match_genotypes(_drplt, _fit_disc)[1], bins=bins, color='tab:green', histtype='step', linestyle='--')
plt.hist(sf.evaluation.match_genotypes(_drplt, _fit_and_found_disc)[1], bins=bins, color='tab:blue', histtype='step', linestyle='--')
plt.hist(sf.evaluation.match_genotypes(_drplt, _mgen_disc)[1], bins=bins, color='tab:red', histtype='step', linestyle='--')

plt.hist([], bins=bins, color='black', histtype='step', linestyle='-', label='fuzzy')
plt.hist([], bins=bins, color='black', histtype='step', linestyle='--', label='discretized')
plt.legend()

In [None]:
# Genotypes observed in droplets are more similar to the genotypes inferred by sfacts than
# a naive genotype constructed from the observed metagenotype, especially when considering the
# discretized (consensus sequences)

print(sf.evaluation.match_genotypes(_drplt, _fit_and_found)[1].mean(), sf.evaluation.match_genotypes(_drplt, _fit_and_found_disc)[1].mean())
print(sf.evaluation.match_genotypes(_drplt, _fit)[1].mean(), sf.evaluation.match_genotypes(_drplt, _fit_disc)[1].mean())

## Large-scale genetics

### 102492

In [None]:
gtpro_fit_102492 = sf.data.World.load('data/zshi.sp-102492.metagenotype.filt-poly05-cvrg25-g500.fit-sfacts13-s500-g500-seed4.world.nc')
gtpro_fit_102492.sizes

In [None]:
plt.hist(gtpro_fit_102492.genotypes.values.flatten(), bins=np.linspace(0, 1, num=51))
# plt.yscale('log')

In [None]:
plt.hist(gtpro_fit_102492.communities.max("strain").values.flatten(), bins=np.linspace(0, 1, num=51))
# plt.yscale('log')

In [None]:
sf.plot.plot_genotype(gtpro_fit_102492, scaley=2e-2, yticklabels=0)

In [None]:
gtpro_fit_102492_filt = gtpro_fit_102492.genotypes.mlift('sel', strain=idxwhere(gtpro_fit_102492.genotypes.entropy().to_series() < 0.25))

In [None]:
sf.plot.plot_genotype(gtpro_fit_102492_filt.to_world(), scaley=2e-2, yticklabels=0)

#### Reference genomes

In [None]:
gtpro_ref_102492_mgen = sf.data.Metagenotypes.load('data/gtprodb.sp-102492.genotype.nc')
gtpro_ref_102492 = gtpro_ref_102492_mgen.to_estimated_genotypes(pseudo=0)

In [None]:
plt.hist(gtpro_fit_102492.metagenotypes.to_estimated_genotypes(pseudo=1).entropy(), bins=np.linspace(0, 1, num=51))
plt.hist(gtpro_fit_102492.metagenotypes.to_estimated_genotypes(pseudo=0.01).entropy(), bins=np.linspace(0, 1, num=51))
plt.hist(gtpro_fit_102492.genotypes.entropy(), bins=np.linspace(0, 1, num=51))

# Estimated genotypes are much lower entropy than the estimates we would have gotten from
# metagenotypes directly (even using a small-pseudocount approach).

In [None]:
gtpro_mgen_102492_genotypes = gtpro_fit_102492.metagenotypes.to_estimated_genotypes()
gtpro_mgen_102492_genotypes_filt = gtpro_mgen_102492_genotypes.mlift('sel', strain=idxwhere(gtpro_mgen_102492_genotypes.entropy().to_series() < 0.25))


In [None]:
g_fit = gtpro_fit_102492_filt
g_ref = gtpro_ref_102492.mlift('sel', position=g_fit.position)

g = sf.data.Genotypes.concat(dict(
    fit=g_fit,
    ref=g_ref,
), dim='strain')

sf.plot.plot_genotype(
    g.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        fit=w.strain.str.startswith('fit_'),
    )),
    scaley=4e-2,
    scalex=4e-2,
    yticklabels=0,
)

In [None]:
sf.plot.plot_genotype(
    g.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        fit=w.strain.str.startswith('fit_'),
    )),
    scaley=4e-3,
    scalex=1e-2,
    yticklabels=0,
)

### 102506

In [None]:
# gtpro_fit_102506 = sf.data.World.load('data/zshi.sp-102506.metagenotype.filt-poly05-cvrg75-g500.fit-sfacts13-s200-g500-seed1.world.nc')
gtpro_fit_102506 = sf.data.World.load('data/zshi.sp-102506.metagenotype.filt-poly05-cvrg25-g500.fit-sfacts13-s500-g500-seed0.world.nc')
gtpro_fit_102506.sizes

In [None]:
plt.hist(gtpro_fit_102506.genotypes.values.flatten(), bins=np.linspace(0, 1, num=51))
# plt.yscale('log')

In [None]:
plt.hist(gtpro_fit_102506.communities.max("strain").values.flatten(), bins=np.linspace(0, 1, num=51))
# plt.yscale('log')

In [None]:
sf.plot.plot_genotype(gtpro_fit_102506, scaley=2e-2, yticklabels=0)

In [None]:
gtpro_fit_102506_filt = gtpro_fit_102506.genotypes.mlift('sel', strain=idxwhere(gtpro_fit_102506.genotypes.entropy().to_series() < 0.25))

In [None]:
sf.plot.plot_genotype(gtpro_fit_102506_filt.to_world(), scaley=2e-2, yticklabels=0)

#### Reference genomes

In [None]:
gtpro_ref_102506_mgen = sf.data.Metagenotypes.load('data/gtprodb.sp-102506.genotype.nc')
gtpro_ref_102506 = gtpro_ref_102506_mgen.to_estimated_genotypes(pseudo=0)

In [None]:
g_fit = gtpro_fit_102506_filt
g_ref = gtpro_ref_102506.mlift('sel', position=g_fit.position)

g = sf.data.Genotypes.concat(dict(
    fit=g_fit,
    ref=g_ref,
), dim='strain')

sf.plot.plot_genotype(
    g.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        fit=w.strain.str.startswith('fit_'),
    )),
    scaley=4e-2,
    scalex=4e-2,
    yticklabels=0,
)

# This cluster-diagram shows that my reconstructions are consistent with what's in the database for E. coli
# I don't get any SUPER different strains than what's already been seen (but this may not be the case for less
# well studied organisms.

In [None]:
sf.plot.plot_genotype(
    g.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        fit=w.strain.str.startswith('fit_'),
    )),
    scaley=4e-3,
    scalex=1e-2,
    yticklabels=0,
)

# I'm also able to recapitulate many of the same divisions that are found in the reference data
# I have strains from every part of the tree.
# Just looking at it, I should get similar estimates for LD
# (and presumably for LD decay with distance along the genome)

In [None]:
position_meta_102506 = pd.read_table(
    'ref/gtpro/variants_main.covered.hq.snp_dict.tsv',
    names=['species_id', 'position', 'contig', 'contig_position', 'ref', 'alt']
).set_index('position')[lambda x: x.species_id.isin([102506])]

position_meta_102506

In [None]:
g = gtpro_ref_102506
strain_diss = pairwise_distances(g.values, metric='euclidean', n_jobs=12)
strain_diss = (strain_diss**2 / g.sizes['position'])
strain_diss = pd.DataFrame(strain_diss, index=g.strain.to_series(), columns=g.strain.to_series())

In [None]:
clust = AgglomerativeClustering(n_clusters=None, distance_threshold=0.02, affinity='precomputed', linkage='complete').fit_predict(strain_diss)
clust = pd.Series(clust, index=strain_diss.index)

In [None]:
clust.value_counts()

In [None]:
gtpro_ref_102506_derep = sf.data.Genotypes(gtpro_ref_102506.data.groupby(clust.to_xarray()).mean().rename({'group': 'strain'}))

g_fit = gtpro_fit_102506_filt
g_ref = gtpro_ref_102506_derep.mlift('sel', position=g_fit.position)

g = sf.data.Genotypes.concat(dict(
    fit=g_fit,
    ref=g_ref,
), dim='strain')

sf.plot.plot_genotype(
    g.to_world(),
    row_colors_func=lambda w: xr.Dataset(dict(
        fit=w.strain.str.startswith('fit_'),
    )),
    scaley=4e-2,
    scalex=4e-2,
    yticklabels=0,
)

# This cluster-diagram shows that my reconstructions are consistent with what's in the database for E. coli
# I don't get any SUPER different strains than what's already been seen (but this may not be the case for less
# well studied organisms.

In [None]:
d = strain_diss.iloc[:300, :300]

strain_clust_palette = lib.plot.construct_ordered_pallete(range(clust.max()), cm='rainbow')

_ = lib.plot.ordination_plot(
    d,
    meta=clust.to_frame(name='clust'),
    colorby='clust',
    color_palette=strain_clust_palette,
    ordin=lib.plot.nmds_ordination,
    ordin_kws=dict(is_dmat=True),
    scatter_kws=dict(lw=0),
)

plt.legend([])

In [None]:
# TODO: Dereplicate reference strains at some dissimilarity threshold

In [None]:
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics.pairwise import pairwise_distances

g = gtpro_ref_102506

# x = pairwise_distances(g.values.T, metric='correlation', n_jobs=100)
# ref_position_diss = (1 - )**2
# ref_position_diss = pd.DataFrame(ref_position_diss, index=g.position.to_series(), columns=g.position.to_series())

In [None]:
position_meta_102506.groupby('contig').apply(len).sort_values(ascending=False).head()

In [None]:
ld = {}
for contig, pos in position_meta_102506.groupby('contig'):
    print(contig)
    g = gtpro_ref_102506.sel(position=pos.index)
    r2 = (1 - pdist(g.values.T, 'correlation'))**2
    x = pdist(np.expand_dims(pos.contig_position.values, axis=1), 'cityblock')
    ld[contig] = (x, r2)
ld = pd.DataFrame(np.concatenate([np.stack([x, r2], axis=1) for x, r2 in ld.values()]), columns=['x', 'r2'])

In [None]:
stepsize = 1
right = 100

d = ld[ld.x < right]

bins = {}
for start in range(0, right, stepsize):
    stop = start + stepsize
    bins[start] = d[(d.x >= start) & (d.x < stop)].r2.mean()
    
plt.scatter(
    x='x',
    y='r2',
    data=d,
    s=1,
    alpha=0.05,
    color='black',
    label='__nolegend__',
)
plt.scatter([], [], s=10, color='black', label='Locus Pair')
plt.plot(pd.Series(bins), color='red', label='Mean LD (25 bp Bin)')
plt.axhline(0, lw=1, color='red', linestyle='--')
plt.ylabel(r"LD")
plt.xlabel("Distance")
plt.legend(bbox_to_anchor=(0.85, 1.15), ncol=2)

print(sp.stats.spearmanr(d['x'], d['r2']))