In [None]:
#| default_exp losses_metrics

# Libraries

In [2]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import sklearn
import numpy as np

from fastai.callback.core import Callback

In [1]:
#| hide
import nbdev; nbdev.nbdev_export()

# $\beta$-VAE loss

In [3]:
#| export
def kl_divergence(mu, 
                  logvar):
    ''' 
    Computes the D_KL between two normal distributions N(mu, exp(logvar)) and N(0,1) (this last is the prior)
    '''
    batch_size = mu.size(0)
    assert batch_size != 0
    if mu.data.ndimension() == 4:
        mu = mu.view(mu.size(0), mu.size(1))
    if logvar.data.ndimension() == 4:
        logvar = logvar.view(logvar.size(0), logvar.size(1))

    klds = -0.5*(1 + logvar - mu.pow(2) - logvar.exp())
    total_kld = klds.sum(1).mean(0, True)
    dimension_wise_kld = klds.mean(0)
    mean_kld = klds.mean(1).mean(0, True)

    return total_kld, dimension_wise_kld, mean_kld

In [7]:
#| export
class beta_mse_loss(torch.nn.modules.loss._Loss):
    
    def __init__(self, 
                 loss_objective = 'H', # 'H' for original (Higgins et al.2017), 'B' for version (Burgess et al. 2017)
                 beta = None, # Regularizing factor in betaVAE
                 gamma = None, # Regularizing factor in Burgess betaVAE
                 C_max = None, # Maximum value of C
                 C_stop_iter = None, # When to reach the C_max
                 recon_objective = 'mse', # either 'mse' (Mean Squared Error) or 'bce' (Binary Cross Entropy)
                 reduction = None, # reduction applied in the mse
                 mse_as_konny = False # if True, adapts the MSE loss to have exactly what they have in https://github.com/1Konny/Beta-VAE/blob/master/solver.py
                 ) -> None:
        
        super().__init__()

        if reduction is not None:
            self.reduction = reduction
        
        self.loss_objective = loss_objective
    
        # For original objective
        self.beta = beta
        # For update objective
        self.gamma = gamma
        self.C_max = C_max
        self.C_stop_iter = C_stop_iter
        self.global_iter = 0

        self.mse_as_konny = mse_as_konny
        
        self.recon_objective = recon_objective
        
    def forward(self,
                input: tuple, # prediction of the model given as recon, mu, logvar
                target: torch.Tensor, # target for the reconstruction
                separate_loss = False # if giving the two parts of the loss separatedly
                ) -> torch.Tensor:        
        
        # Separate the input into different variables
        recon, mu, logvar = input
        
        if recon.shape != target.shape:
            # Most typically, we will have a input that has an extra dimension due to the number of channels
            recon = recon.squeeze(1)

        #### COMPUTE RECONSTRUCTION LOSS ####
        if self.recon_objective == 'mse':
            if self.mse_as_konny:
                recon = torch.sigmoid(recon)                
                rec_loss = torch.nn.functional.mse_loss(recon, target, reduction = 'sum').div(target.shape[0])
            else:
                rec_loss = torch.nn.functional.mse_loss(recon, target, reduction = self.reduction)
            
        else:
            rec_loss = torch.nn.functional.binary_cross_entropy_with_logits(recon, target, reduction = self.reduction, size_average = False)  

        #### COMPUTE LATENT LOSS ####
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)                
                       
        #### MERGE BOTH LOSSES ####
        if self.loss_objective == 'H': # as in original betaVAE
            if separate_loss:
                return rec_loss, self.beta*total_kld
            else:
                return rec_loss + self.beta*total_kld

        elif self.loss_objective == 'B': # as in Burgess betaVAE

            # Calculate the C. Note that you have to update the self.global_iter by means the betaLoss_C_scheduler callback
            C = torch.clamp(torch.tensor([self.C_max/self.C_stop_iter*self.global_iter]), 0, self.C_max).to(total_kld.device)

            if separate_loss:
                return rec_loss, self.gamma*(total_kld-C).abs()
            else:
                return rec_loss + self.gamma*(total_kld-C).abs()

### Testing

In [8]:
x = torch.rand((12,32))
bs = 14; z_dim = 5
mu = torch.zeros((bs, z_dim))#+0.5
logvar = torch.zeros_like(mu)+0.2
input = (x, mu, logvar)
loss = beta_mse_loss(loss_objective='H', beta = 1)
loss(input, x, separate_loss = False)

tensor([0.0535])

In [9]:
loss = beta_mse_loss(loss_objective='B', C_max = 20, gamma = 1, C_stop_iter=100)
# Should be equal to previous at iter = 0 if gamma = beta
loss(input, x, separate_loss = False)

tensor([0.0535])

# TC-loss

In [10]:
#| export
class btcvae_loss(nn.modules.loss._Loss):
    """
    Beta-TC-VAE loss with the same interface as `beta_mse_loss`.

    Parameters
    ----------
    n_data : int
        Total number of training samples (for minibatch-weighted estimates).
    alpha : float
        Weight for mutual information term I[z;x].
    beta : float
        Weight for total correlation term TC[z].
    gamma : float
        Weight for dimension-wise KL term sum_i KL[q(z_i) || p(z_i)].
    is_mss : bool
        If True, use minibatch stratified sampling (MSS). Else, use weighted (MWS).
    steps_anneal : int or None
        If set, linearly anneal the gamma term from 0→1 over this many global steps.
    recon_objective : {'mse','bce'}
    reduction : {'none','mean','sum'} or None
    mse_as_konny : bool
        Match 1Konny’s Beta-VAE MSE variant (sigmoid+sum/bs).
    """

    def __init__(self,
                 n_data: int,
                 alpha: float = 1.0,
                 beta: float = 6.0,
                 gamma: float = 1.0,
                 is_mss: bool = True,
                 steps_anneal: int | None = None,
                 recon_objective: str = 'mse',
                 reduction: str | None = None,
                 mse_as_konny: bool = False) -> None:
        super().__init__()

        # config matching your previous loss style
        if reduction is not None:
            self.reduction = reduction

        self.recon_objective = recon_objective
        self.mse_as_konny = mse_as_konny

        # BTC-VAE specifics
        self.n_data = int(n_data)
        self.alpha = float(alpha)
        self.beta = float(beta)
        self.gamma = float(gamma)
        self.is_mss = bool(is_mss)

        # annealing (mirrors your global_iter pattern)
        self.steps_anneal = steps_anneal if steps_anneal is None else int(steps_anneal)
        self.global_iter = 0  # update this externally if you want annealing

    # ------------------------ helpers ------------------------
    @staticmethod
    def _log_density_gaussian(x, mu, logvar):
        # returns log N(x | mu, diag(exp(logvar))) per-sample, per-dim
        return -0.5 * (math.log(2 * math.pi) + logvar + (x - mu) ** 2 / logvar.exp())

    @staticmethod
    def _matrix_log_density_gaussian(z, mu, logvar):
        """
        Compute pairwise log-prob: for each sample z_b, its log prob under every q(z|x_j).
        z:      (B, D)
        mu:     (B, D)
        logvar: (B, D)
        returns: (B, B, D) where [b, j, d] = log q(z_b[d] | x_j)
        """
        B, D = z.shape
        z_ = z.unsqueeze(1)         # (B, 1, D)
        mu_ = mu.unsqueeze(0)       # (1, B, D)
        lv_ = logvar.unsqueeze(0)   # (1, B, D)
        return -0.5 * (math.log(2 * math.pi) + lv_ + (z_ - mu_) ** 2 / lv_.exp())

    @staticmethod
    def _logsumexp(x, dim=-1):
        m, _ = torch.max(x, dim=dim, keepdim=True)
        return m + torch.log(torch.sum(torch.exp(x - m), dim=dim, keepdim=True))

    def _estimate_log_qz_terms(self, z, mu, logvar):
        """
        Estimate:
          log_qz      ~ log q(z)            (joint)
          log_prod_qzi~ sum_i log q(z_i)    (product of marginals)
          log_q_zCx   ~ log q(z|x)          (conditional, per-sample)
        Using MWS or MSS estimators as in Chen et al. (2018).
        """
        B, D = z.shape
        # log q(z|x): own-sample conditional (sum over dims)
        log_q_zCx = self._log_density_gaussian(z, mu, logvar).sum(dim=1)  # (B,)

        # pairwise per-dim log q(z_i | x_j)
        log_qzi_xj = self._matrix_log_density_gaussian(z, mu, logvar)     # (B, B, D)

        # weights
        if self.is_mss:
            # MSS: uniform over the batch (stratified)
            # log(1/B) weights for each j
            logiw = -math.log(B)
            logiw_vec = z.new_full((B,), logiw)
        else:
            # MWS: importance weights scaled by dataset size
            # log(1/N) for each j, but summed over j in minibatch => add log(B)
            # Equivalent to log(N) correction via logsumexp trick
            # Here we compute log(1/N) + logsumexp over j ⇒ subtract log(N) after logsumexp
            logiw_vec = z.new_full((B,), -math.log(self.n_data))

        # joint: log q(z) = log ∑_j q(z|x_j) * w_j
        # per-sample: for each b, sum over j then over dims
        # log ∑_j exp(∑_d log q(z_b[d] | x_j)) + log w_j
        log_qz_j = log_qzi_xj.sum(dim=2)  # (B, B) dim-sum first
        log_qz_weighted = log_qz_j + logiw_vec.unsqueeze(0)  # broadcast over j
        log_qz = self._logsumexp(log_qz_weighted, dim=1).squeeze(-1)  # (B,)

        # product of marginals: ∑_i log q(z_i) with q(z_i) = ∑_j q(z_i|x_j) * w_j
        # compute per-dim LSE, then sum over dims
        log_qzi_j = log_qzi_xj  # (B, B, D)
        log_qzi_weighted = log_qzi_j + logiw_vec.view(1, B, 1)
        log_qzi = self._logsumexp(log_qzi_weighted, dim=1).squeeze(1)  # (B, D)
        log_prod_qzi = log_qzi.sum(dim=1)  # (B,)

        return log_qz, log_prod_qzi, log_q_zCx

    def _reconstruction(self, recon, target):
        # mirror your handling of shapes/objectives
        if recon.shape != target.shape:
            recon = recon.squeeze(1)

        if self.recon_objective == 'mse':
            if self.mse_as_konny:
                recon_act = torch.sigmoid(recon)
                return F.mse_loss(recon_act, target, reduction='sum').div(target.shape[0])
            else:
                # use the reduction attribute if user set it, else default 'mean'
                reduction = getattr(self, 'reduction', 'mean')
                return F.mse_loss(recon, target, reduction=reduction)

        # BCE with logits (matches your previous style)
        reduction = getattr(self, 'reduction', 'mean')
        return F.binary_cross_entropy_with_logits(recon, target, reduction=reduction)

    def _anneal(self):
        if self.steps_anneal is None or self.steps_anneal <= 0:
            return 1.0
        # linear 0→1 over steps_anneal using self.global_iter
        t = min(max(self.global_iter, 0), self.steps_anneal)
        return float(t) / float(self.steps_anneal)

    # ------------------------ public API ------------------------
    def forward(self,
                input: tuple,           # (recon, mu, logvar)
                target: torch.Tensor,   # reconstruction target
                separate_loss: bool = False
                ) -> torch.Tensor:

        recon, mu, logvar = input

        # 1) reconstruction loss
        rec_loss = self._reconstruction(recon, target)

        # 2) sample z ~ q(z|x) using reparam trick
        eps = torch.randn_like(mu)
        z = mu + eps * (0.5 * logvar).exp()

        # 3) compute MI, TC, and dimension-wise KL pieces
        log_pz = (-0.5 * (math.log(2 * math.pi) + z ** 2)).sum(dim=1)  # standard normal prior
        log_qz, log_prod_qzi, log_q_zCx = self._estimate_log_qz_terms(z, mu, logvar)

        mi = (log_q_zCx - log_qz).mean()           # I[z;x]
        tc = (log_qz - log_prod_qzi).mean()        # TC[z]
        dwkl = (log_prod_qzi - log_pz).mean()      # ∑ KL[q(z_i)||p(z_i)]

        # 4) annealed gamma on dwkl (optional)
        anneal = self._anneal()

        reg = self.alpha * mi + self.beta * tc + anneal * self.gamma * kl_divergence(mu, logvar)[0]
        total = rec_loss + reg

        if separate_loss:
            return rec_loss, reg
        return total

### Testing

In [12]:
x = torch.rand((50,32))
bs = x.shape[0]; z_dim = 5
mu = torch.zeros((bs, z_dim))#+0.5
logvar = torch.zeros_like(mu)+0.2
input = (x+torch.randn_like(x)*0.5, mu, logvar)
loss = beta_mse_loss(loss_objective='H', beta = 1)


loss_tc = btcvae_loss(n_data = x.shape[0], 
                        alpha = 0,
                        beta = 0,
                        gamma = 1)


assert loss(input, x, separate_loss = True) == loss_tc(input, x, separate_loss = True)

# Mutual Information Gap

In [None]:
#| export
def get_mig(z, inp, true_dimensions, bins = 20, return_entropy = False, normalized_MI = False):

    z_dim = z.shape[-1]
    mi = np.zeros((true_dimensions, z_dim))
    entropy = []
    for idx_z in range(z_dim):
        zi = z[:,idx_z]
        zi_d = np.digitize(zi, np.histogram(zi, bins = bins)[1][:-1])
    
        for idx_c in range(true_dimensions):
            cj = inp[:,idx_c].cpu().detach().numpy()
            cj = np.digitize(cj, np.histogram(cj, bins = bins)[1][:-1])
            if normalized_MI:
                mi[idx_c, idx_z] = sklearn.metrics.normalized_mutual_info_score(cj, zi_d)  
            else:        
                mi[idx_c, idx_z] = sklearn.metrics.mutual_info_score(cj, zi_d)   
            
            # Calculate the entropy (just once)
            if idx_z == 0:
                if normalized_MI:
                    entropy.append(sklearn.metrics.normalized_mutual_info_score(cj, cj))
                else:
                    entropy.append(sklearn.metrics.mutual_info_score(cj, cj))

    mi_s = mi.copy()
    mi_s.sort(axis = -1)
    mig = np.mean((mi_s[:,-1]-mi_s[:,-2])/np.array(entropy))

    if return_entropy:
        return mi, mig, entropy
    else:        
        return mi, mig 

# Callbacks

## $\beta$ scheduler

In [None]:
#| export
class betaLoss_C_scheduler(Callback):
    '''
    Updates the C parameter of the Burgess et al. (2018) version of the betaVAE
    '''
    def after_batch(self):
        self.learn.loss_func.global_iter += 1

## Saving logvars during training

In [None]:
#| export
class save_logvars(Callback):
    
    def __init__(self, actions = torch.tensor([[0,1],
                                               [1,0]], dtype=torch.float32)
                ):
        
        self.actions = actions
        self.num_actions = actions.shape[0]
     
    def before_fit(self):        
        self.logvars = torch.ones((self.learn.n_epoch, self.num_actions, self.learn.model.dim_z))        
        
    
    def after_epoch(self):
        
        self.logvars[self.learn.epoch] = self.learn.model.E_a(self.actions.to('cuda'))

In [None]:
#| export
class save_air_logvars(Callback):
    ''' 
    Save the logvars for a AIR model with logvars as biases (AIR_v0 type).
    '''
    def __init__(self, 
                 save_each = 500 # every how many iterations to save the logvars
                ):
        self.save_each = save_each  
        self.save_idx = 0

    def before_fit(self):
        
        self.logvars = torch.zeros(int(len(self.learn.dls[0])*self.learn.n_epoch/self.save_each), 
                                   self.learn.model.z_dim)
        
    def after_batch(self):        
        if self.learn.iter % self.save_each == 0:
            try:
                self.logvars[self.save_idx] = self.learn.model.logvar.detach().cpu()
                self.save_idx += 1
            except:
                self.logvars = torch.concat((self.logvars, self.learn.model.logvar.detach().cpu().unsqueeze(0)), dim = 0)
            
    

In [None]:
#| export         
class save_bvae_logvars(Callback):
    
    '''
    Save the logvars for a typical bvae model.
    This implies performing a forward pass of the model with a random subset of the dataset.
    ''' 
    
    def __init__(self,
                 size_batch_logvars = 1000, # size of the training set that will be used for predicting the logvars
                 save_each = 500 # every how many iterations to save the logvars
                ):
        self.size_batch_logvars = size_batch_logvars
        self.save_each = save_each
        self.save_idx = 0

    def before_fit(self):
        
        self.logvars = torch.zeros(int(len(self.learn.dls[0])*self.learn.n_epoch/self.save_each), 
                                   self.size_batch_logvars, 
                                   self.learn.model.z_dim)
        
    def after_batch(self):        
        
        if self.learn.iter % self.save_each == 0:
            
            batch = self.learn.dls.dataset.__getitem__(torch.randint(0, self.dls.dataset.__len__(), size = (self.size_batch_logvars,)))[0]
            self.model.eval()
            _, _, logvars = self.learn.model.forward(batch)

            try:
                self.logvars[self.save_idx] = logvars.detach().cpu()            
                self.save_idx += 1
            except:
                self.logvars = torch.concat((self.logvars, logvars.detach().cpu().unsqueeze(0)), dim = 0)
            
            self.model.train()
            

## Save separate losses during training

In [None]:
#| export
class save_separate_losses(Callback):
    
    def __init__(self, loader_test):
        
        self.loader_test = loader_test
        self.mse_loss = []
        self.kl_loss = []

    
    def after_epoch(self):
        
        mean_mse, mean_kl = [], []
        for batch in self.loader_test:            
            pred = self.learn.model(batch[0])
            current_loss = self.learn.loss_func.forward(input = pred, target = batch[1], separate_loss= True)
            mean_mse.append(current_loss[0].cpu().detach())
            mean_kl.append(current_loss[1].cpu().detach())
            
        self.mse_loss.append(float(torch.mean(torch.tensor(mean_mse))))
        self.kl_loss.append(float(torch.mean(torch.tensor(mean_kl))))