## Preamble

In [None]:
import pandas as pd
from lib.util import info, idxwhere
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp

from scipy.spatial.distance import pdist, squareform

import pyro
import pyro.distributions as dist
import torch
from functools import partial
import arviz as az
from pyro.ops.contract import einsum
import seaborn as sns
from tqdm import tqdm
import xarray as xr

import sqlite3

def rss(x, y):
    return np.sqrt(np.sum((x - y)**2))

def binary_entropy(p):
    q = (1 - p)
    return -p * np.log2(p) - q * np.log2(q)

def plot_loss_history(loss_history):
    min_loss = loss_history.min()
    plt.plot(loss_history - min_loss)
    plt.plot(
        np.linspace(0, len(loss_history), num=1000),
        np.linspace(len(loss_history), 0, num=1000),
        lw=1, linestyle='--', color='grey'
        )
    plt.title(f'+{min_loss:0.3e}')
    plt.yscale('log')
    return plt.gca()

def mean_residual_count(expect_frac, obs_count, m):
    frac_obs = obs_count / m
    out = np.abs(((frac_obs - expect_frac)))
    out[np.isnan(out)] = 0
    return (out * m).sum() / m.sum()

def linear_distance(linear_index):
    linear_index = linear_index.to_frame()
    return pd.DataFrame(
        squareform(
            pdist(
                linear_index,
                metric='cityblock'
            )
        ),
        index=linear_index.index,
        columns=linear_index.index,
    )

mpl.rcParams['figure.dpi']= 120

## Metadata

In [None]:
position_meta_ = pd.read_table(
    '/pollard/data/gt-pro-db/variants_main.covered.hq.snp_dict.tsv',
    names=['species_id', 'position', 'contig', 'contig_position', 'ref', 'alt']
).set_index('position')

## Load data

In [None]:
species_id = 100022

fit = xr.open_dataset(f'data/core.sp-{species_id}.gtpro-pileup.filt.sfacts-s3000-g6000-gamma3-rho1-pi2-eps1000-alph100.nc')
fit

In [None]:
position_meta = position_meta_[lambda x: x.species_id == species_id]

## Check fit

In [None]:
plot_loss_history(fit.elbo_trace.values)

In [None]:
np.abs(fit.y - (fit.p * fit.m)).sum() / fit.m.sum()

In [None]:
plt.hist(fit.pi.max('strain'), bins=np.linspace(0, 1, num=11))
#plt.yscale('log')
None

In [None]:
plt.plot(fit.pi.max('library_id').to_series().sort_values(ascending=False).values)
plt.axhline(0, lw=1, c='k', linestyle='--')

In [None]:
plt.hist(np.log10((fit.pi > 0.1).sum('library_id')), bins=50)

In [None]:
#fit.gamma.sel(strain=fit.pi.max('library_id'))

In [None]:
plt.hist(fit.pi.max('library_id'), bins=20)
plt.yscale('log')
None

In [None]:
plt.hist((fit.pi > 0.1).sum('library_id'), bins=50)
plt.yscale('log')

## Missing genome fragments:

In [None]:
d = ((fit.y / fit.m) * 2 - 1).to_pandas().sort_index(axis=1).T.fillna(0)
libs = idxwhere((fit.m.mean('position') > 1).to_series())

#fig = plt.figure(figsize=(3, 5))
sns.clustermap(d[libs], metric='cosine', cmap='coolwarm', vmin=-1, vmax=1)

## Linkage Disequilibrium

In [None]:
sns.clustermap(
    (
        ((fit.gamma * 2) - 1)
        .sel(
            strain=idxwhere(
                ((fit.pi > 0.5)
                 .sum('library_id') > 4
                ).to_series()),
        ).T
    ).to_pandas().sort_index(0),
    metric='cosine',
    cmap='coolwarm',
    row_cluster=True,
    
)

In [None]:
def pos_psim(gamma):
    gamma = gamma.to_pandas()
    return pd.DataFrame((1 - squareform(pdist((gamma), metric='correlation')))**2,
                        index=gamma.index,
                        columns=gamma.index)

# def geno_cosine_psim(gamma):
#     gamma = gamma.to_pandas()
#     return pd.DataFrame((1 - squareform(pdist((gamma * 2) - 1, metric='cosine'))),
#                         index=gamma.index,
#                         columns=gamma.index)

# strain_sim = geno_cosine_psim(fit.gamma)
position_sim = pos_psim(fit.gamma.T).sort_index().sort_index(1)

In [None]:
snp_info = (
    position_meta
    .groupby('contig')
    .apply(len)
    .to_frame(name='total_count')
    .assign(
        fit_count=
        position_meta.loc[fit.gamma.position]
        .groupby('contig')
        .apply(len)
    ).fillna(0)
).sort_values('fit_count', ascending=False)

snp_info.head(10)

In [None]:
position_ldist_ = linear_distance(
    position_meta.loc[fit.position]['contig_position']
).sort_index().sort_index(1)

In [None]:
import patsy

same_contig = pd.DataFrame(
    1 - squareform(
        pdist(
            patsy.dmatrix(
                'contig - 1', data=position_meta.loc[fit.position]['contig'].to_frame(), return_type='dataframe'
            ),
            'jaccard'),
    ),
    index=fit.position, columns=fit.position,

)
#sns.heatmap(same_contig.sort_index().sort_index(1))

### All strains

In [None]:
position_ldist = position_ldist_.stack().where(same_contig.stack().astype(bool), np.nan).unstack()

In [None]:
assert position_ldist.shape == position_sim.shape
assert np.all(position_ldist.index == position_sim.index)
assert np.all(position_ldist.columns == position_sim.columns)

In [None]:
sns.heatmap(position_ldist.iloc[:1000,:1000], norm=mpl.colors.PowerNorm(1/3))

In [None]:
d = pd.DataFrame(dict(
    dist=squareform(position_ldist.fillna(-1)),
    rsq=(1 - squareform(1 - position_sim.loc[position_ldist.index, position_ldist.index].values)),
)).assign(dist=lambda x: x.dist.replace({-1: np.nan}))

sns.jointplot(x='dist', y='rsq', data=d, kind='hex',
              joint_kws=dict(norm=mpl.colors.SymLogNorm(linthresh=1)))

In [None]:
for bin_size in [10, 25, 50, 100, 200, 500, 1000, 5000, 10000]:
    print(
        bin_size,
        sp.stats.pearsonr(d[d.dist < bin_size].dist, d[d.dist < bin_size].rsq)[0],
        sp.stats.spearmanr(d[d.dist < bin_size].dist, d[d.dist < bin_size].rsq).correlation,
        sp.stats.spearmanr(d[d.dist < bin_size].dist, d[d.dist < bin_size].rsq).pvalue,
        sep='\t'
    )

In [None]:
sns.jointplot(x='dist', y='rsq', data=d[d.dist < 100], kind='hex',
              joint_kws=dict(norm=mpl.colors.SymLogNorm(linthresh=10)))

In [None]:
plt.hist(
    d.rsq,
    orientation='horizontal',
    bins=np.linspace(0, 1, num=51),
)

In [None]:
sns.regplot(
    x='dist', y='rsq', data=d[d.dist < 200], color='black', logx=True, scatter_kws=dict(s=1, alpha=0.25), x_jitter=True,
    line_kws=dict(color='darkblue'), ci=0,
)

### Within strain clusters

In [None]:
from sklearn.cluster import AgglomerativeClustering
from scripts.strain_facts import genotype_distance
from scipy.spatial.distance import pdist, squareform

geno_dist = squareform(
    pdist(fit["gamma"], metric=genotype_distance)
)

info("Clustering.")
clust = pd.Series(
    AgglomerativeClustering(
        n_clusters=None,
        affinity="precomputed",
        linkage="complete",
        distance_threshold=0.05,
    )
    .fit(geno_dist)
    .labels_
)

In [None]:
clust.value_counts()

In [None]:
sns.clustermap(
    (
        ((fit.gamma * 2) - 1)
        .sel(
            strain=idxwhere(clust.isin([8, 4, 50]))
        ).T
    ).to_pandas().sort_index(0),
    metric='cosine',
    cmap='coolwarm',
    row_cluster=True,
)

In [None]:
psim = pos_psim(fit.gamma.T).sort_index().sort_index(1)

d = pd.DataFrame(dict(
    dist=squareform(position_ldist.fillna(-1)),
    rsq=(1 - squareform(1 - psim.loc[position_ldist.index, position_ldist.index].values)),
)).assign(dist=lambda x: x.dist.replace({-1: np.nan}))

sns.regplot(x='dist', y='rsq', data=d[d.dist < 200], logx=True, scatter_kws=dict(s=1, alpha=0.5, color='black'), x_jitter=True)

In [None]:
psim = pos_psim(fit.gamma.sel(strain=idxwhere(clust.isin([8]))).T).sort_index().sort_index(1)

d = pd.DataFrame(dict(
    dist=squareform(position_ldist.fillna(-1)),
    rsq=(1 - squareform(1 - psim.loc[position_ldist.index, position_ldist.index].values)),
)).assign(dist=lambda x: x.dist.replace({-1: np.nan}))

sns.regplot(x='dist', y='rsq', data=d[d.dist < 200], logx=True, scatter_kws=dict(s=1, alpha=0.5, color='black'), x_jitter=True)

In [None]:
psim = pos_psim(fit.gamma.sel(strain=idxwhere(clust.isin([4]))).T).sort_index().sort_index(1)

d = pd.DataFrame(dict(
    dist=squareform(position_ldist.fillna(-1)),
    rsq=(1 - squareform(1 - psim.loc[position_ldist.index, position_ldist.index].values)),
)).assign(dist=lambda x: x.dist.replace({-1: np.nan}))

sns.regplot(x='dist', y='rsq', data=d[d.dist < 200], logx=True, scatter_kws=dict(s=1, alpha=0.5, color='black'), x_jitter=True)

### Alternative functional forms

In [None]:
sp.special.sigm

In [None]:
import statsmodels.formula.api as smf


psim = pos_psim(fit.gamma.T).sort_index().sort_index(1)

d = pd.DataFrame(dict(
    dist=squareform(position_ldist.fillna(-1)),
    rsq=(1 - squareform(1 - psim.loc[position_ldist.index, position_ldist.index].values)),
)).assign(dist=lambda x: x.dist.replace({-1: np.nan}))

dmax = 100

# Scatter
plt.scatter('dist', 'rsq', data=d[d.dist < dmax], s=1, alpha=0.25, color='black')

# Best fit
xx = np.linspace(1, dmax, num=dmax)

lm = smf.ols('sp.special.logit(rsq) ~ np.log(dist)', data=d, subset=d.dist < dmax).fit()
yy = sp.special.expit(lm.params['Intercept'] + lm.params['np.log(dist)'] * np.log(xx))
plt.plot(xx, yy)
plt.xscale('log')
plt.yscale('logit')

lm.summary()

In [None]:
sns.regplot(lm.predict(), lm.resid_pearson, lowess=True, scatter_kws=dict(s=1, alpha=0.5, color='black'), x_jitter=True)
plt.axhline(0, lw=1, linestyle='--', color='grey')