Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

TabNet Classification Broken #153

Closed
aribornstein opened this issue Feb 28, 2021 · 1 comment
Closed

TabNet Classification Broken #153

aribornstein opened this issue Feb 28, 2021 · 1 comment
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@aribornstein
Copy link
Contributor

aribornstein commented Feb 28, 2021

🐛 Bug

Tabular Classification throwing Index Error on prediction

To Reproduce

Steps to reproduce the behavior:
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
predictions = model.predict("../input/titanic/test.csv")
print(predictions)

See error

IndexError                                Traceback (most recent call last)
<ipython-input-18-17cd8b85d1d8> in <module>
      1 model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
----> 2 predictions = model.predict("../input/titanic/test.csv")
      3 print(predictions)

/opt/conda/lib/python3.7/site-packages/flash/tabular/classification/model.py in predict(self, x, batch_idx, skip_collate_fn, dataloader_idx, data_pipeline)
     80         data_pipeline = data_pipeline or self.data_pipeline
     81         batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
---> 82         predictions = self.forward(batch)
     83         return data_pipeline.uncollate_fn(predictions)
     84 

/opt/conda/lib/python3.7/site-packages/flash/tabular/classification/model.py in forward(self, x_in)
     86         # TabNet takes single input, x_in is composed of (categorical, numerical)
     87         x = torch.cat([x for x in x_in if x.numel()], dim=1)
---> 88         return self.model(x)[0]
     89 
     90     @classmethod

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/pytorch_tabnet/tab_network.py in forward(self, x)
    580 
    581     def forward(self, x):
--> 582         x = self.embedder(x)
    583         return self.tabnet(x)
    584 

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/pytorch_tabnet/tab_network.py in forward(self, x)
    847             else:
    848                 cols.append(
--> 849                     self.embeddings[cat_feat_counter](x[:, feat_init_idx].long())
    850                 )
    851                 cat_feat_counter += 1

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    124         return F.embedding(
    125             input, self.weight, self.padding_idx, self.max_norm,
--> 126             self.norm_type, self.scale_grad_by_freq, self.sparse)
    127 
    128     def extra_repr(self) -> str:

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1850         # remove once script supports set_grad_enabled
   1851         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1852     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1853 
   1854 

IndexError: index out of range in self

Code sample

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@aribornstein aribornstein added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 28, 2021
@aribornstein aribornstein changed the title TabNet Classfiction Broken TabNet Classification Broken Feb 28, 2021
@edenlightning edenlightning added this to the 0.2 milestone Mar 22, 2021
@edenlightning edenlightning modified the milestones: 0.2, 0.3 Apr 19, 2021
@ethanwharris
Copy link
Collaborator

The tabular_classification.py example is working for me on master so closing this. Please re-open if the issue persists 😃

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants