In [1]:
%matplotlib inline

import os
import scanpy as sc
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F



following implementation discussed here: <https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279>

In [20]:
# following implementation discussed here: <https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279>

# ghost batch norm
class GhostBatchNorm(nn.Module):
    def __init__(self, input_dim, vbs=64, momentum=0.01):
        '''
        Arguments:
          vbs (int): (optional, Default=128) virtual batch size. Must be
            smaller than n_samples in mini-batch
          
        '''
        super().__init__()
        self.BN = nn.BatchNorm1d(input_dim, momentum=momentum)
        self.vbs = vbs
        
    def forward(self, x):
        chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
        res = [self.BN(x_sub) for x_sub in chunk]
        return torch.cat(res,0)
        


implementation of SparseMax
- REF: https://github.com/aced125/sparsemax

In [12]:
# implementation of SparseMax
# REF: https://github.com/aced125/sparsemax

def flatten_all_but_nth_dim(ctx, x: torch.Tensor):
    """
    Flattens tensor in all but 1 chosen dimension.
    Saves necessary context for backward pass and unflattening.
    """

    # transpose batch and nth dim
    x = x.transpose(0, ctx.dim)

    # Get and save original size in context for backward pass
    original_size = x.size()
    ctx.original_size = original_size

    # Flatten all dimensions except nth dim
    x = x.reshape(x.size(0), -1)

    # Transpose flattened dimensions to 0th dim, nth dim to last dim
    return ctx, x.transpose(0, -1)


def unflatten_all_but_nth_dim(ctx, x: torch.Tensor):
    """
    Unflattens tensor using necessary context
    """
    # Tranpose flattened dim to last dim, nth dim to 0th dim
    x = x.transpose(0, 1)

    # Reshape to original size
    x = x.reshape(ctx.original_size)

    # Swap batch dim and nth dim
    return ctx, x.transpose(0, ctx.dim)

class Sparsemax(nn.Module):
    __constants__ = ["dim"]

    def __init__(self, dim=-1):
        """
        Sparsemax class as seen in https://arxiv.org/pdf/1602.02068.pdf
        Parameters
        ----------
        dim: The dimension we want to cast the operation over. Default -1
        """
        super(Sparsemax, self).__init__()
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, "dim"):
            self.dim = None

    def forward(self, input):
        return SparsemaxFunction.apply(input, self.dim)

    def extra_repr(self):
        return f"dim={self.dim}"


class SparsemaxFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, dim: int = -1):
        input_dim = input.dim()
        if input_dim <= dim or dim < -input_dim:
            raise IndexError(
                f"Dimension out of range (expected to be in range of [-{input_dim}, {input_dim - 1}], but got {dim})"
            )

        # Save operating dimension to context
        ctx.needs_reshaping = input_dim > 2
        ctx.dim = dim

        if ctx.needs_reshaping:
            ctx, input = flatten_all_but_nth_dim(ctx, input)

        # Translate by max for numerical stability
        input = input - input.max(-1, keepdim=True).values.expand_as(input)

        zs = input.sort(-1, descending=True).values
        range = torch.arange(1, input.size()[-1] + 1)
        range = range.expand_as(input).to(input)

        # Determine sparsity of projection
        bound = 1 + range * zs
        is_gt = bound.gt(zs.cumsum(-1)).type(input.dtype)
        k = (is_gt * range).max(-1, keepdim=True).values

        # Compute threshold
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (zs_sparse.sum(-1, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        output = torch.max(torch.zeros_like(input), input - taus)

        # Save context
        ctx.save_for_backward(output)

        # Reshape back to original shape
        if ctx.needs_reshaping:
            ctx, output = unflatten_all_but_nth_dim(ctx, output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        output, *_ = ctx.saved_tensors

        # Reshape if needed
        if ctx.needs_reshaping:
            ctx, grad_output = flatten_all_but_nth_dim(ctx, grad_output)

        # Compute gradient
        nonzeros = torch.ne(output, 0)
        num_nonzeros = nonzeros.sum(-1, keepdim=True)
        sum = (grad_output * nonzeros).sum(-1, keepdim=True) / num_nonzeros
        grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        # Reshape back to original shape
        if ctx.needs_reshaping:
            ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input)

        return grad_input, None

In [27]:
# (mask*torch.log(mask+1e-10)).mean() #F(x)= -∑xlog(x+eps)


In [69]:
class AttnTransformer(nn.Module):
    def __init__(self, d_a, inp_dim, relax, vbs=64):
        super().__init__()
        self.fc = nn.Linear(d_a, inp_dim)
        self.BN = GhostBatchNorm(inp_dim, vbs=vbs) # instead of inp_idm, out_dim? otherwise, error in the medium post
        self.sparsemax = Sparsemax()
        self.gamma_r = relax
        
    # a := feature from previous decision step
    def forward(self, a, priors): 
        a = self.BN(self.fc(a)) 
        mask = self.sparsemax(a*priors) 
        priors = priors*(self.gamma_r - mask)  #updating the prior
        return mask

In [61]:
class GLU(nn.Module):
    def __init__(self, inp_dim, out_dim, fc=None, vbs=64):
        super().__init__()
        if fc:
            self.fc = fc
        else:
            self.fc = nn.Linear(inp_dim, out_dim*2)
        self.BN = GhostBatchNorm(out_dim*2, vbs=vbs) 
        self.od = out_dim
        
    def forward(self, x):
        x = self.BN(self.fc(x))
        return x[:, :self.od]*torch.sigmoid(x[:, self.od:])
    
class FeatTransformer(nn.Module):
    def __init__(self, inp_dim, out_dim, shared, n_ind, vbs=64):
        super().__init__()
        first = True
        self.shared = nn.ModuleList()
        if shared:
            self.shared.append(GLU(inp_dim, out_dim, shared[0], vbs=vbs))
            first= False    
            for fc in shared[1:]:
                self.shared.append(GLU(out_dim, out_dim, fc, vbs=vbs))
        else:
            self.shared = None
        self.independ = nn.ModuleList()
        if first:
            self.independ.append(GLU(inp, out_dim, vbs=vbs))
        for x in range(first, n_ind):
            self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
        device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.scale = torch.sqrt(torch.tensor([.5], device=device))
        
    def forward(self, x):
        if self.shared:
            x = self.shared[0](x)
            for glu in self.shared[1:]:
                x = torch.add(x, glu(x))
                x = x*self.scale
        for glu in self.independ:
            x = torch.add(x, glu(x))
            x = x*self.scale
        return x

In [62]:
class DecisionStep(nn.Module):
    def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs=64):
        super().__init__()
        self.feat_transformer = FeatTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
        self.attn_transformer =  AttnTransformer(n_a, inp_dim, relax, vbs)
        
    def forward(self, x, a, priors):
        mask = self.attn_transformer(a, priors)
        sparse_loss = ((-1) * mask * torch.log(mask + 1e-10)).mean()
        x = self.feat_transformer(x * mask)
        return x, sparse_loss

In [63]:
class TabNet(nn.Module):
    def __init__(self, inp_dim, final_out_dim,
                 n_d=64, n_a=64,
                 n_shared=2, n_ind=2,
                 n_steps=5, relax=1.2, vbs=64):
        super().__init__()
        
        if n_shared>0:
            self.shared = nn.ModuleList()
            self.shared.append(nn.Linear(inp_dim, 2*(n_d + n_a)))
            
            for x in range(n_shared-1):
                self.shared.append(nn.Linear(n_d + n_a, 2*(n_d + n_a)))
        else:
            self.shared = None
        self.first_step = FeatTransformer(inp_dim, n_d+n_a, self.shared, n_ind) 
        self.steps = nn.ModuleList()
        
        for x in range(n_steps - 1):
            self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
        
        self.fc = nn.Linear(n_d, final_out_dim)
        self.bn = nn.BatchNorm1d(inp_dim)
        self.n_d = n_d
        
    def forward(self, x):
        x = self.bn(x)
        x_a = self.first_step(x)[:,self.n_d:]
        sparse_loss = torch.zeros(1).to(x.device)
        out = torch.zeros(x.size(0),self.n_d).to(x.device)
        priors = torch.ones(x.shape).to(x.device)
        for step in self.steps:
            x_te, l = step(x,x_a,priors)
            out += F.relu(x_te[:,:self.n_d])
            x_a = x_te[:,self.n_d:]
            sparse_loss += l
        return self.fc(out), sparse_loss

In [70]:
model = TabNet(3, 2).to(torch.device('cuda'))
model(c.to(torch.device('cuda'))) 

(tensor([[ 4.3863e-01, -5.7510e-01],
         [-3.5529e-01, -5.2496e-01],
         [-4.1787e-01, -5.9366e-01],
         [ 2.0010e-01, -9.9658e-01],
         [ 6.0293e-01, -6.2393e-01],
         [-1.9921e-01, -2.7227e-01],
         [-1.4657e-01, -1.3763e-01],
         [ 1.0933e-01, -4.4406e-01],
         [ 5.9519e-01, -1.4738e+00],
         [ 4.8816e-01, -8.5086e-01],
         [-6.8618e-02, -6.8230e-01],
         [ 7.9431e-01, -1.6560e+00],
         [-5.2185e-01, -7.0093e-01],
         [-5.5549e-01, -3.9888e-01],
         [-3.0898e-01, -3.9262e-01],
         [-5.7943e-01, -4.5116e-01],
         [-2.3193e-01, -1.1476e+00],
         [ 5.2778e-01, -7.2030e-01],
         [ 2.2045e-01, -4.8974e-01],
         [ 5.2555e-02, -2.8357e-01],
         [-4.4630e-01, -3.1710e-01],
         [-1.5227e-01, -2.6660e-01],
         [-1.7841e-02, -4.4746e-01],
         [-2.3321e-01, -9.1761e-01],
         [ 4.4332e-01, -9.1119e-01],
         [-9.2041e-01, -1.0135e+00],
         [ 2.2039e-01, -7.0681e-01],
 

In [71]:
# # to train:
# critrion = nn.BCELoss()
# model.train()
# optimizer = optim.Adam(model.parameters(),lr=0.007809719000164987,weight_decay=0.00001)
# # sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1,patience=3,verbose=True)
# for i, tar in train_loader:
#     out, l = model(...)
#     optimizer.zero_grad()
#     loss = critrion(out, tar.to(device)) + l*sparse_constant
#     loss.backward()
#     optimizer.step()
#     # sched.step(losses[-1]) # if on val set 
