In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd

In [None]:
import torch

import pyro
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro.infer.autoguide.guides import AutoNormal, AutoDelta

from pyro.optim import Adam
from pyro.infer import SVI

In [None]:
import cellij

In [None]:
pyro.enable_validation(True)

In [None]:
from tqdm import tqdm

In [None]:
%matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import seaborn as sns

sns.set_theme()
sns.set_style("whitegrid")
sns.set_context(
    "notebook",
    font_scale=1.0,
    rc={"lines.linewidth": 5},
)

In [None]:
from cellij.core.synthetic import DataGenerator

In [None]:
n_samples = [101, 102, 103, 104]
n_features = [201, 202, 203, 204]
dg = DataGenerator(n_samples, n_features)
rng = dg.generate(all_combs=True)

In [None]:
sns.heatmap(dg.z, center=0, cmap="vlag")

In [None]:
sns.heatmap(dg.w, center=0, cmap="vlag")

In [None]:
model = cellij.core._pyro_models.Generative(
        n_factors=dg.n_factors,
        obs_dict={f"group_{g}": dg.n_samples[g] for g in range(dg.n_sample_groups)},
        feature_dict={f"view_{m}": dg.n_features[m] for m in range(dg.n_feature_groups)},
        likelihoods={f"view_{m}": "Normal" for m in range(dg.n_feature_groups)},
        factor_priors={
            "group_0": "Normal",
            "group_1": "Laplace",
            "group_2": "Horseshoe",
            "group_3": "Horseshoe",
        },
        weight_priors={
            "view_0": "Normal",
            "view_1": "Laplace",
            "view_2": "Horseshoe",
            "view_3": "Horseshoe",
        },
        device=torch.device("cpu"),
    )

autonormal_guide = AutoNormal(model)
guide = cellij.core._pyro_guides.Guide(model)

In [None]:
for k, v in model().items():
    print(k, v.shape)
    
for k, v in guide().items():
    print(k, v.shape)
    
for k, v in autonormal_guide().items():
    print(k, v.shape)
    
for k, v in guide.sample_dict.items():
    print(v.shape == model.sample_dict[k].shape)

In [None]:
for k, v in guide().items():
    print(k, v.shape)
    print(v.shape == autonormal_guide()[k].shape)

In [None]:
data = {f"group_{g}": {f"view_{m}": torch.Tensor(dg.ys[g][m]) for m in range(dg.n_feature_groups)} for g in range(dg.n_sample_groups)}

In [None]:
# clean start
print("Cleaning parameter store")
pyro.clear_param_store()

scale = 1.0 / len(dg.n_samples)
model=pyro.poutine.scale(model, scale=scale)
guide=pyro.poutine.scale(guide, scale=scale)

svi = SVI(model, guide, Adam({"lr": 0.005, "betas": (0.95, 0.999)}), loss=pyro.infer.Trace_ELBO(num_particles=1))

In [None]:
elbo_history = []
pbar = tqdm(range(1000))
for iteration_idx in pbar:
    elbo = svi.step(data)
    elbo_history.append(elbo)
    if iteration_idx % 5 == 0:
        pbar.set_postfix({"ELBO": elbo})

plt.plot(elbo_history)

In [None]:
{k: v.shape for k, v in model.factor_priors['group_2'].sample_dict.items()}

In [None]:
{k: v.shape for k, v in model.weight_priors['view_2'].sample_dict.items()}

In [None]:
pyro.get_param_store().get_all_param_names()

In [None]:
pyro.get_param_store().get_param("AutoNormal.locs.z_group_2_caux").shape

In [None]:
pyro.get_param_store().get_param("AutoNormal.locs.w_view_2_caux").shape

In [None]:
z_hat = pyro.get_param_store().get_param("AutoNormal.locs.z_group_0").squeeze().detach().numpy()
ws_hat = []
ws_hat.append(pyro.get_param_store().get_param("AutoNormal.locs.w_view_0").squeeze().detach().numpy())
ws_hat.append(pyro.get_param_store().get_param("AutoNormal.locs.w_view_1").squeeze().detach().numpy())
ws_hat.append(pyro.get_param_store().get_param("AutoNormal.locs.w_view_2_unconstrained").squeeze().detach().numpy() * pyro.get_param_store().get_param("AutoNormal.locs.w_view_2_lambdas").squeeze().detach().numpy())

In [None]:
sns.heatmap(dg.z, center=0, cmap="vlag")

In [None]:
sns.heatmap(z_hat, center=0, cmap="vlag")

In [None]:
for m in range(dg.n_feature_groups):
    sns.heatmap(dg.ws[m], center=0, cmap="vlag")
    plt.show()
    sns.heatmap(ws_hat[m], center=0, cmap="vlag")
    plt.show()