Skip to content

Commit

Permalink
fix: importance indexing fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Optimox committed Jan 6, 2020
1 parent 54188db commit a8382c3
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions pytorch_tabnet/tab_network.py
Expand Up @@ -166,14 +166,16 @@ def __init__(self, input_dim, output_dim, n_d=8, n_a=8,

def apply_embeddings(self, x):
"""Apply embdeddings to raw inputs"""
# Getting categorical data
cat_cols = []
for icat, cat_idx in enumerate(self.cat_idxs):
cat_col = x[:, cat_idx].long()
cat_col = self.embeddings[icat](cat_col)
cat_cols.append(cat_col)
post_embeddings = torch.cat([x[:, self.continuous_idx].float()] + cat_cols, dim=1)
post_embeddings = post_embeddings.float()
cols = []
cat_feat_counter = 0
for feat_init_idx, is_continuous in enumerate(self.continuous_idx):
# Enumerate through continuous idx boolean mask to apply embeddings
if is_continuous:
cols.append(x[:, feat_init_idx].view(-1, 1))
else:
cols.append(self.embeddings[cat_feat_counter](x[:, feat_init_idx].long()))
cat_feat_counter += 1
post_embeddings = torch.cat(cols, dim=1).float()
return post_embeddings

def forward(self, x):
Expand Down

0 comments on commit a8382c3

Please sign in to comment.