# Tabnet

For tabular data, the most common approach and winning method in many Kaggle competitions is tree-based models and their ensembles. XGBoost and LightGBM are two examples that have been dominating many winning solutions. In recent years, there has been an effort to develop Deep Learning algorithm for tabular data. One such successful effort is from Google Cloud AI, which is call TabNet. An important characteristic feature of this algorithm is that 'it combines the features of neural networks to fit very complex functions and the feature selection property of tree-based algorithms.' Additionally, in contrast with tree-based models that can only do feature-selection globally, TabNet allows instance-wise feature-selection. Most importanly, TabNet provides interpretability which is a key desirable feature in any machine learning algorithm. 

<figure><center>
<img src = https://miro.medium.com/max/700/1*twB1nZHPN5Cuxu2h_jpEPg.png>
<figcaption> Source: https://arxiv.org/pdf/1908.07442v1.pdf </figcaption> </center></figure>

The TabNet implementation is largely adapted from this [notebook](https://www.kaggle.com/samratthapa/tabnet-implementation/notebook?scriptVersionId=46472520)

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from sklearn.preprocessing import QuantileTransformer


from sklearn import preprocessing
import torch.optim as optim

In [None]:
import torch.nn.functional as F


In [None]:
!pip install optuna
import optuna
import plotly as pl

Collecting optuna
  Downloading optuna-2.10.0-py3-none-any.whl (308 kB)
[K     |████████████████████████████████| 308 kB 7.5 MB/s 
Collecting colorlog
  Downloading colorlog-6.6.0-py2.py3-none-any.whl (11 kB)
Collecting cliff
  Downloading cliff-3.10.0-py3-none-any.whl (80 kB)
[K     |████████████████████████████████| 80 kB 10.3 MB/s 
Collecting alembic
  Downloading alembic-1.7.5-py3-none-any.whl (209 kB)
[K     |████████████████████████████████| 209 kB 63.7 MB/s 
Collecting cmaes>=0.8.2
  Downloading cmaes-0.8.2-py3-none-any.whl (15 kB)
Collecting Mako
  Downloading Mako-1.1.6-py2.py3-none-any.whl (75 kB)
[K     |████████████████████████████████| 75 kB 5.1 MB/s 
Collecting cmd2>=1.0.0
  Downloading cmd2-2.3.3-py3-none-any.whl (149 kB)
[K     |████████████████████████████████| 149 kB 64.4 MB/s 
Collecting stevedore>=2.0.1
  Downloading stevedore-3.5.0-py3-none-any.whl (49 kB)
[K     |████████████████████████████████| 49 kB 6.5 MB/s 
[?25hCollecting autopage>=0.4.0
  Downloading

In [None]:
from sklearn.model_selection import StratifiedKFold


In [None]:
import lightgbm as lgb

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


In [None]:
! unzip lish-moa.zip
test_fea = pd.read_csv('test_features.csv')
train_fea = pd.read_csv('train_features.csv')
train_tar_nonsco = pd.read_csv('train_targets_nonscored.csv')
train_tar_sco = pd.read_csv('train_targets_scored.csv')
submission = pd.read_csv('sample_submission.csv')


Archive:  lish-moa.zip
  inflating: sample_submission.csv   
  inflating: test_features.csv       
  inflating: train_drug.csv          
  inflating: train_features.csv      
  inflating: train_targets_nonscored.csv  
  inflating: train_targets_scored.csv  


In [None]:
def seed_everything(seed=1062):
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=1062)

In [None]:
class Sparsemax(nn.Module):
    def __init__(self, dim=None):
        super(Sparsemax, self).__init__()
        self.dim = -1 if dim is None else dim

    def forward(self, input):
        input = input.transpose(0, self.dim)
        original_size = input.size()
        input = input.reshape(input.size(0), -1)
        input = input.transpose(0, 1)
        dim = 1

        number_of_logits = input.size(dim)
        
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
        range = torch.arange(start=1, end=number_of_logits + 1, device=device,step=1, dtype=input.dtype).view(1, -1)
        range = range.expand_as(zs)

        bound = 1 + range * zs
        cumulative_sum_zs = torch.cumsum(zs, dim)
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
        k = torch.max(is_gt * range, dim, keepdim=True)[0]
        zs_sparse = is_gt * zs
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
        taus = taus.expand_as(input)
        self.output = torch.max(torch.zeros_like(input), input - taus)
        output = self.output
        output = output.transpose(0, 1)
        output = output.reshape(original_size)
        output = output.transpose(0, self.dim)
        return output
    def backward(self, grad_output):
        dim = 1
        nonzeros = torch.ne(self.output, 0)
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
        return self.grad_input

In [None]:
def initialize_non_glu(module,inp_dim,out_dim):
    gain = np.sqrt((inp_dim+out_dim)/np.sqrt(4*inp_dim))
    torch.nn.init.xavier_normal_(module.weight, gain=gain)
    
class GBN(nn.Module):
    def __init__(self,inp,vbs=128,momentum=0.01):
        super().__init__()
        self.bn = nn.BatchNorm1d(inp,momentum=momentum)
        self.vbs = vbs
    def forward(self,x):
        chunk = torch.chunk(x,max(1,x.size(0)//self.vbs),0)
        res = [self.bn(y) for y in chunk ]
        return torch.cat(res,0)

class GLU(nn.Module):
    def __init__(self,inp_dim,out_dim,fc=None,vbs=128):
        super().__init__()
        if fc:
            self.fc = fc
        else:
            self.fc = nn.Linear(inp_dim,out_dim*2)
        self.bn = GBN(out_dim*2,vbs=vbs) 
        self.od = out_dim
    def forward(self,x):
        x = self.bn(self.fc(x))
        return x[:,:self.od]*torch.sigmoid(x[:,self.od:])
    

class FeatureTransformer(nn.Module):
    def __init__(self,inp_dim,out_dim,shared,n_ind,vbs=128):
        super().__init__()
        first = True
        self.shared = nn.ModuleList()
        if shared:
            self.shared.append(GLU(inp_dim,out_dim,shared[0],vbs=vbs))
            first= False    
            for fc in shared[1:]:
                self.shared.append(GLU(out_dim,out_dim,fc,vbs=vbs))
        else:
            self.shared = None
        self.independ = nn.ModuleList()
        if first:
            self.independ.append(GLU(inp,out_dim,vbs=vbs))
        for x in range(first, n_ind):
            self.independ.append(GLU(out_dim,out_dim,vbs=vbs))
        self.scale = torch.sqrt(torch.tensor([.5],device=device))
    def forward(self,x):
        if self.shared:
            x = self.shared[0](x)
            for glu in self.shared[1:]:
                x = torch.add(x, glu(x))
                x = x*self.scale
        for glu in self.independ:
            x = torch.add(x, glu(x))
            x = x*self.scale
        return x
class AttentionTransformer(nn.Module):
    def __init__(self,inp_dim,out_dim,relax,vbs=128):
        super().__init__()
        self.fc = nn.Linear(inp_dim,out_dim)
        self.bn = GBN(out_dim,vbs=vbs)
#         self.smax = Sparsemax()
        self.r = torch.tensor([relax],device=device)
    def forward(self,a,priors):
        a = self.bn(self.fc(a))
        mask = torch.sigmoid(a*priors)
        priors =priors*(self.r-mask)
        return mask

class DecisionStep(nn.Module):
    def __init__(self,inp_dim,n_d,n_a,shared,n_ind,relax,vbs=128):
        super().__init__()
        self.fea_tran = FeatureTransformer(inp_dim,n_d+n_a,shared,n_ind,vbs)
        self.atten_tran = AttentionTransformer(n_a,inp_dim,relax,vbs)
    def forward(self,x,a,priors):
        mask = self.atten_tran(a,priors)
        loss = ((-1)*mask*torch.log(mask+1e-10)).mean()
        x = self.fea_tran(x*mask)
        return x,loss

class TabNet(nn.Module):
    def __init__(self,inp_dim,final_out_dim,n_d=64,n_a=64,n_shared=2,n_ind=2,n_steps=5,relax=1.2,vbs=128):
        super().__init__()
        if n_shared>0:
            self.shared = nn.ModuleList()
            self.shared.append(nn.Linear(inp_dim,2*(n_d+n_a)))
            for x in range(n_shared-1):
                self.shared.append(nn.Linear(n_d+n_a,2*(n_d+n_a)))
        else:
            self.shared=None
        self.first_step = FeatureTransformer(inp_dim,n_d+n_a,self.shared,n_ind) 
        self.steps = nn.ModuleList()
        for x in range(n_steps-1):
            self.steps.append(DecisionStep(inp_dim,n_d,n_a,self.shared,n_ind,relax,vbs))
        self.fc = nn.Linear(n_d,final_out_dim)
        self.bn = nn.BatchNorm1d(inp_dim)
        self.n_d = n_d
    def forward(self,x):
        x = self.bn(x)
        x_a = self.first_step(x)[:,self.n_d:]
        loss = torch.zeros(1).to(x.device)
        out = torch.zeros(x.size(0),self.n_d).to(x.device)
        priors = torch.ones(x.shape).to(x.device)
        for step in self.steps:
            x_te,l = step(x,x_a,priors)
            out += F.relu(x_te[:,:self.n_d])
            x_a = x_te[:,self.n_d:]
            loss += l
        return self.fc(out),loss

In [None]:
class TabNetWithEmbed(nn.Module):
    def __init__(self,inp_dim,final_out_dim,n_d=64,n_a=64,n_shared=2,n_ind=2,n_steps=5,relax=1.2,vbs=128):
        super().__init__()
        self.tabnet = TabNet(inp_dim,final_out_dim,n_d,n_a,n_shared,n_ind,n_steps,relax,vbs)
        self.cat_embed = []
        self.emb1 = nn.Embedding(2,1)
        self.emb3 = nn.Embedding(3,1)
        self.cat_embed.append(self.emb1)
        self.cat_embed.append(self.emb3)
        
    def forward(self,catv,contv):
        catv = catv.to(device)
        contv = contv.to(device)
        embeddings = [embed(catv[:,idx]) for embed,idx in zip(self.cat_embed,range(catv.size(1)))]
        catv = torch.cat(embeddings,1)
        x = torch.cat((catv,contv),1).contiguous()
        x,l = self.tabnet(x)
        return torch.sigmoid(x),l

In [None]:
class DrugData(Dataset):
    
    def __init__(self, df, out, train=True):
        self.df = df
        self.out = out
        self.train=train
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        if self.train:
            tar = np.where(self.out[idx].reshape(-1)==0,0.00001,0.99999)
            return torch.from_numpy(self.df[idx,:]).float(),torch.tensor(tar).float()
        else:
            return torch.from_numpy(self.df[idx,:]).float(),torch.tensor(self.out[idx].reshape(-1)).float()

In [None]:
train_fea['cp_dose'].replace({'D1':0,'D2':1},inplace=True)
test_fea['cp_dose'].replace({'D1':0,'D2':1},inplace=True)

train_fea['cp_time'].replace({24:0,72:1,48:2},inplace=True)
test_fea['cp_time'].replace({24:0,72:1,48:2},inplace=True)

train_fea['cp_type'].replace({'trt_cp':0,'ctl_vehicle':1},inplace=True)
test_fea['cp_type'].replace({'trt_cp':0,'ctl_vehicle':1},inplace=True)

train_fea = train_fea.drop(columns='sig_id') 
test_fea = test_fea.drop(columns='sig_id')

In [None]:
cat_col = [0,2]
num_col = len(train_fea.columns)

In [None]:
train_tar = train_tar_sco.drop(columns='sig_id').values

In [None]:
from sklearn.feature_selection import VarianceThreshold
from sklearn.decomposition import PCA

In [None]:
data = pd.concat([train_fea,test_fea],ignore_index=True)
g = [*(x for x in data.columns if 'g' in x)]
c = [*(x for x in data.columns if 'c-' in x)]

In [None]:
for col in g:
    sel = QuantileTransformer(n_quantiles=1000,random_state=0,output_distribution='normal')
    sel.fit(data[col].to_numpy().reshape(-1,1))
    data[col] = sel.transform(data[col].to_numpy().reshape(-1,1))
for col in c:
    sel = QuantileTransformer(n_quantiles=1000,random_state=0,output_distribution='normal')
    sel.fit(data[col].to_numpy().reshape(-1,1))
    data[col] = sel.transform(data[col].to_numpy().reshape(-1,1))

In [None]:
pca_c = PCA(n_components=15)
extra_c = pd.DataFrame(pca_c.fit_transform(data[c]))
pca_g = PCA(n_components=50)
extra_g = pd.DataFrame(pca_g.fit_transform(data[g]))
data = pd.concat((data,extra_c,extra_g),axis=1)

In [None]:
!pip install umap-learn
from umap import UMAP

umap_c = UMAP(random_state=256,n_components=15)
extra_c = pd.DataFrame(umap_c.fit_transform(data[c]))
umap_g = UMAP(random_state=256,n_components=50)
extra_g = pd.DataFrame(umap_g.fit_transform(data[g]))
data = pd.concat((data,extra_c,extra_g),axis=1)


Collecting umap-learn
  Downloading umap-learn-0.5.2.tar.gz (86 kB)
[K     |████████████████████████████████| 86 kB 3.5 MB/s 
Collecting pynndescent>=0.5
  Downloading pynndescent-0.5.5.tar.gz (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 49.3 MB/s 
Building wheels for collected packages: umap-learn, pynndescent
  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Created wheel for umap-learn: filename=umap_learn-0.5.2-py3-none-any.whl size=82708 sha256=4a164edcd1fdeb1499969acd24f91503d89247cf4107233e3194eb20c1c01bfe
  Stored in directory: /root/.cache/pip/wheels/84/1b/c6/aaf68a748122632967cef4dffef68224eb16798b6793257d82
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Created wheel for pynndescent: filename=pynndescent-0.5.5-py3-none-any.whl size=52603 sha256=11c613911a6f6673e03ee702fb5d5ac412ce2f02754ebc486f6cc553885af104
  Stored in directory: /root/.cache/pip/wheels/af/e9/33/04db1436df0757c42fda8ea6796d7a8586e23c85fac355f476
Successfull


The TBB threading layer requires TBB version 2019.5 or later i.e., TBB_INTERFACE_VERSION >= 11005. Found TBB_INTERFACE_VERSION = 9107. The TBB threading layer is disabled.



In [None]:
train_df = data.iloc[:len(train_fea),:].values
test_df= data.iloc[len(train_fea):,:].values

In [None]:
kfold = StratifiedKFold(n_splits=20)

In [None]:
loss_func = nn.BCELoss()

In [None]:
ra = np.arange(3,train_df.shape[1])
np.random.shuffle(ra)
train_df[:,3:] = train_df[:,ra]
test_df[:,3:] = test_df[:,ra]

In [None]:
cat_col = [0,1]

In [None]:
seed_everything(1006)
submission.iloc[:,1:]=0
for train,test in kfold.split(train_df,np.zeros(len(train_df))):
    batch_size=512
    sparse_constant = 0
    model = TabNetWithEmbed(train_df.shape[1]-1,train_tar.shape[1],n_d=128,n_a=16,n_shared=1,n_ind=4,n_steps=3,relax=1.5,vbs=64)
    model.to(device)
    torch.cuda.empty_cache()
    optimizer = optim.Adam(model.parameters(),lr=0.007809719000164987,weight_decay=0.00001)
    sched = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1,patience=3,verbose=True)
    train_dataset = DrugData(train_df[train],train_tar[train])
    valid_dataset = DrugData(train_df[test],train_tar[test],False)
    train_loader = DataLoader(train_dataset,batch_size=batch_size,num_workers=4,shuffle=True)
    valid_loader = DataLoader(valid_dataset,batch_size=batch_size,num_workers=4,shuffle=True)
    losses=[]
    norm = []
    valid_losses = []
    train_losses = []
    t = time.time()
    for x in range(24):
        train_loss=0.
        grad_norm_sum = 0.
        for inp,tar in train_loader:
            model.zero_grad()
            out,l = model(inp[:,cat_col].long(),inp[:,3:])
            loss = loss_func(out,tar.to(device))#+l*sparse_constant
            loss.backward()
            optimizer.step()
#             sched.step()
            train_loss+=loss.item()*tar.size(0)
        valid_loss=0.
        v=0
        for inp,tar in valid_loader:
            v+=1
            out,_ = model(inp[:,cat_col].long(),inp[:,3:])
            loss = loss_func(out,tar.to(device)) 
            valid_loss += loss.item()*tar.size(0)
        losses.append(valid_loss/len(valid_dataset))
        valid_losses.append(losses[-1])
        train_losses.append(train_loss/len(train_dataset))
        print('%d epoch, %.8f valid_loss, %.8f training_loss %fsec time'% (x+1, losses[-1], train_loss/len(train_dataset), (time.time() - t)))
        sched.step(losses[-1])
        t = time.time()
    print("completed training one fold -------------- ")
    model.eval()
    submission.iloc[:,1:] += model(torch.from_numpy(test_df[:,cat_col]).long(),torch.from_numpy(test_df[:,3:]).float())[0].data.cpu().numpy()


This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



1 epoch, 0.01925417 valid_loss, 0.05950320 training_loss 8.809434sec time
2 epoch, 0.01831277 valid_loss, 0.01888210 training_loss 8.653638sec time
3 epoch, 0.01761686 valid_loss, 0.01796183 training_loss 8.874561sec time
4 epoch, 0.01711001 valid_loss, 0.01706849 training_loss 8.690574sec time
5 epoch, 0.01646078 valid_loss, 0.01642215 training_loss 8.690804sec time
6 epoch, 0.01649438 valid_loss, 0.01607803 training_loss 8.756106sec time
7 epoch, 0.01595966 valid_loss, 0.01580850 training_loss 8.805627sec time
8 epoch, 0.01605164 valid_loss, 0.01554960 training_loss 8.814434sec time
9 epoch, 0.01586758 valid_loss, 0.01542189 training_loss 8.763968sec time
10 epoch, 0.01566753 valid_loss, 0.01522294 training_loss 8.800066sec time
11 epoch, 0.01582895 valid_loss, 0.01516107 training_loss 8.755854sec time
12 epoch, 0.01568054 valid_loss, 0.01502120 training_loss 8.703819sec time
13 epoch, 0.01559140 valid_loss, 0.01491379 training_loss 8.760062sec time
14 epoch, 0.01563122 valid_loss, 0

In [None]:
losses

[0.01955430627557911,
 0.01862123093935622,
 0.017992009635732956,
 0.017171083108968095,
 0.016643987339334327,
 0.016531902144686514,
 0.016118551026873228,
 0.01611169914738471,
 0.015719288448263116,
 0.015592357763597945,
 0.015905841045519883,
 0.015454335668820793,
 0.015428442558070191,
 0.01542883472783225,
 0.01530035108703525,
 0.015491006911552254,
 0.015324883094104399,
 0.015497832480786728,
 0.015269148376371179,
 0.015505443225387766,
 0.01519961832639049,
 0.015453900891442258,
 0.015230949244955007,
 0.01527463062940275]

In [None]:
valid_losses

[0.01955430627557911,
 0.01862123093935622,
 0.017992009635732956,
 0.017171083108968095,
 0.016643987339334327,
 0.016531902144686514,
 0.016118551026873228,
 0.01611169914738471,
 0.015719288448263116,
 0.015592357763597945,
 0.015905841045519883,
 0.015454335668820793,
 0.015428442558070191,
 0.01542883472783225,
 0.01530035108703525,
 0.015491006911552254,
 0.015324883094104399,
 0.015497832480786728,
 0.015269148376371179,
 0.015505443225387766,
 0.01519961832639049,
 0.015453900891442258,
 0.015230949244955007,
 0.01527463062940275]