In [None]:
%load_ext autoreload

In [None]:
import numpy as np
import statsmodels.api as sm
from statsmodels.genmod.families.links import Link, Log as LogLink
import scipy as sp
import scipy.stats
import matplotlib.pyplot as plt
import matplotlib as mpl
import strainzip as sz
import seaborn as sns

from strainzip import depth_model
import strainzip as sz

import pandas as pd

In [None]:
%autoreload

In [None]:
import graph_tool as gt

def test_unzip_lolipops():
    _graph = gt.Graph()
    _graph.add_edge_list([
        (0, 1), (1, 2), (2, 0), (0, 3), # Out lolipop
        (4, 5), (5, 6), (6, 4), (7, 4), # In lolipop
    ])
    _graph.vp["filter"] = _graph.new_vertex_property("bool", val=True)
    _graph.set_vertex_filter(_graph.vp["filter"])
    gm = sz.graph_manager.GraphManager()
    gm.validate(_graph)
    sz.draw.draw_graph(_graph, ink_scale=1, output_size=(200, 200), vertex_text=_graph.vertex_index)
    unitig_paths = [tuple(u) for u in sz.assembly.iter_maximal_unitig_paths(_graph)]
    assert set(frozenset(u) for u in unitig_paths) == {frozenset([0, 1, 2]), frozenset([4, 5, 6])}
    gm.batch_press(_graph, *[(list(path), {}) for path in unitig_paths])
    sz.draw.draw_graph(_graph, ink_scale=1, output_size=(200, 200), vertex_text=_graph.vertex_index)

    gm.batch_unzip(_graph, (9, [(9, 9), (7, 9)], {}), (8, [(8, 8), (8, 3)], {}))
    sz.draw.draw_graph(_graph, ink_scale=1, output_size=(200, 200), vertex_text=_graph.vertex_index)
    # return sz.stats.degree_stats(_graph).sort_index().reset_index().values
    assert (sz.stats.degree_stats(_graph).sort_index().reset_index().values == np.array([[0., 1., 1.],
       [1., 0., 1.],
       [1., 1., 2.],
       [1., 2., 1.],
       [2., 1., 1.]])).all()

test_unzip_lolipops()

In [None]:
model = sz.depth_model
seed = 1
alpha = 1e-0  # Small offset for handling 0s in depths
n, m = 3, 4  # In-edges / out-edges
s_samples = 10
sigma = 1e-2  # Scale of the multiplicative noise
depth_multiplier = 2  # Scaling factor for depths
num_excess_paths = 1  # How many extra paths to include beyond correct ones.

np.random.seed(seed)

r_edges, p_paths = (n + m, n * m)
X = sz.deconvolution.design_paths(n, m)[0]
assert X.shape == (r_edges, p_paths)

# Select which pairs of in/out edges are "real" and assign them weights across samples.
active_paths = sz.deconvolution.simulate_active_paths(n, m, excess=num_excess_paths)
active_paths = [i for i, _ in active_paths]
print(active_paths)
beta = np.zeros((p_paths, s_samples))
beta[active_paths, :] = np.random.lognormal(
    mean=-1, sigma=4, size=(len(active_paths), s_samples)
)
beta = beta.round(1) * depth_multiplier  # Structural zeros


# Simulate the observed depth of each edge.
expect = X @ (beta)
log_noise = np.random.normal(loc=0, scale=1, size=expect.shape)
y_obs = expect * np.exp(log_noise * sigma)


print(-model.loglik(beta, sigma, y_obs, X, alpha=alpha))

# # Simulate a selection of paths during the estimation procedure.
# # Possibly over-specified. (see `num_excess_paths`)
# _active_paths = list(
#     sorted(
#         set(active_paths)
#         | set(
#             np.random.choice(
#                 [p for p in range(p_paths) if p not in active_paths],
#                 replace=False,
#                 size=num_excess_paths,
#             )
#         )
#     )
# )
# X_reduced = X[:, _active_paths]

# # Estimate model parameters
# beta_est, sigma_est, _ = model.fit(y_obs, X_reduced, alpha=alpha)

# # Calculate likelihood
# loglik = -model.negloglik(beta_est, sigma_est, y_obs, X_reduced, alpha=alpha)
# assert np.isfinite(loglik)

# # Estimate standard errors.
# beta_stderr, sigma_stderr = model.estimate_stderr(
#     y_obs, X_reduced, beta_est, sigma_est, alpha=alpha
# )

# # Check model identifiable.
# assert np.isfinite(beta_stderr).all()
# assert np.isfinite(sigma_stderr)

sns.heatmap(pd.DataFrame(y_obs), norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e4, vmax=1e4), yticklabels=1, cmap='coolwarm')

In [None]:
print(-model.loglik(np.zeros_like(beta), sigma, y_obs, X, alpha=alpha))


In [None]:
sns.heatmap(pd.DataFrame(beta[active_paths, :], index=active_paths), norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e4, vmax=1e4), yticklabels=1, cmap='coolwarm')

In [None]:
%time fit = sz.depth_model.fit(y_obs, X[:,active_paths], alpha=5.)

In [None]:
sns.heatmap(
    pd.DataFrame(fit.beta, index=active_paths),
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e3, vmax=1e3),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
sns.heatmap(
    pd.DataFrame(beta[active_paths] - fit.beta, index=active_paths),
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e4, vmax=1e4),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
sns.heatmap(
    fit.stderr_beta,
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e5, vmax=1e5),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
d = pd.DataFrame(dict(
    depth=pd.DataFrame(beta[active_paths], index=active_paths).stack(),
    depth_est=fit.beta.ravel(),
    err=(fit.beta - beta[active_paths]).ravel(),
    stderr_est=fit.stderr_beta.ravel(),
    )).rename_axis(['path', 'sample']).reset_index()

plt.scatter('depth', 'err', data=d, c='stderr_est', norm=mpl.colors.SymLogNorm(1e-1))
plt.colorbar()

xx = np.logspace(-3, 5)
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.plot(xx, np.zeros_like(xx), lw=1, linestyle='--')
plt.xscale('symlog', linthresh=1e-2)
plt.yscale('symlog', linthresh=1e-2)

In [None]:
plt.scatter('stderr_est', 'err', data=d, c='stderr_est', norm=mpl.colors.SymLogNorm(1e-1))
plt.colorbar()

xx = np.logspace(-4, 5)
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.plot(xx, np.zeros_like(xx), lw=1, linestyle='--')
plt.xscale('symlog', linthresh=1e-1)
plt.yscale('symlog', linthresh=1e-1)

In [None]:
selected_paths, delta_aic = (
    sz.deconvolution.select_paths(
        X,
        y_obs,
        model=sz.depth_model,
        forward_stop=0,
        backward_stop=0,
        alpha=alpha,
        verbose=True,
    )
)
print(set(selected_paths) - set(active_paths), set(selected_paths) & set(active_paths), set(active_paths) - set(selected_paths), )

In [None]:
debug

In [None]:
delta_aic

In [None]:
fit = sz.depth_model.fit(y_obs, X[:, selected_paths], alpha=alpha)

In [None]:
all_paths = list(sorted(set(selected_paths) | set(active_paths)))

In [None]:
y_predict = X[:, selected_paths] @ fit.beta

sns.heatmap(
    pd.DataFrame(y_predict - y_obs),
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e2, vmax=1e2),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
depth_est = pd.DataFrame(fit.beta, index=selected_paths).reindex(all_paths, fill_value=0)
sns.heatmap(
    depth_est,
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e2, vmax=1e2),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
depth = pd.DataFrame(beta[active_paths, :], index=active_paths).reindex(all_paths, fill_value=0)
sns.heatmap(
    depth,
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e2, vmax=1e2),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
err = depth_est - depth
sns.heatmap(
    err,
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e2, vmax=1e2),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
err_est = pd.DataFrame(fit.stderr_beta, index=selected_paths).reindex(all_paths, fill_value=0)
sns.heatmap(
    err_est,
    norm=mpl.colors.SymLogNorm(1e-1, vmin=-1e3, vmax=1e3),
    yticklabels=1,
    cmap='coolwarm',
)

In [None]:
d = pd.DataFrame(dict(
    depth=depth.stack(),
    depth_est=depth_est.stack(),
    err=err.stack(),
    stderr_est=err_est.stack(),
)).rename_axis(['path', 'sample']).reset_index().assign(
    false_positive=lambda x: x.path.isin(set(selected_paths) - set(active_paths)),
    false_negative=lambda x: x.path.isin(set(active_paths) - set(selected_paths)),
)
xx = np.logspace(-1, 5)

plt.scatter('depth', 'err', data=d, c='false_positive')
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.xscale('symlog', linthresh=1e-1)
plt.yscale('symlog', linthresh=1e-1)

In [None]:
d = pd.DataFrame(dict(
    depth=depth.stack(),
    depth_est=depth_est.stack(),
    err=err.stack(),
    stderr_est=err_est.stack(),
)).rename_axis(['path', 'sample']).reset_index().assign(
    false_positive=lambda x: x.path.isin(set(selected_paths) - set(active_paths)),
    false_negative=lambda x: x.path.isin(set(active_paths) - set(selected_paths)),
)
xx = np.logspace(-1, 3)

plt.scatter('stderr_est', 'err', data=d)
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.xscale('symlog', linthresh=1e-1)
plt.yscale('symlog', linthresh=1e-1)