## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sfacts as sf

In [None]:
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import scipy as sp
import pyro
import pyro.distributions as dist
import torch
from functools import partial
from tqdm import tqdm
import xarray as xr
import warnings
from torch.jit import TracerWarning

In [None]:
n = 100
p = torch.tensor(1 - 1e-8, requires_grad=True)

torch.allclose(
    sf.model.stickbreaking_betas_to_probs(torch.ones(n) * p),
    sf.model.stickbreaking_betas_to_probs2(torch.ones(n) * p),
)

In [None]:
(
    torch.autograd.grad(sf.model.stickbreaking_betas_to_probs(torch.ones(n) * p)[2], inputs=p)[0],
    torch.autograd.grad(sf.model.stickbreaking_betas_to_probs2(torch.ones(n) * p)[2], inputs=p)[0]
)

## Library

## Model Specification

In [None]:
epsilon_hyper_alpha, epsilon_hyper_beta = 1.5, 1.5 / 0.01
plt.hist(pyro.sample('epsilon_hyper', dist.Beta(epsilon_hyper_alpha, epsilon_hyper_beta).expand([10000])).cpu().numpy(), bins=100)
None

In [None]:
plt.hist(pyro.sample('test', sf.model.NegativeBinomialReparam(torch.tensor(10.), r=torch.tensor(1.), eps=torch.tensor(1e-5)).expand([1000])).numpy())

In [None]:
sf.pyro_util.shape_info(sf.model.model, n=100, g=200, s=20)

## Simulation

### SimShape-1: Small study

In [None]:
seed = 1
pyro.util.set_rng_seed(seed)

n_sim = 100
g_sim = 5000
s_sim = 20

sim1 = sf.model.simulate(
    sf.model.condition_model(
        sf.model.model,
        data=dict(
            alpha_hyper_mean=100.
        ),
        n=n_sim,
        g=g_sim,
        s=s_sim,
        gamma_hyper=0.01,
        delta_hyper_temp=0.01,
        delta_hyper_p=0.7,
        pi_hyper=0.5,
        rho_hyper=10.,
        mu_hyper_mean=2.,
        mu_hyper_scale=0.5,
        m_hyper_r=10.,
        alpha_hyper_scale=0.5,
        epsilon_hyper_alpha=1.5,
        epsilon_hyper_beta=1.5/0.01,
        device='cpu'
    )
)

## Visualization

In [None]:
n_plt = 100
g_plt = 200
s_plt = 20

In [None]:
sf.plot.plot_community(sim1['pi'][:s_plt, :n_plt])

In [None]:
sf.plot.plot_genotype(
    sf.genotype.counts_to_p_estimate(
        sim1['y'][:n_plt, :g_plt],
        sim1['m'][:n_plt, :g_plt]),
    linkage_kw=dict(progress=True)
)

In [None]:
sf.plot.plot_genotype_similarity(sf.genotype.counts_to_p_estimate(sim1['y'][:n_plt, :g_plt], sim1['m'][:n_plt, :g_plt]), linkage_kw=dict(progress=True))

In [None]:
sf.plot.plot_genotype(sim1['gamma'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['delta'][:s_plt, :g_plt])

In [None]:
sf.plot.plot_missing(sim1['nu'][:n_plt, :g_plt])

In [None]:
sns.clustermap(sim1['m'][:n_plt, :g_plt], norm=mpl.colors.SymLogNorm(linthresh=1))

In [None]:
sf.plot.plot_genotype_similarity(sim1['gamma'][:s_plt, :g_plt])

In [None]:
plt.hist(sim1['epsilon'], bins=50)
None

In [None]:
plt.hist(sim1['alpha'], bins=20)
None

## Estimation

### Initialization

In [None]:
g_fit = 1000  # sim1['y'].shape[1]
n_fit = sim1['y'].shape[0]

sim1_gamma_init, sim1_pi_init, sim1_cdmat = sf.estimation.initialize_parameters_by_clustering_samples(
    sim1['y'][:n_fit, :g_fit],
    sim1['m'][:n_fit, :g_fit],
    thresh=0.05,
    additional_strains_factor=0.,
    progress=True,
)

print(sim1_pi_init.shape)

In [None]:
sf.plot.plot_genotype(sim1_gamma_init[:s_plt, :g_plt])

In [None]:
sf.plot.plot_genotype_similarity(sim1_gamma_init)

In [None]:
sf.plot.plot_community(sim1_pi_init[:n_plt, :s_plt])

### Fitting

In [None]:
s_fit = sim1_gamma_init.shape[0]
initialize_params = dict(gamma=sim1_gamma_init, pi=sim1_pi_init)

sim1_fit1, history = sf.estimation.estimate_parameters(
    sf.model.model,
    data=dict(y=sim1['y'][:, :g_fit], m=sim1['m'][:, :g_fit]),
    n=n_fit,
    g=g_fit,
    s=s_fit,
    gamma_hyper=0.1,
    pi_hyper=1.0,
    rho_hyper=0.5,
    mu_hyper_mean=5,
    mu_hyper_scale=5.,
    m_hyper_r=10.,
    delta_hyper_temp=0.1,
    delta_hyper_p=0.9,
    alpha_hyper_hyper_mean=100.,
    alpha_hyper_hyper_scale=10.,
    alpha_hyper_scale=0.5,
    epsilon_hyper_alpha=1.5,
    epsilon_hyper_beta=1.5 / 0.01,
    initialize_params=initialize_params,
    device='cpu',
    lag=100,
    lr=1e-1,
)

### Merging Strains

In [None]:
sim1_fit1_gamma_merge, sim1_fit1_pi_merge, sim1_fit1_delta_merge  = sf.estimation.merge_similar_genotypes(
    sim1_fit1['gamma'],
    sim1_fit1['pi'],
    delta=sim1_fit1['delta'],
    thresh=0.1,
)

# print(sim1_gamma_init.shape[0], sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])
print(sim1_fit1['gamma'].shape[0], sim1_fit1_gamma_merge.shape[0])

## Evaluation

In [None]:
sim1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1['gamma'][:, :g_fit], sim1['delta'][:, :g_fit])
sim1_fit1_gamma_adjusted = sf.genotype.mask_missing_genotype(sim1_fit1['gamma'], sim1_fit1['delta'])

### Ground Truth

#### Visualization

In [None]:
sf.plot.plot_genotype_comparison(
    data=dict(
        true=sim1_gamma_adjusted[:, :g_plt],
#         fit=sim1_fit1['gamma'][:, :g_plt],
        adj=sim1_fit1_gamma_adjusted[:, :g_plt],
#         init=sim1_gamma_init,
#         merg=sim1_fit1_gamma_merge,
    ),
    linkage_kw=dict(progress=True),
)

In [None]:
sf.plot.plot_community_comparison(
    data=dict(
        true=sim1['pi'],
        fit=sim1_fit1['pi'],
#         init=sim1_pi_init,
#         merg=sim1_fit1_pi_merge,
    ),
)

In [None]:
plt.scatter(sim1['epsilon'], sim1_fit1['epsilon'])
plt.plot([0, 0.04], [0, 0.04])

In [None]:
plt.scatter(sim1['alpha'], sim1_fit1['alpha'])
plt.plot([0, 200], [0, 200])

In [None]:
sns.heatmap(sim1_fit1['delta'], vmin=0, vmax=1)

In [None]:
plt.scatter(sim1['mu'], sim1_fit1['mu'])
plt.plot([0, 40], [0, 40])

In [None]:
# TODO: Plot comparing genotype accuracy to true strain abundance
# colored by mean entropy of the estimated genotype masked by delta

#### Fit scores

In [None]:
plt.scatter(sim1_fit1['alpha'], sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']))
sample_mean_masked_genotype_entropy(sim1_fit1['pi'], sim1_fit1['gamma'], sim1_fit1['delta']).mean()

In [None]:
best_hit, best_dist = match_genotypes(sim1_gamma_adjusted[:, :g_fit], sim1_fit1_gamma_adjusted[:, :g_fit])

print('weighted_mean_distance:', (best_dist * sim1['pi'].mean(0)).sum())
plt.scatter((sim1['pi'] * sim1['mu'].reshape(-1, 1)).sum(0), best_dist)

In [None]:
bc_sim = 1 - pdist(sim1['pi'], metric='braycurtis')
bc_fit = 1 - pdist(sim1_fit1['pi'], metric='braycurtis')
plt.scatter(
    bc_sim,
    bc_fit,
    marker='.',
    alpha=0.2,
)

community_accuracy_test(sim1['pi'], sim1_fit1['pi'])

### No Ground Truth

#### Visualization

In [None]:
# Strains that are not representative of true haplotypes
# are high entropy (even after masking with delta)
# and have low estimated total coverage.

best_true_strain, best_true_strain_dist = match_genotypes(sim1_fit1_gamma_adjusted[:, :g_fit], sim1_gamma_adjusted[:, :g_fit])
best_true_strain_dist

plt.scatter((sim1_fit1['pi'] * sim1_fit1['mu'].reshape((-1, 1))).sum(0), best_true_strain_dist, c=mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']))

In [None]:
plot_genotype(sim1_fit1_gamma_adjusted[:, :g_plt], linkage_kw=dict(progress=True))

In [None]:
plot_community(sim1_fit1['pi'])

#### Confidence Scores

In [None]:
plt.hist(mean_masked_genotype_entropy(sim1_fit1['gamma'][:, :g_plt], sim1_fit1['delta'][:, :g_plt]))
None

In [None]:
plt.hist(sim1_fit1['alpha'], bins=20)
None

In [None]:
plot_genotype(sim1_fit1['gamma'][mean_masked_genotype_entropy(sim1_fit1['gamma'], sim1_fit1['delta']) < 0.1, :g_fit])