This notebook demonstrates the use of a Poisson + two-state Markov transcriptional bursting model with detection thinning to describe stochastic gene expression in single-cell transcriptomics. In this notebook, we:

1. Initialize the model with biologically realistic parameters.
2. Plot the probability mass function (PMF) of:
   - Individual burst sizes
   - Total observed transcript counts over a fixed observation window.
3. Compute analytical moments (mean, variance, burstiness index).
4. Simulate synthetic datasets using the model.
5. Fit model parameters using method-of-moments estimation.

## Model overview

The model is mathematically represented as a compound Poisson process with mixed Poisson-distributed burst sizes:

1. Genes stochastically switch between Off and On states.
2. When On, transcripts are produced at rate $r$, but each transcript is detected only with probability $p$ (thinning).
3. Burst durations are exponentially distributed with rate $k_\text{off}$, and bursts occur with rate $k_\text{on}$.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from burst_models.poisson_model import PoissonBurstModel

In [None]:
def simulate_data(model: PoissonBurstModel, L: float, n_samples: int) -> np.ndarray:
    """
    Simulate total observed transcript counts over window L:
      1) Draw number of bursts B ~ Poisson(k_on * L)
      2) For each burst i:
           - Draw duration t_i ~ Exponential(rate=k_off)
           - Draw observed transcripts N_i ~ Poisson(p * r * t_i)
      3) Sum over bursts: X = sum_i N_i
    """
    samples = []
    for _ in range(n_samples):
        # 1) number of bursts in window
        B = np.random.poisson(model.k_on * L)
        if B == 0:
            samples.append(0)
            continue
        # 2) burst durations
        durations = np.random.exponential(scale=1.0 / model.k_off, size=B)
        # 3) transcripts per burst (with detection thinning)
        counts = np.random.poisson(lam=model.r * model.p * durations)
        samples.append(counts.sum())
    return np.array(samples)

In [None]:
# Parameters
k_on = 1.0     # burst initiation rate
k_off = 1.5    # off-switching rate
r = 5.0        # transcription rate when active
p = 0.8        # detection probability
L = 10.0       # observation window length

In [None]:
model = PoissonBurstModel(k_on=k_on, k_off=k_off, r=r, p=p)

In [None]:
k_vals = np.arange(0, 21)
burst_pmf = [model.burst_size_pmf(k) for k in k_vals]

x_vals = np.arange(0, 51)
total_pmf = [model.total_transcripts_pmf(x, L) for x in x_vals]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].stem(k_vals, burst_pmf, basefmt=" ", use_line_collection=True)
axes[0].set_xlabel("Burst size k")
axes[0].set_ylabel("P(N=k)")
axes[0].set_title("Burst-size PMF")

axes[1].stem(x_vals, total_pmf, basefmt=" ", use_line_collection=True)
axes[1].set_xlabel("Total observed transcripts x")
axes[1].set_ylabel("P(X=x)")
axes[1].set_title(f"Total-transcripts PMF (L={L})")

plt.tight_layout()
plt.show()

In [None]:
# --- First moments ---
mean, var, burstiness = model.first_moments(L)
print(f"Analytical results over L={L}:")
print(f"  Mean        = {mean:.3f}")
print(f"  Variance    = {var:.3f}")
print(f"  Burstiness  = Var/Mean = {burstiness:.3f}\n")

In [None]:
# --- Simulate synthetic data ---
n_samples = 1000
data = simulate_data(model, L, n_samples)

# Empirical sample moments
print(
    f"Simulated {n_samples} samples → sample mean = {data.mean():.3f}, sample var = {data.var(ddof=1):.3f}\n"
)

In [None]:
# --- Fit by method of moments ---
lambda_hat, k_on_hat = PoissonBurstModel.fit_moments(data.tolist(), L)
print("Fitted by method of moments:")
print(f"  λ (p·r/k_off)  = {lambda_hat:.3f}")
print(f"  k_on           = {k_on_hat:.3f}")

In [None]:
# Plot histogram
fig, ax = plt.subplots()
ax.hist(data, bins=20)
ax.set_xlabel("Total observed transcripts x")
ax.set_ylabel("Count")
ax.set_title(f"Simulated data over L={L}")
plt.show()

# Sweep analysis

In [None]:
baseline = {"k_on": 1.0, "k_off": 1.5, "r": 5.0, "p": 0.8, "L": 10.0}
sweeps = {
    "k_on": np.linspace(0.5 * baseline["k_on"], 1.5 * baseline["k_on"], 5),
    "k_off": np.linspace(0.5 * baseline["k_off"], 1.5 * baseline["k_off"], 5),
    "r": np.linspace(0.5 * baseline["r"], 1.5 * baseline["r"], 5),
    "p": np.linspace(0.5 * baseline["p"], 1.5 * baseline["p"], 5),
}
variables = list(sweeps.keys())

k_vals = np.arange(0, 21)
x_vals = np.arange(0, 51)


def make_model(params):
    return PoissonBurstModel(
        k_on=params["k_on"], k_off=params["k_off"], r=params["r"], p=params["p"]
    )


def plot_grid(plot_fn, title):
    fig, axes = plt.subplots(
        nrows=4, ncols=5, figsize=(15, 12), sharex=False, sharey=False
    )
    fig.suptitle(title, fontsize=16)
    for i, var in enumerate(variables):
        for j, val in enumerate(sweeps[var]):
            # build parameter set for this subplot
            params = baseline.copy()
            params[var] = val
            model = make_model(params)
            ax = axes[i, j]
            plot_fn(ax, model, params)
            ax.set_title(f"{var}={val:.2f}")
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

In [None]:
def plot_burst_pmf(ax, model, params):
    pmf = [model.burst_size_pmf(k) for k in k_vals]
    ax.stem(k_vals, pmf, basefmt=" ", use_line_collection=True)
    ax.set_xlabel("k")
    ax.set_ylabel("P(N=k)")


plot_grid(plot_burst_pmf, "Burst-size PMF across parameter sweeps")

In [None]:
def plot_total_pmf(ax, model, params):
    pmf = [model.total_transcripts_pmf(x, params["L"]) for x in x_vals]
    ax.stem(x_vals, pmf, basefmt=" ", use_line_collection=True)
    ax.set_xlabel("x")
    ax.set_ylabel("P(X=x)")


plot_grid(plot_total_pmf, "Total-transcripts PMF (L fixed)")

In [None]:
def simulate_data(model, L, n_samples=500):
    samples = []
    for _ in range(n_samples):
        B = np.random.poisson(model.k_on * L)
        if B == 0:
            samples.append(0)
            continue
        durations = np.random.exponential(scale=1.0 / model.k_off, size=B)
        counts = np.random.poisson(lam=model.r * model.p * durations)
        samples.append(counts.sum())
    return np.array(samples)


def plot_histogram(ax, model, params):
    data = simulate_data(model, params["L"], n_samples=500)
    ax.hist(data, bins=20)
    ax.set_xlabel("x")
    ax.set_ylabel("Count")


plot_grid(plot_histogram, "Simulated-data histograms over sweeps")