In [1]:
import torch.nn.functional as F
import torch
from torch import nn, Tensor


In [2]:
def zinb(
    target: Tensor,
    mu: Tensor,
    theta: Tensor,
    pi: Tensor,
    eps=1e-8,
):
    """
    Computes zero-inflated negative binomial (ZINB) loss.

    This function was modified from scvi-tools.

    Args:
        target (Tensor): Torch Tensor of ground truth data.
        mu (Tensor): Torch Tensor of means of the negative binomial (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        pi (Tensor): Torch Tensor of logits of the dropout parameter (real support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: ZINB loss value.
    """
    #  uses log(sigmoid(x)) = -softplus(-x)
    softplus_pi = F.softplus(-pi)
    # eps to make it positive support and taking the log
    log_theta_mu_eps = torch.log(theta + mu + eps)
    pi_theta_log = -pi + theta * (torch.log(theta + eps) - log_theta_mu_eps)

    case_zero = F.softplus(pi_theta_log) - softplus_pi
    mul_case_zero = torch.mul((target < eps).type(torch.float32), case_zero)

    case_non_zero = (
        -softplus_pi
        + pi_theta_log
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )
    mul_case_non_zero = torch.mul((target > eps).type(torch.float32), case_non_zero)

    res = mul_case_zero + mul_case_non_zero
    # we want to minize the loss but maximize the log likelyhood
    return -res.mean()

In [3]:
def zinb_sonnet(
    target: Tensor,
    mu: Tensor,
    theta: Tensor,
    pi: Tensor,
    eps=1e-8,
):
    """
    Computes zero-inflated negative binomial (ZINB) loss updated to improve numerical stability with sonnet

    This function is modified to improve numerical stability and avoid using lgamma.

    Args:
        target (Tensor): Torch Tensor of ground truth data.
        mu (Tensor): Torch Tensor of means of the negative binomial (must have positive support).
        theta (Tensor): Torch Tensor of inverse dispersion parameter (must have positive support).
        pi (Tensor): Torch Tensor of logits of the dropout parameter (real support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: ZINB loss value.
    """
    # Compute log(1 - sigmoid(pi)) more accurately using -softplus(pi)
    log_neg_pi = -F.softplus(pi)
    
    # Compute log(theta + mu) more accurately
    log_theta_mu = torch.log(theta + mu + eps)
    
    # Compute log(1 + mu/theta) more accurately
    log_1_plus_mu_theta = F.softplus(torch.log(mu + eps) - torch.log(theta + eps))
    
    # Compute log likelihood for zero values
    ll_zero = F.softplus(theta * (torch.log(theta + eps) - log_theta_mu) - pi)
    
    # Compute log likelihood for non-zero values
    ll_non_zero = (
        log_neg_pi
        + theta * torch.log(theta + eps)
        - (theta + target) * log_theta_mu
        + target * torch.log(mu + eps)
        - torch.lgamma(target + 1)
        + torch.lgamma(theta + target)
        - torch.lgamma(theta)
    )
    
    # Combine zero and non-zero cases
    ll = torch.where(target < eps, ll_zero, ll_non_zero)
    
    # Return negative mean log-likelihood
    return -ll.mean()

In [4]:
def nb(target: Tensor, mu: Tensor, theta: Tensor, eps=1e-8):
    """
    Computes the negative binomial (NB) loss.

    This function was adapted from scvi-tools.

    Args:
        target (Tensor): Ground truth data.
        mu (Tensor): Means of the negative binomial distribution (must have positive support).
        theta (Tensor): Inverse dispersion parameter (must have positive support).
        eps (float, optional): Numerical stability constant. Defaults to 1e-8.

    Returns:
        Tensor: NB loss value.
    """
    if theta.ndimension() == 1:
        theta = theta.view(1, theta.size(0))

    log_theta_mu_eps = torch.log(theta + mu + eps)
    res = (
        theta * (torch.log(theta + eps) - log_theta_mu_eps)
        + target * (torch.log(mu + eps) - log_theta_mu_eps)
        + torch.lgamma(target + theta)
        - torch.lgamma(theta)
        - torch.lgamma(target + 1)
    )

    return -res.mean()

In [10]:
# Test both functions with the same input
THETA = 10_000 # above this, it gets worse

TARGET = [100,10,10,1,1,0,0,0]
MINPI = 0.01
MAXPI = 100
ERROR = [1,0.1,0.1,0,0,100,100,100]

target = torch.Tensor(TARGET)
mu = torch.Tensor(TARGET)
theta = torch.Tensor([THETA]*len(TARGET))
pi = torch.Tensor([MINPI,MINPI,MINPI,MINPI,MINPI,MAXPI,MAXPI,MAXPI])

# Test original zinb function
original_loss = zinb(target, mu, theta, pi)
print(f"Original ZINB Loss: {original_loss.item()}")

# Test original zinb function with error
original_loss = zinb(target, mu+torch.Tensor(ERROR), theta, pi)
print(f"Original ZINB Loss with error term: {original_loss.item()}")

# Test updated zinb_sonnet function
new_loss = zinb_sonnet(target, mu, theta, pi)
print(f"New ZINB Loss: {new_loss.item()}")

# Test nb function
nb_loss = nb(target, mu, theta)
print(f"NB Loss: {nb_loss.item()}")

Original ZINB Loss: 1.615840196609497
Original ZINB Loss with error term: 1.616633653640747
New ZINB Loss: 1.615966796875
NB Loss: 1.179560899734497
