In [1]:
%load_ext autoreload
%autoreload 2 
#default_exp model

In [2]:
#exporti
from fastai.tabular.all import * 
from tabnet.core import Sparsemax, GBN
from mock import Mock

In [3]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# TabNet 

### Encoder

In [4]:
#exporti
def _initial_block(n_in, n_out):
    return nn.Linear(n_in, 2*n_out, bias=False)

def _rest_block(n):
    return nn.Linear(n, 2*n, bias=False)

def _create_shared_blocks(n_in, n_out, n_shared):
    return [_initial_block(n_in, n_out)] + \
            [_rest_block(n_out) for _ in range(n_shared-1)]

def _combine_cat_cont(x_cat, x_cont, embeds):
    if x_cat.shape[1] != 0: 
        x = [e(x_cat[:,i]) for i,e in enumerate(embeds)]
        x_cat = torch.cat(x, 1)

    x = torch.cat([x_cat, x_cont], 1)
    
    return x

In [5]:
#export
class TabNetBase(Module):
    def __init__(self, n_d=64, n_a=64, n_steps=3, n_shared_ft_blocks=2, n_independent_ft_blocks=2, 
                 virtual_batch_size=128, momentum=0.2, **kwargs):
        store_attr()
    
    
    def _create_feature_transform(self, shared_ft_blocks):
        return FeatureTransformer(self.n_d+self.n_a, shared_ft_blocks,
                                          self.n_independent_ft_blocks,
                                          self.virtual_batch_size, self.momentum)


In [79]:
#exporti
class TabNetEnc(TabNetBase):
    
    @delegates(TabNetBase.__init__)
    def __init__(self, emb_szs, n_cont, gamma=1.5, **kwargs):        
        store_attr()
        super().__init__(**kwargs)
        
        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in self.emb_szs])
        self.n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_features = self.n_emb + self.n_cont
        
        shared_ft_blocks = _create_shared_blocks(self.n_features, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.initial_ft = self._create_feature_transform(shared_ft_blocks)        
        self.initial_bn = BatchNorm(self.n_features, ndim=1)    
        
        self.att_steps = nn.ModuleList([AttentiveTransformer(self.n_a, self.n_features,self.virtual_batch_size, 
                                                             self.momentum) for i in range(self.n_steps)])
        
        self.ft_steps = nn.ModuleList([self._create_feature_transform(shared_ft_blocks) for _ in range(self.n_steps)])        
    
    
    def _split(self, x):
        return x[:, :self.n_d], x[:, self.n_d:]
    
    def forward(self, x_cat, x_cont):
        x = _combine_cat_cont(x_cat, x_cont, self.embeds)

        x = x * self.prior 
        
        output = []
        x = self.initial_bn(x)
        res = self.initial_ft(x)
        d, a = self._split(res)
        
        self.masks = []
        for i in range(self.n_steps):
            M = self.att_steps[i](self.prior, a)
            self.masks.append(M)
            self.prior = (self.gamma - M)*self.prior
            masked_x = M * x
            transformed_x = self.ft_steps[i](masked_x)
            d, a = self._split(transformed_x)
            output.append(nn.functional.relu(d))
        
        return torch.stack(output, dim=1)

In [80]:
#exporti
class FeatureTransformer(Module):
    def __init__(self, n_out, shared_layers, n_indep,
                 virtual_batch_size=128, momentum=0.02):
        store_attr()
        
        shared_layers = [FeatureTransformerBlock(layer, n_out, virtual_batch_size, momentum) 
                             for layer in shared_layers]
        independent_layers = [FeatureTransformerBlock(_rest_block(n_out), n_out, virtual_batch_size, momentum) 
                              for _ in range(n_indep)]
        
        self.layers = nn.ModuleList([*shared_layers, *independent_layers])
        
        
    def forward(self, x):
        scale = torch.sqrt(torch.FloatTensor([0.5]).to(x.device))
        x = self.layers[0](x)
        
        for layer in self.layers[1:]:
            x = scale * (x+layer(x))
        
        return x

class FeatureTransformerBlock(nn.Sequential):
    def __init__(self, fc, n_out,
                 virtual_batch_size=128, momentum=0.02):
        self.n_out = n_out
        
        layers = [
                    fc, 
                    GBN(2*n_out, virtual_batch_size=virtual_batch_size, momentum=momentum),
                    nn.GLU()
                 ]
        super().__init__(*layers)    


In [87]:
#exporti
class AttentiveTransformer(Module):
    
    def __init__(self, n_in, n_out, virtual_batch_size, momentum):
        store_attr()
        self.fc = nn.Linear(n_in, n_out)
        self.bn = GBN(n_out, virtual_batch_size, momentum)
        self.sparsemax = Sparsemax()
        
    def forward(self, prior, a):
        a = self.fc(a)
        a = self.bn(a)
        a = prior * a
        M = self.sparsemax(a)
        return M

In [88]:
N = 3
n_cat = 5
n_cont = 10
n_d = n_a = 7 
n_steps = 4
n_out = 10
virtual_batch_size = 5

x_cont = torch.randn((N, 10))
x_cat = torch.randint(high=3, size=(N, n_cat))
enc = TabNetEnc([(3, 10)]*n_cat, n_cont, n_d=n_d, n_a=n_a, n_steps=n_steps)
enc.prior = 1 
test_eq(enc(x_cat, x_cont).shape, (N, n_steps, n_d))

### Classifier Head

In [10]:
#export
@delegates(TabNetBase.__init__)
def TabNet(head_func, emb_szs, n_cont, n_out, **kwargs):    
    class TabNetWithHead(Module): 
        def __init__(self, enc, head): store_attr()
        def forward(self, x_cat, x_cont): return self.head(self.enc(x_cat, x_cont)) 

    return TabNetWithHead(TabNetEnc(emb_szs, n_cont, **kwargs), head_func(n_out, **kwargs))

In [11]:
#export 
def linear_head(n_out, n_d, **kwargs):
    return nn.Sequential(Lambda(lambda x: x.sum(dim=1)), nn.Linear(n_d, n_out))

In [12]:
classifier = TabNet(head_func=linear_head, emb_szs=[(3, 10)]*n_cat, n_cont=n_cont, n_out=n_out, 
                          n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
classifier.enc.prior = 1 
test_eq(classifier(x_cat, x_cont).shape, (N, n_out))

### Mask Regularization 

In [43]:
#export
class MaskRegularizer(Callback):
    def __init__(self, lambda_sparse=1e-4, eps=1e-6):
        store_attr()
        self.loss_regs = []
    
    def after_loss(self):
        masks = self.model.enc.masks
        Ms = torch.stack(masks)
        n_steps,B,_ = Ms.shape
        res = ((Ms + self.eps).log()*(-Ms)/(n_steps * B)).sum()
        self.loss_regs.append(res)
        self.learn.loss = self.loss + self.lambda_sparse*res

In [44]:
#export
class SetPrior(Callback):
    def before_batch(self):
        self.learn.model.enc.prior = 1

# Self Supervision 

### Decoder 

For self-supervision, we need to create a decoder. 
The decoder receives the `x`s in a (instance_index, step_index, step_result) fashion. 
Need to chunk it to get (i) batches which correspond to the (i-th) step's output. i.e ((step_index, instance_index, step_result)).

In [15]:
#exporti
class TabNetDec(TabNetBase):
    
    @delegates(TabNetBase.__init__)
    def __init__(self, n_out, **kwargs):
        store_attr()
        super().__init__(**kwargs)
        
        shared_ft_blocks = _create_shared_blocks(self.n_d, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.steps = nn.ModuleList([
                            nn.Sequential(
                                self._create_feature_transform(shared_ft_blocks),
                                nn.Linear(self.n_d+self.n_a, self.n_out)) for _ in range(self.n_steps)
                        ])
        
        
        
        
    def forward(self, x):
        xs = x.chunk(self.n_steps, dim=1)
        xs = [x.squeeze() for x in xs] #squeeze to remove the extra "chunk" dimension
        
        output = 0 
        
        for x,step in zip(xs, self.steps): 
            output = output + step(x)
        
        return output

In [16]:
dec = TabNetDec(n_cont+n_cat, n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
test_eq(dec(enc(x_cat, x_cont)).shape, (N, n_cont+n_cat))

### Encoder + Decoder Head = Self Supervised Model

In [17]:
#export
def tabnet_decoder(n_out, **kwargs): return TabNetDec(n_out, **kwargs)

In [18]:
tbss = TabNet(head_func=tabnet_decoder, emb_szs=[(3, 10)]*n_cat, n_cont=n_cont, n_out=n_cat+n_cont, 
                          n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
tbss.enc.prior = 1 
test_eq(tbss(x_cat, x_cont).shape, (N, n_cat+n_cont))

### Mask Generator Callback

In [19]:
#export
def create_mask(size, n_cols):
    if n_cols is 0: return torch.ones(size)
    rand_mat = torch.rand(*size)    
    k_th_quant = torch.topk(rand_mat, n_cols, largest = True)[0][:,None, -1]  
    M = rand_mat < k_th_quant
    return M.int()

In [70]:
#export
class TabularMasking(Callback):
    
    def __init__(self, p=0.5, curriculum=True): 
        store_attr()
    
    def before_batch(self):
        x_cat, x_cont = self.xb
        xb = torch.cat([x_cat, x_cont], dim=1)
        n_cols = xb.shape[1]
        if self.curriculum: 
            n_masked = (torch.linspace(0, self.p, steps=self.n_epoch)*n_cols).floor()[self.epoch].int().item()
        else:
            n_masked = torch.distributions.Binomial(n_cols, probs=self.p).sample().int().item()
        S = to_device(create_mask(xb.shape, n_masked), x_cat.device)
        xb = xb * S
        self.learn.xb = (xb[:, :x_cat.shape[1]].long(), xb[:, x_cat.shape[1]:])
        self.learn.loss_func.S = S 
        
    def after_pred(self):
        y_cat, y_cont = self.yb
        self.learn.yb = tuplify(torch.cat([y_cat, y_cont], dim=1))

In [72]:
tm = TabularMasking(0.5)
learn = Mock()
learn.xb = torch.ones((7, 2)), torch.ones((7, 2))
learn.n_epoch = 5
learn.epoch = 2
learn.loss_func = Mock()
learn.model = Mock()
tm.learn = learn 
tm.before_batch()
test_eq(learn.loss_func.S.shape, (7,4))

### Self Supervised Loss Function

Sometimes we get a column of features that are all equal. When that happens, the `norm` becomes 0 -> divide by 0 -> `inf` loss. 
We'll fix that by removing very small values (close to 0)

In [73]:
#export
class MaskReconstructionLoss(Module):
    def __init__(self, lambda_reg=1e-4, eps=1e-5): store_attr()
    
    def forward(self, preds, targ):
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        norm_mask = norm >= 1e-6
        norm = norm[norm_mask]
        error = (preds - targ) * self.S
        error = error[:,norm_mask]
        
        loss = (error / norm).abs().sum()
        return loss

In [74]:
a = tensor([[1,2],[3,4],[5,6]], dtype=float).requires_grad_()
b = tensor([[1,2],[3,3.8],[5.2,6]], dtype=float).requires_grad_()

loss_func = MaskReconstructionLoss()
loss_func.S = torch.ones_like(a)
loss = loss_func(a,b)
loss.backward()
test_eq(loss < 100, True)
loss

tensor(0.1379, grad_fn=<SumBackward0>)

In [89]:
a = tensor([[1,2],[3,4],[5,6]], dtype=float).requires_grad_()
b = tensor([[1,2],[1,3.8],[1,6]], dtype=float).requires_grad_()

loss_func = MaskReconstructionLoss()
loss_func.S = torch.ones_like(a)
loss = loss_func(a,b)
loss.backward()
test_eq(loss < 100, True)
loss

tensor(0.0706, grad_fn=<SumBackward0>)

# Export

In [90]:
from nbdev.export import notebook2script
notebook2script()

Converted 01_core.ipynb.
Converted 02_model.ipynb.
Converted 03_experiments.ipynb.
Converted 04_utils.ipynb.
Converted index.ipynb.
