In [1]:
import torch

def regression_heteroscedastic_loss(true, mean, log_var, metric): 
    '''
    ARGUMENTS:
    true: true values. Tensor (batch_size x number of outputs)
    mean: predictions. Tensor (batch_size x number of outputs)
    log_var: Logaritms of uncertainty estimates. Tensor (batch_size x number of outputs)
    metric: "mae" or "rmse"

    OUTPUTS:
    loss. Tensor (0)
    '''
    precision = torch.exp(-log_var)
    if metric == "mae":
        return torch.mean(torch.sum((2 * precision) ** .5 * torch.abs(true - mean) + log_var / 2, 1), 0)
    elif metric == "rmse" or not metric:   #default is rmse
        return torch.mean(torch.sum(precision * (true - mean) ** 2 + log_var, 1), 0)
    else:
        print("Metric has to be 'rmse' or 'mae'")

def regression_homoscedastic_loss(true, mean, metric):
    '''
    ARGUMENTS:
    true: true values. Tensor (batch_size x number of outputs)
    mean: predictions. Tensor (batch_size x number of outputs)
    metric: "mae" or "rmse"

    OUTPUTS:
    loss. Tensor (0)
    '''
    
    if metric == "mae":
        return torch.mean(torch.sum(torch.abs(true - mean), 1), 0)
    elif metric == "rmse" or not metric:   #default is rmse
        return torch.mean(torch.sum((true - mean) ** 2, 1), 0)
    else:
        print("Metric has to be 'rmse' or 'mae'")