In [1]:
from __future__ import annotations

from typing import Literal

import matplotlib.pyplot as plt
import torch


In [2]:
# %%
ex_dir = torch.tensor([1, 5, 1], dtype=torch.float32)


In [3]:
# %%
def condition_mv(mean, cov, val, keep: int = 0):
    val = torch.atleast_1d(val)
    m = mean.shape[0]
    c = val.shape[0]

    mu_1, mu_2 = mean[:-c], mean[-c:]
    cov_11 = cov[:-c, :-c].reshape((m - c, m - c))
    cov_12 = cov[:-c, -c:].reshape((m - c, c))
    cov_22 = cov[-c:, -c:].reshape((c, c))
    term = cov_12 @ torch.linalg.inv(cov_22)
    cov_star = cov_11 - term @ cov_12.T
    mu_star = mu_1 + term @ (val - mu_2)
    if keep:
        mu_star = torch.cat((mu_star, val[:keep]))
        cov_star = torch.nn.functional.pad(cov_star, (0, keep, 0, keep))

    assert mu_star.shape == (m - c + keep,), f"{mu_star.shape=}"
    assert cov_star.shape == (m - c + keep, m - c + keep), f"{cov_star.shape=}"
    return mu_star, cov_star


In [4]:
# %%
def condition_subspace(cov: torch.Tensor, orthog_dir: torch.Tensor, val: float):
    dim = orthog_dir.shape[0]
    assert orthog_dir.shape == (dim,)
    assert cov.shape == (dim, dim), f"{cov.shape=}"

    expand = torch.cat((torch.eye(dim), orthog_dir.reshape(1, -1)), dim=0)
    mean = torch.zeros((dim + 1,), dtype=cov.dtype, device=cov.device)
    cov = expand @ cov @ expand.T
    return condition_mv(mean, cov, val, keep=0)


In [10]:
# %%
ex_cov = torch.tensor([[2, 1, 1], [1, 2.4, 0], [1, 0, 1.5]], dtype=torch.float32)
ex_mean, ex_cov = condition_subspace(
    cov=ex_cov,
    orthog_dir=ex_dir,
    val=1.2,
)
map_a, map_b = 0.08163265, 0.64114136


TypeError: atleast_1d() received an invalid combination of arguments - got (float), but expected one of:
 * (Tensor input)
      didn't match because some of the arguments have invalid types: (!float!)
 * (tuple of Tensors tensors)
      didn't match because some of the arguments have invalid types: (!float!)


In [None]:
# %%
def wavy_kernel(x: torch.Tensor, y: torch.Tensor):
    return (1 + x.reshape(-1, 1) @ y.reshape(1, -1)) ** 2 + map_a * (
        torch.sin(2 * torch.pi * x.flatten() + map_b).reshape(-1, 1)
        * torch.sin(2 * torch.pi * y.flatten() + map_b).reshape(1, -1)
    )


In [None]:
# %%
def gen_unconditional_samples(
    orthog_dir: torch.Tensor,
    q_hat: float,
    l_size: int,
    sample_shape: tuple[int, ...],
):
    space = torch.linspace(0, 1, steps=l_size)
    mean = torch.zeros_like(space)
    cov = wavy_kernel(space, space)
    if orthog_dir is not None:
        mag = torch.linalg.norm(orthog_dir)
        mean, cov = condition_subspace(
            cov=cov, orthog_dir=orthog_dir / mag, val=q_hat / mag
        )
    distribution = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
    samples = distribution.sample(sample_shape)
    return samples


In [None]:
# %%
def plot_unconditional_samples(
    ax: plt.Axes = None,
    orthog_dir: torch.Tensor | Literal["w"] = None,
    q_hat: float = 1.0,
    l_size: int = 101,
    sample_shape: tuple[int, ...] = (5,),
):
    if orthog_dir == "w":
        orthog_dir = torch.full((l_size,), 1 / (l_size - 1))
        orthog_dir[0] *= 0.5
        orthog_dir[-1] *= 0.5
    if ax is None:
        _, ax = plt.subplots()

    samples = gen_unconditional_samples(orthog_dir, q_hat, l_size, sample_shape)
    assert torch.isfinite(samples).all(), "Invalid values sampled"
    for sample in samples:
        ax.plot(torch.linspace(0, 1, l_size).numpy(), sample.numpy(), label="")

    if orthog_dir is not None:
        assert torch.isclose(
            orthog_dir @ samples.reshape(-1, l_size).T, q_hat, atol=0.01
        ).all(), f"Invalid conditioning, {orthog_dir @ samples.T}"
        ax.set_title(f"Samples of $f ~|~ X, \\hat q = {q_hat}$")
    else:
        ax.set_title("Unconditional samples of $f ~|~ X$")

    return samples


In [None]:
# %%
_, axs = plt.subplots(nrows=3, figsize=(5, 12), tight_layout=True)
for w_ax, target_q in zip(axs, (0, 5, 10)):
    plot_unconditional_samples(ax=w_ax, orthog_dir="w", q_hat=target_q, l_size=101)
plt.show()


In [None]:
# %%
def shuffle_condition(
    mean: torch.Tensor,
    cov: torch.Tensor,
    indices: torch.Tensor,
    vals: torch.Tensor,
    orthog_dir: torch.Tensor = None,
    q_hat: float = 1.0,
):
    is_conditioned = torch.zeros_like(mean)
    is_conditioned[indices] = 1
    is_conditioned = is_conditioned.argsort().reshape(-1, 1)
    reverse_is = torch.empty(len(mean), dtype=torch.int64)
    reverse_is[is_conditioned.flatten()] = torch.arange(0, len(mean), dtype=torch.int64)

    mean = mean[is_conditioned.flatten()]
    cov = cov[is_conditioned.flatten()][:, is_conditioned.flatten()]

    if orthog_dir is not None:
        expand = torch.cat((torch.eye(len(mean)), orthog_dir.reshape(1, -1)), dim=0)
        mean = expand @ mean
        cov = expand @ cov @ expand.T
        vals = torch.cat((vals, torch.tensor([q_hat], dtype=torch.float32)))

    keep = len(vals) - (1 if orthog_dir is not None else 0)
    mean, cov = condition_mv(mean, cov, vals, keep=keep)

    mean = mean[reverse_is]
    cov = cov[reverse_is][:, reverse_is]

    return mean, cov


In [None]:
# %%
def conditional_samples(
    xs: torch.Tensor,
    ys: torch.Tensor,
    ax: plt.Axes = None,
    l_size: int = 101,
    orthog_dir: torch.Tensor | Literal["w"] = None,
    q_hat: float = 1.0,
    sample_shape: tuple[int, ...] = (5,),
):
    if ax is None:
        _, ax = plt.subplots()

    indices = torch.round(xs * 100).clamp(0, 100).to(torch.int64)
    space = torch.linspace(0, 1, steps=l_size)
    mean = torch.zeros_like(space)
    cov = wavy_kernel(space, space)
    cov[indices, indices] += 0.01
    if orthog_dir is not None:
        if orthog_dir == "w":
            orthog_dir = torch.full_like(space, 1 / (l_size - 1))
            orthog_dir[0] *= 0.5
            orthog_dir[-1] *= 0.5
        mag = torch.linalg.norm(orthog_dir)
        mean, cov = shuffle_condition(
            mean, cov, indices, ys, orthog_dir=orthog_dir / mag, q_hat=q_hat / mag
        )
        ax.set_title(f"Samples of $f ~|~ D, X, \\hat q = {q_hat}$")
    else:
        ax.set_title("Unconditional samples of $f ~|~ D, X$")
        mean, cov = shuffle_condition(mean, cov, indices, ys)

    distribution = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov)
    samples = distribution.sample(sample_shape)
    assert torch.isfinite(samples).all(), "Invalid values sampled"
    ax.scatter(xs.numpy(), ys.numpy(), marker="x", c="red", label="D")
    for sample in samples:
        ax.plot(space.numpy(), sample.numpy(), label="")

    diag_low, diag_high = torch.distributions.Normal(
        loc=mean, scale=cov.diagonal()
    ).icdf(torch.tensor([0.05, 0.95]).reshape(2, 1))
    error_samples = distribution.sample((5000,))
    error_low, error_high = torch.quantile(
        error_samples, torch.tensor([0.05, 0.95]), dim=0
    )
    ax.fill_between(
        space.numpy(),
        error_low.numpy(),
        error_high.numpy(),
        alpha=0.2,
        label="Numerical 0.05:0.95",
    )
    ax.fill_between(
        space.numpy(),
        diag_low.numpy(),
        diag_high.numpy(),
        alpha=0.3,
        label="Diagonal 0.05:0.95",
    )
    ax.legend()

    if orthog_dir is not None:
        assert torch.isclose(
            orthog_dir @ samples.reshape(-1, l_size).T, q_hat, atol=0.01
        ).all(), f"Invalid conditioning, {orthog_dir @ samples.T}"
    global ex_dir
    ex_dir = orthog_dir
    return samples


In [None]:
# %%
data_xs = torch.tensor([0, 0.25, 0.5], dtype=torch.float32)
data_ys = torch.tensor([1.46, 0.93, 2.76], dtype=torch.float32)

_, axs = plt.subplots(nrows=2, figsize=(5, 8), tight_layout=True)
cs_1 = conditional_samples(data_xs, data_ys, l_size=101, q_hat=2, ax=axs[0])
cs_2 = conditional_samples(
    data_xs, data_ys, l_size=101, q_hat=2, orthog_dir="w", ax=axs[1]
)

# %%