Skip to content

Commit

Permalink
fix: allow zero layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox authored and eduardocarvp committed May 11, 2020
1 parent 8092324 commit e3b5a04
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions pytorch_tabnet/tab_network.py
Expand Up @@ -203,15 +203,18 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.epsilon = epsilon
self.n_independent = n_independent
self.n_shared = n_shared
self.virtual_batch_size = virtual_batch_size

if self.n_steps <= 0:
raise ValueError("n_steps should be a positive integer.")
if self.n_independent == 0 and self.n_shared == 0:
raise ValueError("n_shared and n_independant can't be both zero.")

self.virtual_batch_size = virtual_batch_size
self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim)
self.post_embed_dim = self.embedder.post_embed_dim

self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum)

self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

# Defining device
Expand Down Expand Up @@ -289,23 +292,29 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent,
}

if shared_layers is None:
self.shared = None
self.specifics = GLU_Block(input_dim, output_dim,
first=True,
**params)
# no shared layers
self.shared = torch.nn.Identity()
is_first = True
else:
self.shared = GLU_Block(input_dim, output_dim,
first=True,
shared_layers=shared_layers,
n_glu=len(shared_layers),
virtual_batch_size=virtual_batch_size,
momentum=momentum)
self.specifics = GLU_Block(output_dim, output_dim,
is_first = False

if n_glu_independent == 0:
# no independent layers
self.specifics = torch.nn.Identity()
else:
spec_input_dim = input_dim if is_first else output_dim
self.specifics = GLU_Block(spec_input_dim, output_dim,
first=is_first,
**params)

def forward(self, x):
if self.shared is not None:
x = self.shared(x)
x = self.shared(x)
x = self.specifics(x)
return x

Expand Down

0 comments on commit e3b5a04

Please sign in to comment.