In [1]:
import torch
import torch.nn as nn

In [2]:
def relu_evidence(y):
    return torch.nn.functional.relu(y)

def softplus_evidence(y):
    return torch.nn.functional.softplus(y)

def kl_divergence(alpha, device='cpu'):
    num_classes=alpha.shape[-1]
    ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
    sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
    first_term = (
        torch.lgamma(sum_alpha)
        - torch.lgamma(alpha).sum(dim=1, keepdim=True)
        + torch.lgamma(ones).sum(dim=1, keepdim=True)
        - torch.lgamma(ones.sum(dim=1, keepdim=True))
    )
    second_term = (
        (alpha - ones)
        .mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
        .sum(dim=1, keepdim=True)
    )
    kl = first_term + second_term
    return kl

def loglikelihood_loss(y, alpha, device='cpu'):
    y = y#.to(device)
    alpha = alpha#.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)
    loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
    loglikelihood_var = torch.sum(
        alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
    )
    loglikelihood = loglikelihood_err + loglikelihood_var
    return loglikelihood

def edl_loss(func, y, alpha, epoch_num, annealing_step, device=None):
    y = y#.to(device)
    alpha = alpha#.to(device)
    S = torch.sum(alpha, dim=1, keepdim=True)

    A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, device=device)
    return A + kl_div

def mse_loss(y, alpha, epoch_num, annealing_step=10, device='cpu'):
    y = y#.to(device)
    alpha = alpha#.to(device)
    loglikelihood = loglikelihood_loss(y, alpha, device=device)

    annealing_coef = torch.min(
        torch.tensor(1.0, dtype=torch.float32),
        torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
    )

    kl_alpha = (alpha - 1) * (1 - y) + 1
    kl_div = annealing_coef * kl_divergence(kl_alpha, device=device)
    return loglikelihood + kl_div

def edl_log_loss(output, target, epoch_num, annealing_step, device=None):
    evidence = relu_evidence(output)
    alpha = evidence + 1
    loss = torch.mean(
        edl_loss(
            torch.log, target, alpha, epoch_num, annealing_step, device
        )
    )
    return loss

class PenalizedTanh(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        x[x>0]=nn.Tanh()(x[x>0])
        x[x<=0]=0.25*nn.Tanh()(x[x<=0])
        return x

class EvidentialMSELoss():
    """
    I made this a class so that you can give it certain default parameters (like evidence_activation_fx) and can then pass it to a fitter
    Because of the '__call__' method, you can still just call the EdlMSELoss object like you would a normal function
    """
    def __init__(self, evidence_activation_fx=torch.nn.functional.softplus, annealing_step=10, device='cpu'):
        self.evidence_activation_fx, self.annealing_step, self.device = evidence_activation_fx, annealing_step, device
        # todo take out device from loss fxs now that it is in class init
    def _one_hot_embedding(self, labels, num_classes):
        y = torch.eye(num_classes).float().to(self.device) 
        return y[labels]

    def edl_mse_loss(self, output, target, epoch_num, device='cpu', losswts=None, reduction='mean'):
        y_onehot = self._one_hot_embedding(target, num_classes=output.shape[-1])
        evidence = self.evidence_activation_fx(output)
        alpha = evidence + 1
        loss = mse_loss(y_onehot, alpha, epoch_num, self.annealing_step, device=device)
        if losswts is not None: loss*=losswts.view(-1,1)
        if reduction=='mean': loss = torch.mean(loss)
        return loss.flatten()
    
    def edl_digamma_loss(self, output, target, epoch_num, device='cpu', losswts=None, reduction='mean'):
        y_onehot = self._one_hot_embedding(target, num_classes=output.shape[-1])
        evidence = self.evidence_activation_fx(output)
        alpha = evidence + 1
        loss = edl_loss(torch.digamma, y_onehot, alpha, epoch_num, self.annealing_step, device)
        if losswts is not None: loss=loss.view(-1,output.shape[-1])*losswts.view(-1,1)
        if reduction=='mean': loss = torch.mean(loss)
        return loss.flatten()
    
    def __call__(self, *args, **kwargs):
        return self.edl_mse_loss(*args, **kwargs)
        # return self.edl_digamma_loss(*args, **kwargs)

In [None]:
EvidentialMSELoss(evidence_activation_fx=torch.nn.functional.relu)

In [3]:
import fastai