Skip to content

Commit

Permalink
feat: add attentive embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Mar 18, 2020
1 parent f83ffad commit c8bd369
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ torch="^1.2"
tqdm="^4.36"
scikit_learn=">0.21"
scipy=">1.4"
torch-scatter="2.0.4"

[tool.poetry.dev-dependencies]
jupyter="1.0.0"
Expand Down
44 changes: 39 additions & 5 deletions pytorch_tabnet/tab_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.nn import Linear, BatchNorm1d, ReLU
import numpy as np
from pytorch_tabnet import sparsemax
from torch_scatter import scatter_max, scatter_mean


def initialize_non_glu(module, input_dim, output_dim):
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, input_dim, output_dim,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
n_independent=2, n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum=0.02):
virtual_batch_size=128, momentum=0.02, embedding_groups=None):
"""
Defines main part of the TabNet network without the embedding layers.
Expand All @@ -74,6 +75,8 @@ def __init__(self, input_dim, output_dim,
Number of independent GLU layer in each GLU block (default 2)
- epsilon: float
Avoid log(0), this should be kept very low
- embedding_groups : list of int
List of int to group post embedding feature to their original column
"""
super(TabNetNoEmbeddings, self).__init__()
self.input_dim = input_dim
Expand All @@ -86,6 +89,7 @@ def __init__(self, input_dim, output_dim,
self.n_independent = n_independent
self.n_shared = n_shared
self.virtual_batch_size = virtual_batch_size
self.embedding_groups = embedding_groups

if self.n_shared > 0:
shared_feat_transform = torch.nn.ModuleList()
Expand Down Expand Up @@ -115,7 +119,8 @@ def __init__(self, input_dim, output_dim,
momentum=momentum)
attention = AttentiveTransformer(n_a, self.input_dim,
virtual_batch_size=self.virtual_batch_size,
momentum=momentum)
momentum=momentum,
embedding_groups=self.embedding_groups)
self.feat_transformers.append(transformer)
self.att_transformers.append(attention)

Expand Down Expand Up @@ -228,9 +233,20 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,
else:
self.post_embed_dim = self.input_dim + np.sum(cat_emb_dim) - len(cat_emb_dim)
self.post_embed_dim = np.int(self.post_embed_dim)

# creates embedding groups
dict_emb = {idx: dim for idx, dim in zip(cat_idxs, self.cat_emb_dims)}
embedding_groups = []
for i in range(input_dim):
emb_size = dict_emb.get(i, None)
if emb_size:
embedding_groups.extend([i]*emb_size)
else:
embedding_groups.append(i)

self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon,
virtual_batch_size, momentum)
virtual_batch_size, momentum, embedding_groups)
self.initial_bn = BatchNorm1d(self.post_embed_dim, momentum=0.01)

# Defining device
Expand Down Expand Up @@ -263,7 +279,8 @@ def forward(self, x):


class AttentiveTransformer(torch.nn.Module):
def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02):
def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02,
embedding_groups=None):
"""
Initialize an attention transformer.
Expand All @@ -281,9 +298,12 @@ def __init__(self, input_dim, output_dim, virtual_batch_size=128, momentum=0.02)
initialize_non_glu(self.fc, input_dim, output_dim)
self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size,
momentum=momentum)
self.embedding_groups = embedding_groups

# Sparsemax
self.sp_max = sparsemax.Sparsemax(dim=-1)
if embedding_groups is not None:
self.group_tensor = torch.LongTensor(embedding_groups)
# Entmax
# self.sp_max = sparsemax.Entmax15(dim=-1)

Expand All @@ -292,7 +312,21 @@ def forward(self, priors, processed_feat):
x = self.bn(x)
x = torch.mul(x, priors)
x = self.sp_max(x)
return x

# group embedding so that attention over same column embeddings is consistant
if self.embedding_groups is not None:
# max
# vals, _ = scatter_max(x, self.group_tensor, out=None)
# out = vals[:, self.group_tensor]

# mean
vals = scatter_mean(x, self.group_tensor, out=None)
out = vals[:, self.group_tensor]
# normalize output
out = 1/torch.sum(out, dim=1)[:,None]*out
return out
else:
return x


class FeatTransformer(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_tabnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, batch_size):
"""
if weights == 0:
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=batch_size, shuffle=True)
batch_size=batch_size, shuffle=True, drop_last=True)
else:
if weights == 1:
class_sample_count = np.array(
Expand All @@ -92,7 +92,7 @@ def create_dataloaders(X_train, y_train, X_valid, y_valid, weights, batch_size):
samples_weight = np.array([weights[t] for t in y_train])
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataloader = DataLoader(TorchDataset(X_train, y_train),
batch_size=batch_size, sampler=sampler)
batch_size=batch_size, sampler=sampler, drop_last=True)

valid_dataloader = DataLoader(TorchDataset(X_valid, y_valid),
batch_size=batch_size, shuffle=False)
Expand Down

0 comments on commit c8bd369

Please sign in to comment.