In [1]:
import numpy as np
import torch
# import torch.distributions.constraints as constraints
# import torch.distributions as tdist
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
from pyro.infer.autoguide import AutoNormal
# import matplotlib.pyplot as plt
# from tqdm import tqdm
# import pickle as pkl

pyro.set_rng_seed(101)

In [6]:
data = {}

 data contains:
 - reads_edited, shape: [len(tags), len(orfs), len(replicates)]
 - reads_unedited, shape: [len(tags), len(orfs), len(replicates)]
 - size_factors, shape: [len(tags), len(orfs), len(replicates)]
 - fit_dispersions, shape: [len(tags), len(orfs)]

 also need:
- beta_exp design matrix for experimental covariates
- beta_int design matrix for interaction covariates

In [5]:
def model_NB_mixture(data):
    n_tags, n_orfs, n_reps = data["reads_edited"].size()
    dispersion_var = 1

    with pyro.plate("Experiments (ABE-Tags)", n_tags):
        # All betas per experiment
        # unsure of dists here
        beta_exp_0 = pyro.sample("β_Exp_0", dist.Normal(0, 1))
        # beta_exp_1
        # beta_exp_2
        # ...
        beta_exp = torch.Tensor([beta_exp_0])  # , beta_exp_1, beta_exp_2]

        with pyro.plate("MCP-ORFs", n_orfs):
            # unsure of dists here
            beta_int_0 = pyro.sample("β_Int_0", dist.Normal(0, 1))
            # beta_int_1
            # beta_int_2
            # ...
            beta_int = torch.Tensor([beta_int_0])  # , beta_int_1, beta_int_2]

            dispersion = pyro.sample(
                "φ", dist.LogNormal(data["fit_dispersions"].log(), dispersion_var)
            )
            assert dispersion.shape == (n_orfs, n_tags)

            with pyro.plate("Replicates", n_reps):
                # need to test these shapes
                log_q = (
                    torch.inner(beta_exp, data["D_Exp"])
                    .unsqueeze(0)
                    .unsqueeze(0)
                    .expand(n_reps, n_orfs, -1)
                )
                pi = (
                    torch.exp(torch.inner(beta_int, data["D_Exp"]))
                    .unsqueeze(0)
                    .unsqueeze(0)
                    .expand(n_reps, n_orfs, -1)
                )

                assert log_q.shape == pi.shape == (n_reps, n_orfs, n_tags)

                logits_U = (
                    log_q
                    + torch.log(1 - pi)
                    + torch.log(data["size_factors"])
                    - torch.log(dispersion[None, :, :])
                )
                logits_E = (
                    log_q
                    + torch.log(pi)
                    + torch.log(data["size_factor"])
                    - torch.log(dispersion[None, :, :])
                )

                dist_U = dist.NegativeBinomial(dispersion, logits=logits_E)
                dist_E = dist.NegativeBinomial(dispersion, logits=logits_U)

                pyro.sample("x_U", dist_U, obs=data["reads_edited"])
                pyro.sample("x_E", dist_E, obs=data["reads_edited"])


# pyro.render_model(model_NB_mixture, render_params=True)


In [None]:
guide = AutoNormal(model_NB_mixture, init_scale=0.01)

pyro.clear_param_store()
initial_lr = 0.05
gamma = 0.1  # final learning rate will be gamma * initial_lr
num_steps = 2000
lrd = gamma ** (1 / num_steps)
svi = pyro.infer.SVI(
    model=model_NB_mixture,
    guide=guide,
    optim=pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd}),
    loss=pyro.infer.Trace_ELBO(max_plate_nesting=3),
)


# keep track of anything meaningful here
beta1s, beta1_vars, beta0s, beta0_vars, dm, dv, api, bpi, pis, losses = (
    [] for _ in range(10)
)

for t in range(num_steps):
    losses.append(svi.step(data))
    if t == 1:
        print(pyro.get_param_store())
    beta1s.append(pyro.param("AutoNormal.locs.beta_1").clone())
    beta1_vars.append(pyro.param("AutoNormal.scales.beta_1"))
    beta0s.append(pyro.param("AutoNormal.locs.beta_0").clone())
    beta0_vars.append(pyro.param("AutoNormal.scales.beta_0"))
    dm.append(pyro.param("AutoNormal.locs.dispersion").clone())
    dv.append(pyro.param("AutoNormal.scales.dispersion"))
    api.append(pyro.param("alpha_pi").clone().detach())
    bpi.append(pyro.param("beta_pi").clone().detach())
    a = pyro.param("alpha_pi").clone().detach()
    b = pyro.param("beta_pi").clone().detach()
    pis.append(a / (a + b))