In [37]:
%load_ext autoreload
%autoreload 2 
#default_exp model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
#exporti
from fastai.tabular.all import * 
from tabnet.sparsemax import Sparsemax

In [3]:
! pip install -e ../../libraries/fastai ../../libraries/fastcore 

Processing /home/jupyter/libraries/fastcore
Obtaining file:///home/jupyter/libraries/fastai
Building wheels for collected packages: fastcore
  Building wheel for fastcore (setup.py) ... [?25ldone
[?25h  Created wheel for fastcore: filename=fastcore-1.1.3-py3-none-any.whl size=43299 sha256=4ddad464f8e42683a92d4fd7a5da5c7c50ef95edbc21e89baee6116c30eb48bd
  Stored in directory: /tmp/pip-ephem-wheel-cache-08qajf4h/wheels/f5/2c/77/44ecc13f884dd9f69af0b9d1cffa0421c2b261c7e7ff267cbd
Successfully built fastcore
Installing collected packages: fastcore, fastai
  Attempting uninstall: fastcore
    Found existing installation: fastcore 1.1.3
    Uninstalling fastcore-1.1.3:
      Successfully uninstalled fastcore-1.1.3
  Attempting uninstall: fastai
    Found existing installation: fastai 2.0.16
    Uninstalling fastai-2.0.16:
      Successfully uninstalled fastai-2.0.16
  Running setup.py develop for fastai
Successfully installed fastai fastcore-1.1.3


# Creating the model 


In [39]:
class GBN(torch.nn.Module):
    """
        Ghost Batch Normalization
        https://arxiv.org/abs/1705.08741
    """

    def __init__(self, input_dim, virtual_batch_size=128, momentum=0.02):
        super(GBN, self).__init__()

        self.input_dim = input_dim
        self.virtual_batch_size = virtual_batch_size
        self.bn = BatchNorm(self.input_dim, momentum=momentum, ndim=1)

    def forward(self, x):
        chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0)
        res = [self.bn(x_) for x_ in chunks]

        return torch.cat(res, dim=0)


In [59]:
class TabNet(Module):
    
    def __init__(self, emb_szs, n_cont, out_features, n_d, n_a, n_steps, n_shared_ft_blocks=2,
                         n_independent_ft_blocks=2, gamma=1.5, virtual_batch_size=128):        
        store_attr()
        
        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])
        self.n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_features = self.n_emb + n_cont
        
        
        intermediate_features = n_d + n_a 
        self.shared_ft_blocks = [FeatureTransformerBlock(self.n_features, intermediate_features, virtual_batch_size, 
                                                         is_first=True)] + \
                                [FeatureTransformerBlock(intermediate_features, intermediate_features, virtual_batch_size,
                                                         is_first=False) for _ in range(n_shared_ft_blocks)]
        
        
        
        self.initial_bn = BatchNorm(self.n_features, ndim=1)
        self.initial_ft = FeatureTransformer(self.shared_ft_blocks, n_d, n_a, n_independent_ft_blocks, virtual_batch_size)
        
        self.att_steps = nn.ModuleList([AttentiveTransformer(n_a, self.n_features,virtual_batch_size) 
                                            for i in range(self.n_steps)])
        self.ft_steps = nn.ModuleList([FeatureTransformer(self.shared_ft_blocks, n_d, n_a,
                                                          n_independent_ft_blocks, virtual_batch_size) 
                                            for i in range(self.n_steps)])
        
        self.final_fc = nn.Linear(n_d, out_features)
    
    def forward(self, x_cat, x_cont):
        x = self._combine_cat_cont(x_cat, x_cont)
        
        output = 0
        x = self.initial_bn(x)
        _, a = self.initial_ft(x)
        
        prior = torch.ones(self.n_features, device=x_cont.device)
        
        for i in range(self.n_steps):
            M = self.att_steps[i](a, prior)
            prior = (self.gamma - M)*prior
            res = M * x
            d, a = self.ft_steps[i](res)
            output = output + nn.functional.relu(d)
        
        res = self.final_fc(output)
        return res
    
    def _combine_cat_cont(self, x_cat, x_cont):
        if self.n_emb != 0:
            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
            x = torch.cat(x, 1)
        if self.n_cont != 0:
            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
        
        return x

In [60]:
class AttentiveTransformer(Module):
    
    def __init__(self, n_a, in_features, virtual_batch_size):
        store_attr()
        self.fc = nn.Linear(n_a, in_features)
        self.bn = GBN(in_features, virtual_batch_size)
        self.sparsemax = Sparsemax()
        
    def forward(self, a, prior):
        a = self.fc(a)
        a = self.bn(a)
        a = prior * a
        M = self.sparsemax(a)
        return M

In [61]:
class FeatureTransformer(Module):
    
    def __init__(self, shared_blocks, n_d, n_a, n_independent_ft_blocks, virtual_batch_size):
        store_attr()
        intermediate_features = n_d + n_a
        steps = [FeatureTransformerBlock(intermediate_features, intermediate_features, virtual_batch_size, False) 
                             for _ in range(n_independent_ft_blocks)]
        self.steps = nn.Sequential(*[*shared_blocks, *steps])
        
    def forward(self, x):
        res = self.steps(x)
        d, a = res[:,:self.n_d], res[:,self.n_d:]
        return d, a

In [62]:
class FeatureTransformerBlock(Module):
    def __init__(self, in_features, intermediate_features, virtual_batch_size, is_first, norm=math.sqrt(0.5)):
        store_attr()
        
        self.block1 = self._create_inner_block(is_first=True)
        self.block2 = self._create_inner_block(is_first=False)
        
        
    def forward(self, x):
        x1 = self.block1(x)
        if not self.is_first: x1 = (x1 + x)*self.norm
        x2 = self.block2(x1)
        x2 = (x2 + x1)*self.norm
        return x2
        
        
    def _create_inner_block(self, is_first):
        intermediate_features = self.intermediate_features
        in_features = self.in_features if is_first else intermediate_features
        
        return nn.Sequential(*[
            nn.Linear(in_features, 2*intermediate_features),
            GBN(2*intermediate_features, self.virtual_batch_size),
            nn.GLU(),
        ])        
        
        
    

In [63]:
N = 3
n_features = 32
n_d = n_a = 7 
n_steps = 3
out_features = 10
virtual_batch_size = 5

a = torch.randn((N, n_features))
ft = FeatureTransformerBlock(n_features, n_d+n_a, virtual_batch_size, is_first=True)
test_eq(ft(a).shape, (N, n_d+n_a))

In [64]:
a = torch.randn((N, n_features))
tabnet = TabNet([], n_features, out_features, n_d, n_a, n_steps, virtual_batch_size=virtual_batch_size)
test_eq(tabnet(a, a).shape, (N, out_features))

# Testing it out 

In [65]:
data_dir = Path('./data')

In [66]:
def extract_gzip(file, dest=None):
    import gzip
    dest = dest or Path(dest)
    with gzip.open(file, 'rb') as f_in:
        with open(dest / file.stem, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)

In [67]:
forest_type_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz'
forest_path = untar_data(forest_type_url, dest=data_dir, extract_func=extract_gzip)

In [68]:
target = "Covertype"

cat_names = [
    "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3",
    "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4",
    "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9",
    "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14",
    "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19",
    "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24",
    "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29",
    "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34",
    "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39",
    "Soil_Type40"
]

cont_names = [
    "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways",
    "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points"
]

feature_columns = (
    cont_names + cat_names + [target])

In [69]:
df = pd.read_csv(forest_path, header=None, names=feature_columns); df.head()
df = df_shrink(df)

In [70]:
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter(0.05)(range_of(df))

In [71]:
to = TabularPandas(df, procs, cat_names, cont_names, y_names=target, y_block = CategoryBlock(), splits=splits)
dls = to.dataloaders(bs=64*64*4)

In [72]:
model = TabNet(get_emb_sz(to), len(to.cont_names), dls.c, n_d=64, n_a=64, n_steps=5, virtual_batch_size=256)
opt_func = partial(Adam, wd=0.01, eps=1e-5)
learn = Learner(dls, model, CrossEntropyLossFlat(), opt_func=opt_func, lr=3e-2, metrics=[accuracy])

In [73]:
learn.fit_one_cycle(10)

epoch,train_loss,valid_loss,accuracy,time
0,1.338366,0.884798,0.67105,00:24
1,0.98464,0.768904,0.67222,00:24
2,0.854257,0.862537,0.645473,00:24
3,0.783051,0.769091,0.66475,00:24
4,0.741406,0.788258,0.664131,00:24
5,0.712884,0.787519,0.662788,00:24
6,0.693834,0.859104,0.65673,00:24
7,0.681376,0.887326,0.644819,00:24
8,0.67375,0.852126,0.640516,00:24
9,0.668546,0.864051,0.649776,00:24


# Export

In [None]:
from nbdev.export import notebook2script
notebook2script()