# Soft Decision Tree Loss

Train the soft decision tree using a loss function that seeks to minimize the cross entropy between each leaf, weighted by its path probability, and the target distribution. For a single training case with input vector x and target distribution `T`, the loss is:

![loss](images/loss.png)

## Regularizer

To avoid getting stuck at poor solutions during the training, authors introduced a
penalty term that encouraged each internal node to make equal use of both
left and right sub-trees. Without this penalty, the tree tended to get stuck on
plateaus in which one or more of the internal nodes always assigned almost all
the probability to one of its sub-trees and the gradient of the logistic for this
decision was always very close to zero. The penalty is the cross entropy between
the desired average distribution 0.5, 0.5 for the two sub-trees and the actual
average distribution $\alpha$,(1 − $\alpha$) where $\alpha$ for node i is given by

![regularizer](images/regularizer.png)

In [None]:
# default_exp loss

In [None]:
#hide
IN_COLAB = 'google.colab' in str(get_ipython())
if IN_COLAB:
  !pip3 install -Uqq fastbook

In [None]:
#hide
if IN_COLAB:
  from pathlib import Path
  from nbdev.imports import Config
  project_path = Path('/content/drive/My Drive/Colab Notebooks/github/sdt')
  get_ipython().magic(f'cd {project_path}')
  get_ipython().magic(f'cd {Config().nbs_path}')

/content/drive/My Drive/Colab Notebooks/github/sdt
/content/drive/My Drive/Colab Notebooks/github/sdt


In [None]:
#hide
if IN_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from fastai.vision.all import *

## Loss

In [None]:
#export
class SDTLoss(Module):
  def __init__(self, lambda_):
    super().__init__()
    
    self.lambda_ = lambda_
  
  def set_path_prob(self, path_prob): self.path_prob = path_prob
  def set_regularizer(self, numers, denoms): self.numers, self.denoms = numers, denoms
  
  def forward(self, output, target):
    # number of target categories
    target_ohe = torch.zeros((output.shape[0], output.shape[-1])).cuda()

    target = target.view(-1, 1)

    # assert to find out whether shape of target 
    # is similar to output or not
    target_ohe.scatter_(1, target, 1).cuda()
    target_ohe = target_ohe.unsqueeze(dim=2)

    log_output = torch.log(output)
    res = torch.bmm(log_output, target_ohe).squeeze(dim=2)

    # weigh cross entropy over all paths
    res = (self.path_prob * res).sum(dim=1)

    # calculate regularizer
    alphas = self.numers.sum(dim=0) / self.denoms.sum(dim=0)
    lambdas_ = torch.ones_like(alphas)
    
    for i in range(int(np.log2(len(lambdas_) + 1))):
      start_index = 2 ** i - 1
      end_index   = 2 ** i
      lambdas_[start_index:start_index+end_index] = 2 ** (-i)
    
    lambdas_ = self.lambda_ * lambdas_
    
    reg_inner_term = (0.5 * torch.log(alphas) + torch.log(1 - alphas) * 0.5) * lambdas_
    C = -reg_inner_term.sum()
    
    return -res.mean() + C

## Export

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_data.ipynb.
Converted 01_model.ipynb.
Converted 02_loss.ipynb.
Converted 03_train.ipynb.
Converted index.ipynb.
