From e3b5a04edb1aff25683ce5457f9b4fd57b4c1bf6 Mon Sep 17 00:00:00 2001 From: Sebastien Fischman Date: Thu, 7 May 2020 16:21:27 +0200 Subject: [PATCH] fix: allow zero layer --- pytorch_tabnet/tab_network.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pytorch_tabnet/tab_network.py b/pytorch_tabnet/tab_network.py index 6df98528..15c1131f 100644 --- a/pytorch_tabnet/tab_network.py +++ b/pytorch_tabnet/tab_network.py @@ -203,15 +203,18 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8, self.epsilon = epsilon self.n_independent = n_independent self.n_shared = n_shared - self.virtual_batch_size = virtual_batch_size + if self.n_steps <= 0: + raise ValueError("n_steps should be a positive integer.") + if self.n_independent == 0 and self.n_shared == 0: + raise ValueError("n_shared and n_independant can't be both zero.") + + self.virtual_batch_size = virtual_batch_size self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim) self.post_embed_dim = self.embedder.post_embed_dim - self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps, gamma, n_independent, n_shared, epsilon, virtual_batch_size, momentum) - self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01) # Defining device @@ -289,10 +292,9 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent, } if shared_layers is None: - self.shared = None - self.specifics = GLU_Block(input_dim, output_dim, - first=True, - **params) + # no shared layers + self.shared = torch.nn.Identity() + is_first = True else: self.shared = GLU_Block(input_dim, output_dim, first=True, @@ -300,12 +302,19 @@ def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent, n_glu=len(shared_layers), virtual_batch_size=virtual_batch_size, momentum=momentum) - self.specifics = GLU_Block(output_dim, output_dim, + is_first = False + + if n_glu_independent == 0: + # no independent layers + self.specifics = torch.nn.Identity() + else: + spec_input_dim = input_dim if is_first else output_dim + self.specifics = GLU_Block(spec_input_dim, output_dim, + first=is_first, **params) def forward(self, x): - if self.shared is not None: - x = self.shared(x) + x = self.shared(x) x = self.specifics(x) return x