Skip to content

Commit

Permalink
feat: fix shared layers with independent batchnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
eduardocarvp committed Feb 7, 2020
1 parent dbe476a commit 5f0e43f
Showing 1 changed file with 56 additions and 30 deletions.
86 changes: 56 additions & 30 deletions pytorch_tabnet/tab_network.py
Expand Up @@ -127,15 +127,16 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

if self.n_shared > 0:
shared_feat_transform = GLU_Block(self.post_embed_dim,
n_d+n_a,
n_glu=self.n_shared,
virtual_batch_size=self.virtual_batch_size,
first=True,
momentum=momentum,
device=self.device)
shared_feat_transform = torch.nn.ModuleList()
for i in range(self.n_shared):
if i == 0:
shared_feat_transform.append(Linear(self.post_embed_dim, 2*(n_d + n_a), bias=False))
else:
shared_feat_transform.append(Linear(n_d + n_a, 2*(n_d + n_a), bias=False))

else:
shared_feat_transform = None

self.initial_splitter = FeatTransformer(self.post_embed_dim, n_d+n_a, shared_feat_transform,
n_glu=self.n_independent,
virtual_batch_size=self.virtual_batch_size,
Expand Down Expand Up @@ -244,7 +245,7 @@ def forward(self, priors, processed_feat):


class FeatTransformer(torch.nn.Module):
def __init__(self, input_dim, output_dim, shared_blocks, n_glu,
def __init__(self, input_dim, output_dim, shared_layers, n_glu,
virtual_batch_size=128, momentum=0.02, device='cpu'):
super(FeatTransformer, self).__init__()
"""
Expand All @@ -256,26 +257,27 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu,
Input size
- output_dim : int
Outpu_size
- shared_blocks : torch.nn.Module
- shared_blocks : torch.nn.ModuleList
The shared block that should be common to every step
- momentum : float
Float value between 0 and 1 which will be used for momentum in batch norm
"""

self.shared = shared_blocks
if self.shared is not None:
for l in self.shared.glu_layers:
l.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,
momentum=momentum, device=device)

if self.shared is None:
if shared_layers is None:
self.specifics = GLU_Block(input_dim, output_dim,
n_glu=n_glu,
first=True,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device)
else:
self.shared = GLU_Block(input_dim, output_dim,
n_glu=n_glu,
first=True,
shared_layers=shared_layers,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device)
self.specifics = GLU_Block(output_dim, output_dim,
n_glu=n_glu,
virtual_batch_size=virtual_batch_size,
Expand All @@ -284,7 +286,11 @@ def __init__(self, input_dim, output_dim, shared_blocks, n_glu,

def forward(self, x):
if self.shared is not None:
# print('-------before----------')
# print(self.shared.glu_layers[0].bn.bn.running_mean)
x = self.shared(x)
# print('-------after-----------')
# print(self.shared.glu_layers[0].bn.bn.running_mean)
x = self.specifics(x)
return x

Expand All @@ -293,24 +299,41 @@ class GLU_Block(torch.nn.Module):
"""
Independant GLU block, specific to each step
"""
def __init__(self, input_dim, output_dim, n_glu=2, first=False,
def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=None,
virtual_batch_size=128, momentum=0.02, device='cpu'):
super(GLU_Block, self).__init__()
self.first = first
self.shared_layers = shared_layers
self.n_glu = n_glu
self.glu_layers = torch.nn.ModuleList()
self.scale = torch.sqrt(torch.FloatTensor([0.5]).to(device))
for glu_id in range(self.n_glu):
if glu_id == 0:
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))

if shared_layers:
for glu_id in range(self.n_glu):
if glu_id == 0:
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
fc=shared_layers[glu_id],
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
fc=shared_layers[glu_id],
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
for glu_id in range(self.n_glu):
if glu_id == 0:
self.glu_layers.append(GLU_Layer(input_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))
else:
self.glu_layers.append(GLU_Layer(output_dim, output_dim,
virtual_batch_size=virtual_batch_size,
momentum=momentum,
device=device))

def forward(self, x):
if self.first: # the first layer of the block has no scale multiplication
Expand All @@ -326,12 +349,15 @@ def forward(self, x):


class GLU_Layer(torch.nn.Module):
def __init__(self, input_dim, output_dim,
def __init__(self, input_dim, output_dim, fc=None,
virtual_batch_size=128, momentum=0.02, device='cpu'):
super(GLU_Layer, self).__init__()

self.output_dim = output_dim
self.fc = Linear(input_dim, 2*output_dim, bias=False)
if fc:
self.fc = fc
else:
self.fc = Linear(input_dim, 2*output_dim, bias=False)
initialize_glu(self.fc, input_dim, 2*output_dim)

self.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size,
Expand Down

0 comments on commit 5f0e43f

Please sign in to comment.