In [67]:
%load_ext autoreload
%autoreload 2 
%nbdev_default_export tabnet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Cells will be exported to nbdev_template.tabnet,
unless a different module is specified after an export flag: `%nbdev_export special.module`


In [2]:
from fastai.tabular.all import * 

# Creating the model 

The TabNet model consists of the following parts (which we'll implement):
1. Attentive Transformer - aka, the `MaskBlock` which is in charge of creating the mask given the previous step's output
1. Feature Transformer - which given the masked features is in charge of processing them to a better representation. Internally it's made out of multiple `FeatureTransformerBlock`s. 

These two are the building blocks for the `TabNetStep` which is represents the steps in a `TabNet`. 

Todo: 
1. Use `ModuleDict` 
1. Use GLU
1. sqrt(0.5) 
1. save masks for interpretation

The mask block creates a mask from the vector that it gets from the previous step (of size, `n_transformed_features`), 

In [95]:
#exporti 

class MaskBlock(Module):
    
    def __init__(self, n_features, n_transformed_features, max_func=nn.Softmax(dim=1), gamma=1):
        store_attr('gamma,max_func')
        
        self.fc = LinBnDrop(n_transformed_features, n_features, lin_first=True)
        self.prior = torch.ones(n_features, dtype=torch.float)
        
    def forward(self, prev):
        mask = self._calc_mask(prev)
        self.prior = self._calc_prior_scales(mask)
        return mask
        
    def _calc_mask(self, prev):
        return self.max_func(self.fc(prev) * self.prior)
    
    def _calc_prior_scales(self, mask):
        return (self.gamma - mask)*self.prior

In [None]:
n_samples = 3
n_features = 32
n_transformed_features = 8 
mb = MaskBlock(n_features, n_transformed_features)
a = torch.randn((n_samples,n_transformed_features), dtype=torch.float)
test_close(mb(a).sum(dim=1),1)

The `FeatureTransformerBlock` uses a GLU in the end which halves the size of the input. For us to get the same size output we need the linear layers to double the output size.

In [84]:
#exporti 

class FeatureTransformerBlock(Module):
    
    def __init__(self, n_features, n_transformed_features, add_input=False):
        store_attr('n_features,n_transformed_features,add_input')
        
        self.b1 = self._inner_block(self.n_features, 2*self.n_transformed_features)
        self.b2 = self._inner_block(self.n_transformed_features, 2*self.n_transformed_features)
        

    def forward(self, inp):
        x1 = self.b1(inp)
        if self.add_input: x1 += inp
        x2 = self.b2(x1)
        return x1 + x2
        
    
    def _inner_block(self, n_in, n_out):
        return nn.Sequential(LinBnDrop(n_in, n_out, lin_first=True), nn.GLU())

In [85]:
a = torch.randn((n_samples,n_features), dtype=torch.float)
ft = FeatureTransformerBlock(n_features, n_transformed_features)
test_eq(ft(a).shape, (n_samples, n_transformed_features))

In [86]:
#exporti 

class FeatureTransformer(nn.Sequential):
    
    def __init__(self, n_features, n_transformed_features, n_trans_blocks=2):
        FeatureTransformer.shared_block = self._create_block(n_features, n_transformed_features)
        layers = [
            FeatureTransformer.shared_block, 
            *[self._create_block(n_transformed_features, n_transformed_features, add_input=True) 
                                                                      for i in range(n_trans_blocks-1)]
        ]

        super().__init__(*layers)
        
        
    def _create_block(self, n_features, n_transformed_features, **kwargs):
        return FeatureTransformerBlock(n_features, n_transformed_features, **kwargs)

In [87]:
a = torch.randn((n_samples,n_features), dtype=torch.float)
trans = FeatureTransformer(n_features, n_transformed_features, 2)
test_eq(trans(a).shape, (n_samples, n_transformed_features))

In [88]:
#exporti 

def _split(x):
    N = x.shape[-1]//2
    return x[:,:N], x[:,N:]

In [89]:
Split = Lambda(_split)
res = Split(torch.randn((3,4)))
test_eq(len(res), 2)
test_eq(res[0].shape, (3,2))

Notice that the `MaskBlock`'s input from the previous part is $\frac{1}{2}$ of the `n_transformed_features`. 

In [90]:
#exporti 

class TabNetStep(Module):
    
    @delegates(FeatureTransformer)
    def __init__(self, n_features, n_transformed_features, activ=nn.ReLU(), **kwargs):
        store_attr('activ')

        self.mask = MaskBlock(n_features, n_transformed_features//2)
        self.feat_trans = FeatureTransformer(n_features, n_transformed_features, **kwargs)
        
        
    def forward(self, features, prev):
        
        masked = self.mask(prev) * features 
        transformed = self.feat_trans(masked)
        output, next_input = Split(transformed)
        output = self.activ(output)
        return output, next_input
        


In [91]:
a = torch.randn((n_samples, n_features))
tab_step = TabNetStep(n_features, n_transformed_features)
prev = torch.randn((n_samples, n_transformed_features//2))
res = tab_step(features=a, prev=prev)
test_eq(len(res), 2)
test_eq(torch.cat(res, dim=1).shape, (n_samples, n_transformed_features))

In [92]:
#export

class TabNet(Module):
    
    @delegates(TabNetStep)
    def __init__(self, n_features, n_transformed_features, n_steps=2, **kwargs):        
        
        self.init = nn.Sequential(
            BatchNorm(n_features, ndim=1),
            FeatureTransformer(n_features, n_transformed_features), 
            Split,
        )
        
        self.steps = nn.ModuleList(
            [TabNetStep(n_features, n_transformed_features, **kwargs) for i in range(n_steps)]
        )
        
    
    def forward(self, features):
        outputs = []
  
        output, next_input = self.init(features)     
        
        for step in self.steps:        
            output, next_input = step(features, next_input)
            outputs.append(output)
    

        outputs = torch.cat(outputs, dim=0)
        return torch.sum(outputs, dim=0)

In [93]:
a = torch.randn((n_samples, n_features))
tabnet = TabNet(n_features, n_transformed_features)
test_eq(len(tabnet(a)), n_transformed_features//2)

# Export

In [96]:
from nbdev.export import notebook2script
notebook2script()

Converted 00_review_prev_work.ipynb.
Converted 01_model.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
