In [None]:
import numpy as np
import torch
import cv2

In [None]:
import numpy as np
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from scipy.stats import chi2

def _clip(v, maxv, minv=0):
    return min(maxv, max(minv, v))

def slice_3d_normal_at_z_value(mu3, Sigma3, Vz):
    """
    Compute parameters of the 2D slice of a 3D normal distribution at Z = Vz,
    where input order is [Z, Y, X].

    Parameters
    ----------
    mu3 : array-like, shape (3,)
        Mean vector in [μz, μy, μx] order.
    
    Sigma3 : array-like, shape (3, 3)
        Covariance matrix in [Z, Y, X] variable order.
    
    Vz : float
        The fixed value for Z (slice at z = Vz).
    
    Returns
    -------
    C : float
        Scale factor = marginal PDF of Z at Z = Vz.
    
    mu2 : ndarray, shape (2,)
        Conditional mean vector [μy|Vz, μx|Vz].
    
    cov2 : ndarray, shape (2, 2)
        Conditional covariance matrix Σ_{YX|Z}.
    
    std2 : ndarray, shape (2,)
        Standard deviations [σ_{Y|Vz}, σ_{X|Vz}].
    """
    mu3 = np.asarray(mu3)
    Sigma3 = np.asarray(Sigma3)

    mu_z = mu3[0]
    mu_yx = mu3[1:]               # [μy, μx]

    Sigma_zz = Sigma3[0, 0]       # scalar
    Sigma_zyx = Sigma3[0, 1:]     # 1×2 (Z vs [Y,X])
    Sigma_yx = Sigma3[1:, 1:]     # 2×2 ([Y,X] block)
    
    # 1. Marginal density of Z at Vz
    diff_z = Vz - mu_z
    # C = (1 / np.sqrt(2 * np.pi * Sigma_zz)) * np.exp(-0.5 * (diff_z**2) / Sigma_zz)
    C = np.exp(-0.5 * (diff_z**2) / Sigma_zz)

    # 2. Conditional mean
    mu2 = mu_yx + (Sigma_zyx / Sigma_zz) * diff_z

    # 3. Conditional covariance
    cov2 = Sigma_yx - np.outer(Sigma_zyx, Sigma_zyx) / Sigma_zz

    return C, mu2, cov2


def _check_valid_covariance_matrix(matrix: torch.Tensor):
    """Validates if a tensor is a valid 3x3 covariance matrix.

    Raises:
        TypeError: If the input is not a torch.Tensor.
        ValueError: If the matrix is not 3x3 or is not a valid covariance matrix.
    """

    if not isinstance(matrix, torch.Tensor):
        raise TypeError("Input must be a torch.Tensor.")
    dof = matrix.shape[0]

    if matrix.shape != (dof, dof):
        raise ValueError(f"Matrix must be {dof}x{dof}, but got shape {matrix.shape}.")

    if not torch.allclose(matrix, matrix.T):
        raise ValueError("Matrix is not symmetric.")

    try:
        torch.linalg.cholesky(matrix)
    except RuntimeError as e:
        if "positive definite" in str(e):
            raise ValueError("Matrix is not positive definite.")
        elif "singular" in str(e):
            print("Matrix is singular (checking eigenvalues for semi-definiteness).")
            eigenvalues = torch.linalg.eigvals(matrix)
            if not torch.all(eigenvalues.real >= 0):
                raise ValueError(
                    "Matrix is not positive semi-definite. Eigenvalues:", eigenvalues
                )
        else:
            raise RuntimeError(
                f"An unexpected error occurred during Cholesky decomposition: {e}"
            )

def cal_range_in_conf_interval(
    cov: torch.Tensor,
    sigma_scale_factor: float | None = None,
    conf_interval: float | None = 0.999,
) -> torch.Tensor:
    """
    Given a normal distribution with covariance matrix `cov`.
    This function calculate value ranges (ellipsoid radius) along base axes X, Y, Z
    lies within a confident inverval, e.g 99.9% as default
    This is a generalized version of 3-sigma rules where `cov` matrix is just 1x1 matrix
    of 1D variance, conf_interval ~ 99.73%, this function should return
    `3 * sigma = 3*sqrt(cov[0,0])`

    Args:
        cov: covariance matrix of shape (D, D) where D is the number of dimensions
        sigma_scale_factor: value of `sigma_scale_factor` sigma-rule, e.g 3.0 for dof=1
        conf_interval: the confident interval

    Returns:
        Tensor R of shape (D,) specify the range along D base axes.
        The region within confident interval then be [mean - R, mean + R]
    """
    dof = cov.shape[0]
    assert cov.shape == (dof, dof)

    if sigma_scale_factor is not None:
        assert conf_interval is None and sigma_scale_factor > 0
    else:
        # chi-squared critical value for K degrees of freedom
        sigma_scale_factor = chi2.ppf(conf_interval, df=dof) ** 0.5

    eigenvalues, eigenvectors = torch.linalg.eigh(cov)  # faster for symmetric matrices
    assert torch.all(eigenvalues >= 0.0), f"{cov}"
    ax_lengths = torch.sqrt(eigenvalues) * sigma_scale_factor
    ax_vecs = ax_lengths[None] * eigenvectors  # each colume is an axis vector
    ret = ax_vecs.abs().max(dim=-1)[0]
    return ret

    
def generate_2d_gaussian_heatmap(
    heatmap_size: Tuple[int, int, int, int],
    keypoints: torch.Tensor,
    stride: int = 1,
    covariance: torch.Tensor | None = None,
    dtype=torch.float32,
    sigma_scale_factor: float | None = None,
    conf_interval: float | None = 0.999,
    lower=0.0,
    upper=1.0,
    same_std=False,
    add_offset=True,
    validate_cov_mat=False,
) -> torch.Tensor:
    """Generate 3D multivariate Gaussian heatmap from keypoints with arbitrary covariance matrix.

    Args:
        heatmap_size: heatmap size of shape (C, X, Y, Z) where C is the number of classes
        keypoints: (N, D) where D is one of:
            3 -> (x, y, z)
            4 -> (x, y, class)
            10 -> (x, y, z, cov_xx, cov_yy, cov_zz, cov_xy, cov_xz, cov_yz, class)
        convariance: (3, 3) covariance matrix (squared of sigma/std) if not provided per keypoint
        dtype: output heatmap data type, use torch.float16 to save some memory else torch.float32
        sigma_scale_factor: number of sigma-multiplicative rule, e.g 3 for dof=1, 4.03 for dof=3
        conf_interval: amount of confident interval drawed, covered by `sigma_scale_factor` sigma rule
        lower: min heatmap value for label smoothing
        upper: max heatmap value for label smoothing
        add_offset: whether to add 0.5 offset to coordinates to transform from discreate pixel indices
            to floating point coordinate
        validate_cov_mat: whether to validate the per-keypoint covariance matrix

    Returns:
        Tensor of shape (C, X, Y, Z) as provided by `heatmap_size`
    """
    assert 0 <= lower < upper <= 1.0
    assert len(keypoints.shape) == 2 and keypoints.shape[1] >= 2
    keypoints = keypoints.clone()
    # voxel indices -> floating point coordinate
    # origin is top-left corner of top-left voxel
    # center of first voxel (top-left) has coordinate (0.5, 0.5, 0.5)
    if add_offset:
        keypoints[:, :3] += 0.5

    C = heatmap_size[0]
    if stride > 1:
        assert all([e % stride == 0 for e in heatmap_size[1:]])
    X, Y = [round(e / stride) for e in heatmap_size[1:]]

    # Create 3D grid
    xs = torch.linspace(0.5, X - 0.5, X)
    ys = torch.linspace(0.5, Y - 0.5, Y)
    grid = torch.stack(
        torch.meshgrid([xs, ys], indexing="ij"), dim=-1
    )  # Shape (X, Y, 2)

    heatmap = torch.full((C, X, Y), lower, dtype=dtype)

    if keypoints.shape[1] <= 3:
        keypoints[:, :2] /= stride
        assert covariance.shape == (2, 2)
    elif keypoints.shape[1] == 6:
        keypoints[:, :2] /= stride
        keypoints[:, 2:5] /= stride * stride
    else:
        raise ValueError

    if sigma_scale_factor is not None:
        assert conf_interval is None and sigma_scale_factor > 0
    else:
        # chi-squared critical value for K degrees of freedom
        sigma_scale_factor = chi2.ppf(conf_interval, df=2) ** 0.5

    for kpt in keypoints.tolist():
        if len(kpt) == 2:
            x, y = kpt
            kpt_cls = 0
            cov_mat = covariance
        elif len(kpt) == 3:
            x, y, kpt_cls = kpt
            cov_mat = covariance
        elif len(kpt) == 6:
            x, y, cov_xx, cov_yy, cov_xy, kpt_cls = kpt
            cov_mat = torch.tensor(
                [
                    [cov_xx, cov_xy],
                    [cov_xy, cov_yy],
                ],
                dtype=torch.float32,
            )
        else:
            raise ValueError

        if validate_cov_mat:
            _check_valid_covariance_matrix(cov_mat)

        kpt_cls = int(kpt_cls)
        assert kpt_cls < C

        if not same_std:
            # compute values interval along base axes XYZ, inside a confident inverval
            # e.g 99.9% <-> 4.0331422236561565 sigma-rule
            range_x, range_y = cal_range_in_conf_interval(
                cov_mat,
                conf_interval=None,
                sigma_scale_factor=sigma_scale_factor,
            ).tolist()
            assert range_x > 0 and range_y > 0
        else:
            principle_stds = cal_std_along_principle_axes(cov_mat)
            std = min(principle_stds)
            assert std > 0
            var = std**2
            cov_mat = torch.tensor(
                [
                    [var, 0, 0],
                    [0, var, 0],
                    [0, 0, var],
                ],
                dtype=torch.float32,
            )
            range_x = range_y = std * sigma_scale_factor

        # slices to crop -> reduce computation
        x_min = _clip(round(x - range_x), X)
        x_max = _clip(round(x + range_x), X)
        y_min = _clip(round(y - range_y), Y)
        y_max = _clip(round(y + range_y), Y)
        if (x_max <= x_min) or (y_max <= y_min):
            continue

        # Precompute covariance matrix inverse and determinant
        sigma_inv = torch.linalg.inv(cov_mat)
        grid_patch = grid[x_min:x_max, y_min:y_max]
        grid_diff = grid_patch - torch.tensor(
            [[[x, y]]], dtype=dtype
        )  # Shape (X, Y, 2) - [1,1,1,2]

        # Compute multivariate Gaussian
        squared_mahalanobis = torch.einsum(
            "...i,ij,...j->...", grid_diff, sigma_inv, grid_diff
        )  # Shape (X, Y)
        gaussian = torch.exp(-0.5 * squared_mahalanobis)  # min~0, max~1
        # turn into [0,1] range
        gaussian = (gaussian - gaussian.min()) / (
            gaussian.max() - gaussian.min()
        )  # min=0, max=1
        gaussian = lower + (upper - lower) * gaussian  # min=lower, max=upper

        heatmap_patch = heatmap[kpt_cls, x_min:x_max, y_min:y_max]
        # print(gaussian.shape)
        torch.maximum(heatmap_patch, gaussian, out=heatmap_patch)

    return heatmap

In [None]:
mean = [50, 60, 70]
# cov = np.array([[2.322318, 0.416754, 1.497221 ],
#   [0.416754, 0.746716, 0.440805 ],
#   [1.497221, 0.440805, 2.132053 ]])

cov = np.array([
    [20 **2, 0, 0],
    [0, 25 ** 2, 0],
    [0, 0, 15 ** 2]
])

slice_3d_normal_at_z_value(mean, cov, 60)

In [None]:
agg_heatmap = torch.zeros((128, 512, 512), dtype = torch.float32)
mean = [170, 200, 400]
cov = np.array([
    [20 **2, 0, 0],
    [0, 25 ** 2, 0],
    [0, 0, 15 ** 2]
])

for z_value in range(agg_heatmap.shape[0]):
    C, mu2, cov2 = slice_3d_normal_at_z_value(mean, cov, z_value)
    keypoints = torch.tensor([[
        mu2[0], mu2[1], cov2[0, 0], cov2[1,1], cov2[0,1], 0
    ]], dtype = torch.float32)
    heatmap = generate_2d_gaussian_heatmap(
        heatmap_size = (1, 512, 512),
        keypoints = keypoints,
        stride = 1,
        covariance = None,
        dtype=torch.float32,
        sigma_scale_factor = None,
        conf_interval = 0.999,
        lower=0.0,
        upper=C,
        same_std=False,
        add_offset=False,
        validate_cov_mat=True
    )
    print(z_value, heatmap.min(), heatmap.max())
    agg_heatmap[z_value] = heatmap

from itkwidgets import view

agg_heatmap = (agg_heatmap * 255).byte().cpu().numpy()
view(image = agg_heatmap,
    cmap = 'Grayscale',
)