In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

class GLU(tf.keras.layers.Layer):
    def __init__(self, units):
        super(GLU, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.dense = tf.keras.layers.Dense(2 * self.units)

    def call(self, inputs):
        x = self.dense(inputs)
        a, b = tf.split(x, num_or_size_splits=2, axis=-1)
        return a * tf.nn.sigmoid(b)

class GhostBatchNormalization(tf.keras.layers.Layer):
    def __init__(self, virtual_batch_size=64, momentum=0.99):
        super(GhostBatchNormalization, self).__init__()
        self.virtual_batch_size = virtual_batch_size
        self.momentum = momentum
        self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum)

    def call(self, inputs):
        if self.virtual_batch_size is None:
            return self.bn(inputs)
        splits = tf.split(inputs, num_or_size_splits=self.virtual_batch_size, axis=0)
        outputs = [self.bn(split) for split in splits]
        return tf.concat(outputs, axis=0)



In [2]:
class FeatureTransformer(tf.keras.layers.Layer):
    def __init__(self, units, n_glus):
        super(FeatureTransformer, self).__init__()
        self.units = units
        self.n_glus = n_glus
        self.glu_layers = [GLU(self.units) for _ in range(self.n_glus)]
        self.bn_layers = [GhostBatchNormalization() for _ in range(self.n_glus)]

    def call(self, inputs):
        x = inputs
        for glu, bn in zip(self.glu_layers, self.bn_layers):
            x = glu(x)
            x = bn(x)
        return x

class AttentiveTransformer(tf.keras.layers.Layer):
    def __init__(self, units):
        super(AttentiveTransformer, self).__init__()
        self.units = units
        self.fc = tf.keras.layers.Dense(self.units, activation=None)
        self.bn = GhostBatchNormalization()

    def call(self, inputs):
        x = self.fc(inputs)
        x = self.bn(x)
        x = tf.keras.layers.Softmax(axis=-1)(x)
        return x


In [10]:

class TabNetEncoder(tf.keras.Model):
    def __init__(self, feature_dim, n_glus):
        super(TabNetEncoder, self).__init__()
        self.feature_transformer = FeatureTransformer(feature_dim, n_glus)

    def call(self, inputs):
        x = self.feature_transformer(inputs)
        return x

input_shape = (None, 42)

encoder = TabNetEncoder(feature_dim=42, n_glus=4)  
encoder.build(input_shape)

encoder.summary()


Model: "tab_net_encoder_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 feature_transformer_4 (Fea  multiple                  15120     
 tureTransformer)                                                
                                                                 
Total params: 15120 (59.06 KB)
Trainable params: 14784 (57.75 KB)
Non-trainable params: 336 (1.31 KB)
_________________________________________________________________


In [11]:
class TabNetDecoder(tf.keras.Model):
    def __init__(self, feature_dim, n_glus):
        super(TabNetDecoder, self).__init__()
        self.feature_transformer = FeatureTransformer(feature_dim, n_glus)

    def call(self, inputs):
        x = self.feature_transformer(inputs)
        return x

decoder = TabNetDecoder(feature_dim=42, n_glus=4)  

decoder.build(input_shape)

decoder.summary()

Model: "tab_net_decoder_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 feature_transformer_5 (Fea  multiple                  15120     
 tureTransformer)                                                
                                                                 
Total params: 15120 (59.06 KB)
Trainable params: 14784 (57.75 KB)
Non-trainable params: 336 (1.31 KB)
_________________________________________________________________


In [12]:
class TabNet(tf.keras.Model):
    def __init__(self, encoder, decoder, feature_dim, n_steps):
        super(TabNet, self).__init__()
        self.feature_dim = feature_dim
        # self.output_dim = output_dim
        self.n_steps = n_steps
        self.encoder = encoder
        self.attentive_transformer = AttentiveTransformer(self.feature_dim)
        self.decoder = decoder

    def call(self, inputs):
        x = self.encoder(inputs)
        outputs = []
        masks = []

        for step in range(self.n_steps):
            mask = self.attentive_transformer(x)
            masks.append(mask)
            x = x * mask
            x = self.decoder(x)
            outputs.append(x)

        outputs = tf.reduce_sum(outputs, axis=0)
        return outputs, masks

input_shape = (None, 42)  
tabnet = TabNet(feature_dim=42, encoder=encoder,decoder=decoder,  n_steps=5)

inputs = tf.keras.Input(shape=input_shape[1:])
outputs, masks = tabnet(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse')

# Display the model's architecture
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 42)]              0         
                                                                 
 tab_net_3 (TabNet)          ((None, 42),              32214     
                              [(None, 42),                       
                              (None, 42),                        
                              (None, 42),                        
                              (None, 42),                        
                              (None, 42)])                       
                                                                 
Total params: 32214 (125.84 KB)
Trainable params: 31458 (122.88 KB)
Non-trainable params: 756 (2.95 KB)
_________________________________________________________________
