In [None]:
%load_ext autoreload

In [None]:
import numpy as np
import statsmodels.api as sm
from statsmodels.genmod.families.links import Link, Log as LogLink
import scipy as sp
import scipy.stats
import matplotlib.pyplot as plt
import matplotlib as mpl
import strainzip as sz
import seaborn as sns

from strainzip import model_zoo
import strainzip as sz

import pandas as pd

In [None]:
%autoreload

In [None]:
model = sz.model_zoo.multiplicative_gaussian_noise
seed = 0
alpha = 1e-0  # Small offset for handling 0s in depths
n, m = 3, 4  # In-edges / out-edges
s_samples = 4
sigma = 1e-1  # Scale of the multiplicative noise
depth_multiplier = 1  # Scaling factor for depths
num_excess_paths = 1  # How many extra paths to include beyond correct ones.

np.random.seed(seed)

r_edges, p_paths = (n + m, n * m)
X = sz.deconvolution.design_paths(n, m)[0].T
assert X.shape == (r_edges, p_paths)

# Select which pairs of in/out edges are "real" and assign them weights across samples.
active_paths = sz.deconvolution.simulate_active_paths(n, m, excess=num_excess_paths)
active_paths = [i for i, _ in active_paths]
print(active_paths)
beta = np.zeros((p_paths, s_samples))
beta[active_paths, :] = np.random.lognormal(
    mean=-5, sigma=7, size=(len(active_paths), s_samples)
)
beta = beta.round(1)  # Structural zeros


# Simulate the observed depth of each edge.
expect = X @ (beta * depth_multiplier)
log_noise = np.random.normal(loc=0, scale=1, size=expect.shape)
y_obs = expect * np.exp(log_noise * sigma)


print(-model.negloglik(beta, sigma, y_obs, X, alpha=alpha))

# # Simulate a selection of paths during the estimation procedure.
# # Possibly over-specified. (see `num_excess_paths`)
# _active_paths = list(
#     sorted(
#         set(active_paths)
#         | set(
#             np.random.choice(
#                 [p for p in range(p_paths) if p not in active_paths],
#                 replace=False,
#                 size=num_excess_paths,
#             )
#         )
#     )
# )
# X_reduced = X[:, _active_paths]

# # Estimate model parameters
# beta_est, sigma_est, _ = model.fit(y_obs, X_reduced, alpha=alpha)

# # Calculate likelihood
# loglik = -model.negloglik(beta_est, sigma_est, y_obs, X_reduced, alpha=alpha)
# assert np.isfinite(loglik)

# # Estimate standard errors.
# beta_stderr, sigma_stderr = model.estimate_stderr(
#     y_obs, X_reduced, beta_est, sigma_est, alpha=alpha
# )

# # Check model identifiable.
# assert np.isfinite(beta_stderr).all()
# assert np.isfinite(sigma_stderr)

In [None]:
sns.heatmap(pd.DataFrame(beta[active_paths, :], index=active_paths), norm=mpl.colors.SymLogNorm(1, vmin=-5e7, vmax=5e7), yticklabels=1, cmap='coolwarm')

In [None]:
selected_paths, beta_est, beta_stderr, sigma_est, sigma_stderr, inv_hessian, fit = (
    sz.deconvolution.estimate_paths(
        X,
        y_obs,
        model=sz.model_zoo.multiplicative_gaussian_noise,
        forward_stop=0.2,
        backward_stop=0.01,
        verbose=2,
        alpha=alpha)
)
print(set(selected_paths) - set(active_paths), set(selected_paths) & set(active_paths), set(active_paths) - set(selected_paths), )

In [None]:
all_paths = list(sorted(set(selected_paths) | set(active_paths)))

In [None]:
depth_est = pd.DataFrame(beta_est, index=selected_paths).reindex(all_paths, fill_value=0)
sns.heatmap(depth_est, norm=mpl.colors.SymLogNorm(1, vmin=-5e7, vmax=5e7), yticklabels=1, cmap='coolwarm')

In [None]:
depth = pd.DataFrame(beta[active_paths, :], index=active_paths).reindex(all_paths, fill_value=0)
sns.heatmap(depth, norm=mpl.colors.SymLogNorm(1, vmin=-5e7, vmax=5e7), yticklabels=1, cmap='coolwarm')

In [None]:
err = depth_est - depth
sns.heatmap(err, norm=mpl.colors.SymLogNorm(1, vmin=-5e7, vmax=5e7), yticklabels=1, cmap='coolwarm')

In [None]:
err_est = pd.DataFrame(beta_stderr, index=selected_paths).reindex(all_paths, fill_value=0)
sns.heatmap(err_est, norm=mpl.colors.SymLogNorm(1, vmin=-5e7, vmax=5e7), yticklabels=1, cmap='coolwarm')

In [None]:
d = pd.DataFrame(dict(
    depth=depth.stack(),
    depth_est=depth_est.stack(),
    err=err.stack(),
    stderr_est=err_est.stack(),
)).rename_axis(['path', 'sample']).reset_index().assign(
    false_positive=lambda x: x.path.isin(set(selected_paths) - set(active_paths)),
    false_negative=lambda x: x.path.isin(set(active_paths) - set(selected_paths)),
)
xx = np.logspace(-1, 5)

plt.scatter('depth', 'err', data=d, c='false_positive')
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.xscale('symlog', linthresh=1e-1)
plt.yscale('symlog', linthresh=1e-1)

In [None]:
d = pd.DataFrame(dict(
    depth=depth.stack(),
    depth_est=depth_est.stack(),
    err=err.stack(),
    stderr_est=err_est.stack(),
)).rename_axis(['path', 'sample']).reset_index().assign(
    false_positive=lambda x: x.path.isin(set(selected_paths) - set(active_paths)),
    false_negative=lambda x: x.path.isin(set(active_paths) - set(selected_paths)),
)
xx = np.logspace(-1, 3)

plt.scatter('stderr_est', 'err', data=d)
plt.plot(xx, xx)
plt.plot(xx, -xx)
plt.xscale('symlog', linthresh=1e-1)
plt.yscale('symlog', linthresh=1e-1)

# Test Known Specification

In [None]:
_already_active_paths = [2, 4, 6]
_active_paths = _already_active_paths

X_reduced = X[:, _active_paths]
# Estimate model parameters
beta_reduced_est, sigma_est, _ = model.fit(
    y_obs, X_reduced, alpha=alpha
)
loglik = -model.negloglik(
    beta_reduced_est, sigma_est, y_obs, X_reduced, alpha=alpha
)
print(loglik, _active_paths)

## Locally Better Set?

In [None]:
for i, (p, l) in enumerate(sz.deconvolution.iter_forward_greedy_path_selection(X, y_obs, sz.model_zoo.multiplicative_gaussian_noise, active_paths=active_paths, alpha=1.0)):
    print(round(l, 0), p)
    if i >= 0:
        break

In [None]:
for i, (p, l) in enumerate(sz.deconvolution.iter_backward_greedy_path_selection(X, y_obs, sz.model_zoo.multiplicative_gaussian_noise, active_paths=active_paths, alpha=1.0)):
    print(round(l, 0), p)
    if i >= 0:
        break

# Greedy selection

In [None]:
prior_l = float('-inf')

for i, (p, l) in enumerate(sz.deconvolution.iter_forward_greedy_path_selection(X, y_obs, sz.model_zoo.multiplicative_gaussian_noise, active_paths=[], alpha=1.0)):
    pvalue = sz.deconvolution.likelihood_ratio_test(l - prior_l, delta_df=s_samples)
    print(round(l, 0), p, pvalue)
    prior_l = l
    if i >= 16:
        break

In [None]:
prior_l = float('-inf')

for i, (p, l) in enumerate(sz.deconvolution.iter_backward_greedy_path_selection(X, y_obs, sz.model_zoo.multiplicative_gaussian_noise, active_paths=[1, 2, 3, 4, 6, 8], alpha=1.0)):
    pvalue = sz.deconvolution.likelihood_ratio_test(prior_l - l, delta_df=s_samples)
    print(round(l, 0), p, pvalue)
    prior_l = l
    if i >= 20:
        break

In [None]:
# import strainzip as sz
import jax.numpy as jnp
from jax import grad, hessian
from jax.tree_util import Partial
import jaxopt

In [None]:
n, m = 3, 3
s_samples = 3
r_edges, p_paths = (n+m, n*m)


X = sz.deconvolution.design_paths(n, m)[0].T
assert X.shape == (r_edges, p_paths)

sns.heatmap(X)

In [None]:
np.random.seed(0)
active_paths = simulate_active_paths(n, m)
print(active_paths)
active_paths = [i for i, _ in active_paths]

beta = np.zeros((p_paths, s_samples))
beta[active_paths, :] = np.random.lognormal(mean=-3, sigma=6, size=(len(active_paths), s_samples))
# beta = beta.round(1)

sns.heatmap(beta, norm=mpl.colors.SymLogNorm(1e-5))
# sns.heatmap(beta)

In [None]:
sigma = 1
depth_multiplier = 1

np.random.seed(2)
expect = X @ (beta * depth_multiplier)
log_noise = np.random.normal(loc=0, scale=1, size=expect.shape)
y_obs = expect * np.exp(log_noise * sigma)

# sns.heatmap(y_obs)
sns.heatmap(y_obs, norm=mpl.colors.SymLogNorm(1e-2))

In [None]:
np.random.seed(4)

num_excess_paths = 0
_active_paths = list(sorted(set(active_paths) | set(np.random.choice([p for p in range(p_paths) if p not in active_paths], replace=False, size=num_excess_paths))))
p_reduced = len(_active_paths)

X_reduced = X[:, _active_paths]
print(X_reduced.shape)

sns.heatmap(X_reduced)
print(np.linalg.matrix_rank(X_reduced))

In [None]:
beta_est, sigma_est, fit = fit_model(y_obs, X_reduced, alpha=1e-5)

In [None]:
model_hessian = hessian(Partial(model_loss, y=y_obs, X=X_reduced, alpha=1e-5), argnums=[0, 1])
(_beta_beta_hess, _beta_sigma_hess), (_sigm_beta_hess, _sigma_sigma_hess) = model_hessian(beta_est, sigma_est)
_hess_flat = _beta_beta_hess.reshape((p_reduced*s_samples, -1))
_var_covar_matrix = jnp.linalg.inv(_hess_flat)
_max = np.abs(_var_covar_matrix).max()

_variance = np.diag(_var_covar_matrix).reshape((p_reduced, s_samples))
_stderr = np.sqrt(_variance)
assert np.isfinite(_stderr).all()

# sns.heatmap(_var_covar_matrix, norm=mpl.colors.SymLogNorm(1, vmin=-_max, vmax=_max), cmap='coolwarm')

In [None]:
print(_hess_flat.shape)
print(np.linalg.matrix_rank(_hess_flat))

In [None]:
fig, axs = plt.subplots(1, s_samples, sharex=True, sharey=True, figsize=(3 * s_samples, 3))

for s, ax in enumerate(axs):
    ax.errorbar(beta[_active_paths, s].ravel(), beta_est[:, s], yerr=_stderr[:, s], fmt='.')
    ax.plot([0, 2000], [0, 2000], lw=1, linestyle='--', color='k')

plt.yscale('symlog', linthresh=1e-1)
plt.xscale('symlog', linthresh=1e-1)

In [None]:
np.random.seed(0)

X = np.array([
    [1, 1, 1, 0, 0, 0],
    [0, 0, 0, 1, 1, 1],
    [1, 0, 0, 1, 0, 0],
    [0, 1, 0, 0, 1, 0],
    [0, 0, 1, 0, 0, 1],
])

# Latent path depths
beta = np.array([
    [1. , 1. , 0. ],
    [1. , 0. , 0. ],
    [0. , 0. , 0. ],
    [0. , 0. , 0. ],
    [0. , 0. , 0. ],
    [1. , 0. , 1. ],
]) * 10000
# beta = np.array([
#     [1.],
#     [0.5],
#     [0.],
#     [0.],
#     [0.],
#     [0.],
# ])

n, m = 2, 3


s_samples = beta.shape[1]
p_paths = X.shape[1]
r_edges = X.shape[0]

assert X.shape == (n+m, n*m)
assert beta.shape == (p_paths, s_samples)

y = X @ beta
y_obs = y * np.random.lognormal(0, sigma=2, size=y.shape)

y_stacked = y.reshape((r_edges * s_samples, 1))

y_stacked.shape
X_stacked = np.block([[X] * s_samples] * s_samples)

assert X_stacked.shape[0] == y_stacked.shape[0]

# beta_stacked = beta.reshape((p_paths * s_samples, 1))

# print(X_stacked.shape)
# print(beta_stacked.shape)
# print(y_stacked.shape)

In [None]:
import jax.numpy as jnp
from jax import grad, hessian
from jax.tree_util import Partial

eps = 1

def trsfm_depth(beta, eps):
    return jnp.log(beta + eps)

def inv_trsfm_depth(trsfm_beta, eps):
    return jnp.exp(trsfm_beta) - eps

def loss_(trsfm_beta, y, X, eps):
    return jnp.sum((trsfm_depth(y, eps) - trsfm_depth(X @ inv_trsfm_depth(trsfm_beta, eps), eps))**2)

active_paths = [0,1,5]
reduced_X = X[:, active_paths]
trsfm_reduced_beta = trsfm_depth(beta[active_paths, :], eps)
p_reduced = len(active_paths)

# loss = Partial(loss_, y=y_obs, X=reduced_X, eps=eps)
# grad_loss = grad(loss)
# hess_loss = hessian(loss)

# loss(trsfm_reduced_beta), grad_loss(trsfm_reduced_beta)

In [None]:
def _pack(beta):
    p, s = beta.shape
    return beta.ravel(), p, s

def _unpack(packed, p, s):
    beta = packed.reshape((p, s))
    return beta

def _loss_packed(packed, y, X, eps, p, s):
    trsfm_beta = _unpack(packed, p, s)
    return loss_(trsfm_beta, y, X, eps)

loss = Partial(_loss_packed, y=y_obs, X=reduced_X, eps=eps, p=p_reduced, s=s_samples)
grad_loss = grad(loss)
hess_loss = hessian(loss)

res = sp.optimize.minimize(loss, x0=np.zeros(p_reduced*s_samples), jac=grad_loss)
est_beta_reduced = res.x
inv_trsfm_depth(_unpack(est_beta_reduced, p=p_reduced, s=s_samples), eps)

In [None]:
_unpack(est_beta_reduced, p=p_reduced, s=s_samples)

In [None]:
_unpack(grad_loss(est_beta_reduced), p=p_reduced, s=s_samples)

In [None]:
_this_hessian = hess_loss(est_beta_reduced)
var_covar_matrix = jnp.linalg.inv(_this_hessian)

import seaborn as sns

sns.heatmap(var_covar_matrix, center=0)

print(np.sqrt(_unpack(np.diag(var_covar_matrix), p=p_reduced, s=s_samples)))

In [None]:
hess_loss(sol.params).shape

In [None]:
active_paths = [0, 1]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(LOG1PLINK(y_stacked * 100000), reduced_X_stacked, family=Gaussian2(link=Log1pLink()))
results = model.fit()

print(results.summary())

In [None]:
import pymc as pm
import pytensor.tensor as tt

def log1p_np(x):
    return np.log(x + 1)

def log1p_pm(x):
    return tt.log(x + 1)

active_paths = [0,1,4,5]
reduced_X = X[:, active_paths]

y_obs = X @ (beta * 100)

with pm.Model() as model0:
    design = pm.Data('design', value=reduced_X, shape=(r_edges, len(active_paths)), mutable=True)
    observed_trsfm = pm.Data('observed', value=log1p_np(y_obs), shape=(r_edges, s_samples), mutable=True)
    b = pm.LogNormal('b', shape=(len(active_paths), s_samples))
    s = pm.LogNormal('s', mu=-2, sigma=1)
    expect = design @ b
    lik = pm.Normal('lik', mu=log1p_pm(expect), sigma=s, observed=observed_trsfm)

    trace = pm.sample()

In [None]:
pm.summary(trace)

In [None]:
model = sm.GLM(y_stacked, X_stacked, family=sm.families.Gaussian(link=LogLink()))
results = model.fit()

print(results.summary())

In [None]:
active_paths = [0,1,4,5]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(y_stacked, reduced_X_stacked, family=sm.families.Gaussian(link=LogLink()))
results = model.fit()

print(results.summary())

In [None]:
active_paths = [1,4,5]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(LOG1PLINK(y_stacked), reduced_X_stacked, family=Gaussian2(link=Log1pLink()))
results = model.fit()

print(results.summary())

In [None]:
active_paths = [0, 1, 2, 4, 5]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(LOG1PLINK(y_stacked), reduced_X_stacked, family=Gaussian2(link=Log1pLink()))
results = model.fit()

print(results.summary())

In [None]:
active_paths = [1, 2, 4, 5]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(LOG1PLINK(y_stacked), reduced_X_stacked, family=Gaussian2(link=Log1pLink()))
results = model.fit()

print(results.summary())

In [None]:
active_paths = [4, 5]
reduced_X = X[:, active_paths]
reduced_X_stacked = np.block([[reduced_X] * s_samples] * s_samples)

model = sm.GLM(LOG1PLINK(y_stacked), reduced_X_stacked, family=Gaussian2(link=Log1pLink()))
results = model.fit()

print(results.summary())

In [None]:
results.