Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_text_classifier fails with custom AWS_LSTM #3817

Closed
fabridamicelli opened this issue Oct 13, 2022 · 1 comment · Fixed by #3819
Closed

get_text_classifier fails with custom AWS_LSTM #3817

fabridamicelli opened this issue Oct 13, 2022 · 1 comment · Fixed by #3819
Labels

Comments

@fabridamicelli
Copy link
Contributor

Please confirm you have the latest versions of fastai, fastcore, and nbdev prior to reporting a bug (delete one): YES

Describe the bug
The function get_text_classifier from the module text.models.core which takes the argument arch (eg AWD_LSTM) throws a KeyError when a user-instantiated AWD_LSTM is passed (AWD_LSTM(vocab_sz=100, emb_sz=10, n_hid=2, n_layers=2) ).
More precisely, the lookup _model_meta[arch] fails because the custom AWS_LSTM instance is not recognized as being equal to <class 'fastai.text.models.awdlstm.AWD_LSTM'> (the key of the _model_meta[arch] dictionary).

To Reproduce
Steps to reproduce the behavior:

from fastai.text.all import *

arch = AWD_LSTM(vocab_sz=100, emb_sz=10, n_hid=10, n_layers=2)
get_text_classifier(arch=arch, vocab_sz=100, n_class=2)

The error can be clearly seen in this notebook, which can be directly accessed in colab here

Expected behavior
Function should return a SequentialRNN instance.

Error with full stack trace

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [13], in <cell line: 4>()
      1 from fastai.text.all import *
      3 arch = AWD_LSTM(vocab_sz=100, emb_sz=10, n_hid=10, n_layers=2)
----> 4 get_text_classifier(arch=arch, vocab_sz=100, n_class=2)

File ~/miniconda3/envs/fai/lib/python3.10/site-packages/fastai/text/models/core.py:158, in get_text_classifier(arch, vocab_sz, n_class, seq_len, config, drop_mult, lin_ftrs, ps, pad_idx, max_len, y_range)
    144 def get_text_classifier(
    145     arch:callable, # Function or class that can generate a language model architecture
    146     vocab_sz:int, # Size of the vocabulary 
   (...)
    155     y_range:tuple=None # Tuple of (low, high) output value bounds
    156 ):
    157     "Create a text classifier from `arch` and its `config`, maybe `pretrained`"
--> 158     meta = _model_meta[arch]
    159     config = ifnone(config, meta['config_clas']).copy()
    160     for k in config.keys():

KeyError: AWD_LSTM(
  (encoder): Embedding(100, 10, padding_idx=1)
  (encoder_dp): EmbeddingDropout(
    (emb): Embedding(100, 10, padding_idx=1)
  )
  (rnns): ModuleList(
    (0): WeightDropout(
      (module): LSTM(10, 10, batch_first=True)
    )
    (1): WeightDropout(
      (module): LSTM(10, 10, batch_first=True)
    )
  )
  (input_dp): RNNDropout()
  (hidden_dps): ModuleList(
    (0): RNNDropout()
    (1): RNNDropout()
  )
)

Additional context
Forum discussion with another report from @machinatoonist (with no solution so far): https://forums.fast.ai/t/how-to-customise-vocab-sz-in-text-classifier-learner/98230.

@Salehbigdeli
Copy link
Contributor

There are two problem here:

  1. You need to modify the way you used the API, according to docs arch need to be a class or callable creating a model (not the model itself as you gave). something like:
from fastai.text.all import *
config = awd_lstm_clas_config.copy()
config.update(emb_sz=10, n_hid=10, n_layers=2)
get_text_classifier(arch=AWD_LSTM, n_class=2, vocab_sz=100, config=config)
  1. The second problem is the API itself. I don't like the part that I created config config = awd_lstm_clas_config.copy() and the next line. I'm going to create PR to solve this issue, so you can use the API like:
from fastai.text.all import *
config = dict(emb_sz=10, n_hid=10, n_layers=2)
get_text_classifier(arch=AWD_LSTM, n_class=2, vocab_sz=100, config=config)

@jph00 jph00 added the bug label Nov 2, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants