In [3]:
%load_ext autoreload
%autoreload 2
#default_exp core

This is a `sparsemax` implementation, taken from [here](https://github.com/dreamquark-ai/tabnet/blob/5df2dd19d9fe10e38f4ba13a4c2d6f9707e3f65d/pytorch_tabnet/sparsemax.py). Article [here](https://arxiv.org/abs/1602.02068v2)

In [20]:
#exporti
import torch
from torch import nn
from fastai.tabular.all import * 
from torch.autograd import Function

In [21]:
#exporti
def _make_ix_like(input, dim=0):
    d = input.size(dim)
    rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype)
    view = [1] * input.dim()
    view[0] = -1
    return rho.view(view).transpose(0, dim)


In [22]:
#exporti
class SparsemaxFunction(Function):
    """
    An implementation of sparsemax (Martins & Astudillo, 2016). See
    :cite:`DBLP:journals/corr/MartinsA16` for detailed description.
    By Ben Peters and Vlad Niculae
    """

    @staticmethod
    def forward(ctx, input, dim=-1):
        """sparsemax: normalizing sparse transform (a la softmax)
        Parameters:
            input (Tensor): any shape
            dim: dimension along which to apply sparsemax
        Returns:
            output (Tensor): same shape as input
        """
        ctx.dim = dim
        max_val, _ = input.max(dim=dim, keepdim=True)
        input -= max_val  # same numerical stability trick as for softmax
        tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim)
        output = torch.clamp(input - tau, min=0)
        ctx.save_for_backward(supp_size, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        supp_size, output = ctx.saved_tensors
        dim = ctx.dim
        grad_input = grad_output.clone()
        grad_input[output == 0] = 0

        v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
        v_hat = v_hat.unsqueeze(dim)
        grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
        return grad_input, None

    @staticmethod
    def _threshold_and_support(input, dim=-1):
        """Sparsemax building block: compute the threshold
        Args:
            input: any dimension
            dim: dimension along which to apply the sparsemax
        Returns:
            the threshold value
        """

        input_srt, _ = torch.sort(input, descending=True, dim=dim)
        input_cumsum = input_srt.cumsum(dim) - 1
        rhos = _make_ix_like(input, dim)
        support = rhos * input_srt > input_cumsum

        support_size = support.sum(dim=dim).unsqueeze(dim)
        tau = input_cumsum.gather(dim, support_size - 1)
        tau /= support_size.to(input.dtype)
        return tau, support_size



In [23]:
#export
sparsemax = SparsemaxFunction.apply


class Sparsemax(Module):

    def __init__(self, dim=-1):
        self.dim = dim
        

    def forward(self, input):
        return sparsemax(input, self.dim)


In [24]:
a = torch.randn(5); print(a)
sparsemax(a)

tensor([ 0.4707, -2.2783, -0.2929,  0.4337, -0.6759])


tensor([0.5185, 0.0000, 0.0000, 0.4815, 0.0000])

In [25]:
a = torch.tensor([-100,2,3,4.0])
sparsemax(a)

tensor([0., 0., 0., 1.])

In [31]:
#export
class GBN(Module):
    """
        Ghost Batch Normalization
        https://arxiv.org/abs/1705.08741
    """

    def __init__(self, input_dim, virtual_batch_size=128, momentum=0.02):
        self.input_dim = input_dim
        self.virtual_batch_size = virtual_batch_size
        self.bn = BatchNorm(self.input_dim, momentum=momentum, ndim=1)

    def forward(self, x):
        chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
        res = [self.bn(x_) for x_ in chunks]

        return torch.cat(res, dim=0)


# Export

In [32]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_review_prev_work.ipynb.
Converted 01_core.ipynb.
Converted 01_model.ipynb.
Converted 03_experiments.ipynb.
Converted index.ipynb.
