Skip to content

Commit

Permalink
feat: add entmax as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed Jun 18, 2020
1 parent d6fbf90 commit 96c8a74
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 15 deletions.
3 changes: 3 additions & 0 deletions README.md
Expand Up @@ -141,6 +141,9 @@ You can also get comfortable with how the code works by playing with the **noteb
- device_name : str (default='auto')
'cpu' for cpu training, 'gpu' for gpu training, 'auto' to automatically detect gpu.

- mask_type: str (default='sparsemax')
Either "sparsemax" or "entmax" : this is the masking function to use for selecting features

## Fit parameters

- X_train : np.array
Expand Down
4 changes: 3 additions & 1 deletion census_example.ipynb
Expand Up @@ -154,7 +154,9 @@
" optimizer_params=dict(lr=2e-2),\n",
" scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
" \"gamma\":0.9},\n",
" scheduler_fn=torch.optim.lr_scheduler.StepLR)"
" scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
" mask_type='entmax' # \"sparsemax\"\n",
" )"
]
},
{
Expand Down
8 changes: 5 additions & 3 deletions pytorch_tabnet/tab_model.py
Expand Up @@ -23,7 +23,8 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params=None, scheduler_fn=None,
device_name='auto'):
device_name='auto',
mask_type="sparsemax"):
""" Class for TabNet model
Parameters
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[],
self.device_name = device_name
self.scheduler_params = scheduler_params
self.scheduler_fn = scheduler_fn

self.mask_type = mask_type
self.seed = seed
torch.manual_seed(self.seed)
# Defining device
Expand Down Expand Up @@ -133,7 +134,8 @@ def fit(self, X_train, y_train, X_valid=None, y_valid=None, loss_fn=None,
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
device_name=self.device_name).to(self.device)
device_name=self.device_name,
mask_type=self.mask_type).to(self.device)

self.reducing_matrix = create_explain_matrix(self.network.input_dim,
self.network.cat_emb_dim,
Expand Down
40 changes: 29 additions & 11 deletions pytorch_tabnet/tab_network.py
Expand Up @@ -43,7 +43,8 @@ def __init__(self, input_dim, output_dim,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
n_independent=2, n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum=0.02):
virtual_batch_size=128, momentum=0.02,
mask_type="sparsemax"):
"""
Defines main part of the TabNet network without the embedding layers.
Expand All @@ -70,6 +71,8 @@ def __init__(self, input_dim, output_dim,
Number of independent GLU layer in each GLU block (default 2)
- epsilon: float
Avoid log(0), this should be kept very low
- mask_type: str
Either "sparsemax" or "entmax" : this is the masking function to use
"""
super(TabNetNoEmbeddings, self).__init__()
self.input_dim = input_dim
Expand All @@ -82,7 +85,7 @@ def __init__(self, input_dim, output_dim,
self.n_independent = n_independent
self.n_shared = n_shared
self.virtual_batch_size = virtual_batch_size

self.mask_type = mask_type
self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01)

if self.n_shared > 0:
Expand Down Expand Up @@ -113,7 +116,8 @@ def __init__(self, input_dim, output_dim,
momentum=momentum)
attention = AttentiveTransformer(n_a, self.input_dim,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum)
momentum=momentum,
mask_type=self.mask_type)
self.feat_transformers.append(transformer)
self.att_transformers.append(attention)

Expand Down Expand Up @@ -179,7 +183,8 @@ class TabNet(torch.nn.Module):
def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1,
n_independent=2, n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum=0.02, device_name='auto'):
virtual_batch_size=128, momentum=0.02, device_name='auto',
mask_type="sparsemax"):
"""
Defines TabNet network
Expand Down Expand Up @@ -212,6 +217,8 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
Number of independent GLU layer in each GLU block (default 2)
- n_shared : int
Number of independent GLU layer in each GLU block (default 2)
- mask_type: str
Either "sparsemax" or "entmax" : this is the masking function to use
- epsilon: float
Avoid log(0), this should be kept very low
"""
Expand All @@ -229,6 +236,7 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.epsilon = epsilon
self.n_independent = n_independent
self.n_shared = n_shared
self.mask_type = mask_type

if self.n_steps <= 0:
raise ValueError("n_steps should be a positive integer.")
Expand All @@ -240,7 +248,7 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.post_embed_dim = self.embedder.post_embed_dim
self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum)
virtual_batch_size, momentum, mask_type)

# Defining device
if device_name == 'auto':
Expand All @@ -261,7 +269,10 @@ def forward_masks(self, x):


class AttentiveTransformer(torch.nn.Module):
def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02):
def __init__(self, input_dim, output_dim,
virtual_batch_size=128,
momentum=0.02,
mask_type="sparsemax"):
"""
Initialize an attention transformer.
Expand All @@ -273,23 +284,30 @@ def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02)
Outpu_size
- momentum : float
Float value between 0 and 1 which will be used for momentum in batch norm
- mask_type: str
Either "sparsemax" or "entmax" : this is the masking function to use
"""
super(AttentiveTransformer, self).__init__()
self.fc = Linear(input_dim, output_dim, bias=False)
initialize_non_glu(self.fc, input_dim, output_dim)
self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size,
momentum=momentum)

# Sparsemax
self.sp_max = sparsemax.Sparsemax(dim=-1)
# Entmax
# self.sp_max = sparsemax.Entmax15(dim=-1)
if mask_type == "sparsemax":
# Sparsemax
self.selector = sparsemax.Sparsemax(dim=-1)
elif mask_type == "entmax":
# Entmax
self.selector = sparsemax.Entmax15(dim=-1)
else:
raise NotImplementedError("Please choose either sparsemax" +
"or entmax as masktype")

def forward(self, priors, processed_feat):
x = self.fc(processed_feat)
x = self.bn(x)
x = torch.mul(x, priors)
x = self.sp_max(x)
x = self.selector(x)
return x


Expand Down

0 comments on commit 96c8a74

Please sign in to comment.