In [1]:
%load_ext autoreload
%autoreload 2 
#default_exp model

In [2]:
#exporti
from fastai.tabular.all import * 
from tabnet.core import Sparsemax, GBN

In [3]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

# TabNet 

### Encoder

In [4]:
#exporti
def _initial_block(n_in, n_out):
    return nn.Linear(n_in, 2*n_out, bias=False)

def _rest_block(n):
    return nn.Linear(n, 2*n, bias=False)

def _create_shared_blocks(n_in, n_out, n_shared):
    return [_initial_block(n_in, n_out)] + \
            [_rest_block(n_out) for _ in range(n_shared-1)]

def _combine_cat_cont(x_cat, x_cont, embeds):
    if len(x_cat) != 0: 
        x = [e(x_cat[:,i]) for i,e in enumerate(embeds)]
        x = torch.cat(x, 1)

    if len(x_cont) != 0: 
        x = torch.cat([x, x_cont], 1) if len(x_cat) != 0 else x_cont

    return x

In [5]:
#export
class TabNetBase(Module):
    def __init__(self, n_d=64, n_a=64, n_steps=3, n_shared_ft_blocks=2, n_independent_ft_blocks=2, virtual_batch_size=128, momentum=0.2):
        store_attr()
    
    
    def _create_feature_transform(self, shared_ft_blocks):
        return FeatureTransformer(self.n_d+self.n_a, shared_ft_blocks,
                                          self.n_independent_ft_blocks,
                                          self.virtual_batch_size, self.momentum)


In [57]:
#exporti
class TabNetEnc(TabNetBase):
    
    @delegates(TabNetBase.__init__)
    def __init__(self, emb_szs, n_cont, gamma=1.5, **kwargs):        
        store_attr()
        super().__init__(**kwargs)
        
        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in self.emb_szs])
        self.n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_features = self.n_emb + self.n_cont
        
        shared_ft_blocks = _create_shared_blocks(self.n_features, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.initial_ft = self._create_feature_transform(shared_ft_blocks)        
        self.initial_bn = BatchNorm(self.n_features, ndim=1)    
        
        self.att_steps = nn.ModuleList([AttentiveTransformer(self.n_a, self.n_features,self.virtual_batch_size, 
                                                             self.momentum) for i in range(self.n_steps)])
        
        self.ft_steps = nn.ModuleList([self._create_feature_transform(shared_ft_blocks) for _ in range(self.n_steps)])        
    
    
    def _split(self, x):
        return x[:, :self.n_d], x[:, self.n_d:]
    
    def forward(self, x_cat, x_cont):
        x = _combine_cat_cont(x_cat, x_cont, self.embeds)

        output = []
        x = self.initial_bn(x)
        res = self.initial_ft(x)
        d, a = self._split(res)
        
        prior = torch.ones(self.n_features, device=x_cont.device)
        
        for i in range(self.n_steps):
            M = self.att_steps[i](prior, a)
            prior = (self.gamma - M)*prior
            masked_x = M * x
            transformed_x = self.ft_steps[i](masked_x)
            d, a = self._split(transformed_x)
            output.append(nn.functional.relu(d))
        
        return torch.stack(output, dim=1)


In [58]:
#exporti
class FeatureTransformer(Module):
    def __init__(self, n_out, shared_layers, n_indep,
                 virtual_batch_size=128, momentum=0.02):
        store_attr()
        
        shared_layers = [FeatureTransformerBlock(layer, n_out) for layer in shared_layers]
        independent_layers = [FeatureTransformerBlock(_rest_block(n_out), n_out) for _ in range(n_indep)]
        
        self.layers = nn.ModuleList([*shared_layers, *independent_layers])
        
        
    def forward(self, x):
        scale = torch.sqrt(torch.FloatTensor([0.5]).to(x.device))
        x = self.layers[0](x)
        
        for layer in self.layers[1:]:
            x = scale * (x+layer(x))
        
        return x

class FeatureTransformerBlock(nn.Sequential):
    def __init__(self, fc, n_out,
                 virtual_batch_size=128, momentum=0.02):
        self.n_out = n_out
        
        layers = [
                    fc, 
                    GBN(2*n_out, virtual_batch_size=virtual_batch_size, momentum=momentum),
                    nn.GLU()
                 ]
        super().__init__(*layers)    


In [59]:
#exporti
class AttentiveTransformer(Module):
    
    def __init__(self, n_a, n_in, virtual_batch_size, momentum):
        store_attr()
        self.fc = nn.Linear(n_a, n_in)
        self.bn = GBN(n_in, virtual_batch_size, momentum)
        self.sparsemax = Sparsemax()
        
    def forward(self, prior, a):
        a = self.fc(a)
        a = self.bn(a)
        a = prior * a
        M = self.sparsemax(a)
        return M

In [17]:
N = 3
n_cat = 5
n_cont = 10
n_d = n_a = 7 
n_steps = 4
n_out = 10
virtual_batch_size = 5

x_cont = torch.randn((N, 10))
x_cat = torch.randint(high=3, size=(N, n_cat))
enc = TabNetEnc([(3, 10)]*n_cat, n_cont, n_d=n_d, n_a=n_a, n_steps=n_steps)
test_eq(enc(x_cat, x_cont).shape, (N, n_steps, n_d))

### Classifier Head

In [106]:
#export
@delegates(TabNetBase.__init__)
def TabNet(head_func, emb_szs, n_cont, n_out, **kwargs):    
    class TabNetWithHead(Module): 
        def __init__(self, enc, head): store_attr()
        def forward(self, x_cat, x_cont): return self.head(self.enc(x_cat, x_cont)) 

    return TabNetWithHead(TabNetEnc(emb_szs, n_cont, **kwargs), head_func(n_out, **kwargs))

In [107]:
#export 
def linear_head(n_out, n_d, **kwargs):
    return nn.Sequential(Lambda(lambda x: x.sum(dim=1)), nn.Linear(n_d, n_out))

In [108]:
classifier = TabNet(head_func=linear_head, emb_szs=[(3, 10)]*n_cat, n_cont=n_cont, n_out=n_out, 
                          n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
test_eq(classifier(x_cat, x_cont).shape, (N, n_out))

# Self Supervision 

### Decoder 

For self-supervision, we need to create a decoder. 
The decoder receives the `x`s in a (instance_index, step_index, step_result) fashion. 
Need to chunk it to get (i) batches which correspond to the (i-th) step's output. i.e ((step_index, instance_index, step_result)).

In [None]:
#exporti
class TabNetDec(TabNetBase):
    
    @delegates(TabNetBase.__init__)
    def __init__(self, n_out, **kwargs):
        store_attr()
        super().__init__(**kwargs)
        
        shared_ft_blocks = _create_shared_blocks(self.n_d, self.n_d + self.n_a, self.n_shared_ft_blocks)
        
        self.steps = nn.ModuleList([
                            nn.Sequential(
                                self._create_feature_transform(shared_ft_blocks),
                                nn.Linear(self.n_d+self.n_a, self.n_out)) for _ in range(self.n_steps)
                        ])
        
        
        
        
    def forward(self, x):
        xs = x.chunk(self.n_steps, dim=1)
        xs = [x.squeeze() for x in xs] #squeeze to remove the extra "chunk" dimension
        
        output = 0 
        
        for x,step in zip(xs, self.steps): 
            output = output + step(x)
        
        return output

In [None]:
dec = TabNetDec(n_cont+n_cat, n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
test_eq(dec(enc(x)).shape, (N, n_cont+n_cat))

### Encoder + Decoder Head = Self Supervised Model

In [None]:
#export
def tabnet_decoder(n_out, **kwargs): return TabNetDec(n_out, **kwargs)

In [None]:
tbss = TabNet(head_func=tabnet_decoder, emb_szs=[(3, 10)]*n_cat, n_cont=n_cont, n_out=n_cat+n_cont, 
                          n_steps=n_steps, n_d=n_d, n_a=n_a, virtual_batch_size=virtual_batch_size)
test_eq(tbss(x).shape, (N, n_cat+n_cont))

### Mask Generator Callback

In [None]:
#export
def create_mask(size, n_cols):
    if n_cols is 0: return torch.ones(size)
    rand_mat = torch.rand(*size)    
    k_th_quant = torch.topk(rand_mat, n_cols, largest = True)[0][:,None, -1]  
    M = rand_mat < k_th_quant
    return M

In [None]:
#export
class TabularMasking(Callback):
    
    def __init__(self, p=0.5): 
        store_attr()
    
    def before_batch(self):
        x_cat, x_cont = self.xb
        xb = torch.cat([x_cat, x_cont], dim=1)
        n_cols = xb.shape[1]
        n_masked = (torch.linspace(0, self.p, steps=self.n_epoch)*n_cols).floor()[self.epoch].int().item()
        M = create_mask(xb.shape, n_masked)
        xb = xb * M
        self.learn.xb = (xb[:, :x_cat.shape[1]].long(), xb[:, x_cat.shape[1]:])
        self.learn.loss_func.M = M 
        
    def after_pred(self):
        y_cat, y_cont = self.yb
        self.learn.yb = tuplify(torch.cat([y_cat, y_cont], dim=1))

In [None]:
tm = TabularMasking(0.5)
learn = namedtuple('a', '')
learn.xb = torch.ones((7, 2)), torch.ones((7, 2))
learn.n_epoch = 5
learn.epoch = 2
learn.loss_func = lambda x: x
tm.learn = learn 
tm.before_batch()

### Self Supervised Loss Function

In [None]:
#export
class MaskReconstructionLoss(Module):
    
    def __init__(self): pass
    
    def forward(self, preds, targ):
        norm = (targ - targ.mean(dim=0)).pow(2).sum(dim=0).sqrt()
        error_masked = (preds - targ) * self.M
        
        return (error_masked / norm).abs().sum()

In [None]:
a = tensor([[1,2],[3,4]], dtype=float).requires_grad_()
b = tensor([[1,2],[3,3.8]], dtype=float).requires_grad_()

loss_func = MaskReconstructionLoss()
loss_func.M = torch.ones_like(a)
loss_func(a,b).backward()

# Export

In [109]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_core.ipynb.
Converted 02_model.ipynb.
No export destination, ignored:
#exporti
from fastai.tabular.all import * 
from tabnet.utils import *
from tabnet.model import *
Converted 03_experiments.ipynb.
Converted 04_self_supervision.ipynb.
Converted 04_utils.ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
