Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Hierarchical intent and slot filling demo is broken #1012

Closed
ButteredGroove opened this issue Sep 27, 2019 · 5 comments
Closed

Hierarchical intent and slot filling demo is broken #1012

ButteredGroove opened this issue Sep 27, 2019 · 5 comments
Assignees

Comments

@ButteredGroove
Copy link

ButteredGroove commented Sep 27, 2019

Steps to reproduce

Follow the instructions here: https://pytext.readthedocs.io/en/master/hierarchical_intent_slot_tutorial.html

Observed Results

$ pytext train < pytext/demo/configs/rnng.json
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/user/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) /'(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Patch PathManager with python builtins
Patch PathManager with python builtins
No config file specified, reading from stdin
WARNING - Applying old config adapter for version=12. Please consider migrating your old configs to the latest version.
WARNING - Applying old config adapter for version=13. Please consider migrating your old configs to the latest version.
WARNING - Applying old config adapter for version=14. Please consider migrating your old configs to the latest version.
WARNING - Applying old config adapter for version=15. Please consider migrating your old configs to the latest version.
WARNING - Applying old config adapter for version=16. Please consider migrating your old configs to the latest version.

===Starting training...

Parameters: PyTextConfig:
    debug_path: /tmp/model.debug
    distributed_world_size: 1
    export_caffe2_path: None
    export_onnx_path: /tmp/model.onnx
    export_torchscript_path: None
    include_dirs: None
    load_snapshot_path:
    modules_save_dir:
    random_seed: None
    report_eval_results: False
    save_all_checkpoints: False
    save_module_checkpoints: False
    save_snapshot_path: /tmp/model.pt
    task: SemanticParsingTask.Config:
        data: Data.Config:
            batcher: PoolingBatcher.Config:
                eval_batch_size: 1
                num_shuffled_pools: 1
                pool_num_batches: 10000
                test_batch_size: 1
                train_batch_size: 1
            in_memory: True
            sort_key: None
            source: TSVDataSource.Config:
                column_mapping: {}
                delimiter:
                drop_incomplete_rows: False
                eval_filename: /home/user/dataset/eval.tsv
                field_names: ['text', 'tokenized_text', 'seqlogical']
                quoted: False
                test_filename: /home/user/dataset/test.tsv
                train_filename: /home/user/dataset/train.tsv
        metric_reporter: CompositionalMetricReporter.Config:
            output_path: /tmp/test_out.txt
            pep_format: False
            text_column_name: tokenized_text
        model: RNNGParser.Config:
            ablation: AblationParams:
                use_action: True
                use_buffer: True
                use_last_open_NT_feature: False
                use_stack: True
            beam_size: 1
            compositional_type: CompositionalType.SUM
            constraints: RNNGConstraints:
                ignore_loss_for_unsupported: False
                intent_slot_nesting: True
                no_slots_inside_unsupported: True
            dropout: 0.34
            embedding: WordEmbedding.Config:
                embed_dim: 100
                embedding_init_range: None
                embedding_init_strategy: EmbedInitStrategy.RANDOM
                export_input_names: ['tokens_vals']
                freeze: False
                load_path: None
                lowercase_tokens: True
                min_freq: 1
                mlp_layer_dims: []
                padding_idx: None
                pretrained_embeddings_path:
                save_path: None
                shared_module_key: None
                vocab_file:
                vocab_from_all_data: False
                vocab_from_pretrained_embeddings: False
                vocab_from_train_data: True
                vocab_size: 0
            inputs: ModelInput:
                actions: AnnotationNumberizer.Config:
                    column: seqlogical
                tokens: TokenTensorizer.Config:
                    add_bos_token: False
                    add_eos_token: False
                    column: tokenized_text
                    max_seq_len: None
                    tokenizer: Tokenizer.Config:
                        lowercase: True
                        split_regex: \s+
                    use_eos_token_for_bos: False
                    vocab: VocabConfig:
                        build_from_data: True
                        size_from_data: 0
                        vocab_files: []
            lstm: BiLSTM.Config:
                bidirectional: True
                dropout: 0.34
                freeze: False
                load_path: None
                lstm_dim: 164
                num_layers: 2
                pack_sequence: True
                save_path: None
                shared_module_key: None
            max_open_NT: 10
            top_k: 1
            version: 2
        trainer: HogwildTrainer.Config:
            num_workers: 1
            real_trainer: TaskTrainer.Config:
                do_eval: True
                early_stop_after: 0
                epochs: 1
                max_clip_norm: None
                num_accumulated_batches: 1
                num_batches_per_epoch: None
                num_samples_to_log_progress: 1000
                optimizer: Adam.Config:
                    eps: 1e-08
                    lr: 0.001
                    weight_decay: 1e-05
                report_train_metrics: False
                scheduler: None
                sparsifier: None
                target_time_limit_seconds: None
    test_out_path: /tmp/test_out.txt
    torchscript_quantize: False
    use_config_from_snapshot: True
    use_cuda_if_available: True
    use_deterministic_cudnn: False
    use_fp16: False
    use_tensorboard: True
    version: 17


        # for debug of GPU
        use_cuda_if_available: True
        device_id: 0
        world_size: 1
        torch.cuda.is_available(): True
        cuda.CUDA_ENABLED: True
        cuda.DISTRIBUTED_WORLD_SIZE: 1

# for debug of FP16: fp16_enabled=False
Creating task: SemanticParsingTask...
Traceback (most recent call last):
  File "/home/user/.local/bin/pytext", line 11, in <module>
    load_entry_point('pytext-nlp', 'console_scripts', 'pytext')()
  File "/home/user/.local/lib/python3.7/site-packages/click/core.py", line 764, in __call__
    return self.main(*args, **kwargs)
  File "/home/user/.local/lib/python3.7/site-packages/click/core.py", line 717, in main
    rv = self.invoke(ctx)
  File "/home/user/.local/lib/python3.7/site-packages/click/core.py", line 1137, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/home/user/.local/lib/python3.7/site-packages/click/core.py", line 956, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/user/.local/lib/python3.7/site-packages/click/core.py", line 555, in invoke
    return callback(*args, **kwargs)
  File "/home/user/.local/lib/python3.7/site-packages/click/decorators.py", line 17, in new_func
    return f(get_current_context(), *args, **kwargs)
  File "/home/user/pytext/pytext/main.py", line 364, in train
    train_model(config, metric_channels=metric_channels)
  File "/home/user/pytext/pytext/workflow.py", line 99, in train_model
    config, dist_init_url, device_id, rank, world_size, metric_channels, metadata
  File "/home/user/pytext/pytext/workflow.py", line 140, in prepare_task
    config.task, metadata=metadata, rank=rank, world_size=world_size
  File "/home/user/pytext/pytext/task/task.py", line 43, in create_task
    world_size=world_size,
  File "/home/user/pytext/pytext/config/component.py", line 154, in create_component
    return cls.from_config(config, *args, **kwargs)
  File "/home/user/pytext/pytext/task/new_task.py", line 105, in from_config
    tensorizers, data = cls._init_tensorizers(config, tensorizers, rank, world_size)
  File "/home/user/pytext/pytext/task/new_task.py", line 148, in _init_tensorizers
    init_tensorizers=init_tensorizers,
  File "/home/user/pytext/pytext/config/component.py", line 154, in create_component
    return cls.from_config(config, *args, **kwargs)
  File "/home/user/pytext/pytext/data/data.py", line 256, in from_config
    **kwargs,
  File "/home/user/pytext/pytext/data/data.py", line 285, in __init__
    self.tensorizers, full_train_data, init_tensorizers_from_scratch
  File "/home/user/pytext/pytext/data/tensorizers.py", line 1410, in initialize_tensorizers
    for row in data_source:
  File "/home/user/pytext/pytext/data/sources/data_source.py", line 227, in _convert_raw_source
    example = self._read_example(row)
  File "/home/user/pytext/pytext/data/sources/data_source.py", line 207, in _read_example
    example[name] = self.load(value, self.schema[name])
  File "/home/user/pytext/pytext/data/sources/data_source.py", line 248, in load
    return converter(value)
  File "/home/user/pytext/pytext/data/sources/data_source.py", line 313, in load_json
    return json.loads(s)
  File "/usr/local/lib/python3.7/json/__init__.py", line 348, in loads
    return _default_decoder.decode(s)
  File "/usr/local/lib/python3.7/json/decoder.py", line 337, in decode
    obj, end = self.raw_decode(s, idx=_w(s, 0).end())
  File "/usr/local/lib/python3.7/json/decoder.py", line 355, in raw_decode
    raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 2 (char 1)
Exception ignored in: <generator object AnnotationNumberizer.initialize at 0x7f7d7dc08c00>
Traceback (most recent call last):
  File "/home/user/pytext/pytext/data/tensorizers.py", line 1310, in initialize
    self.shift_idx = self.vocab.idx[SHIFT]
KeyError: 'SHIFT'

Expected Results

This:
https://pytext.readthedocs.io/en/master/hierarchical_intent_slot_tutorial.html#test-the-model-interactively-against-input-utterances

Relevant Code

See attached.
pytext_bug.txt

@ButteredGroove
Copy link
Author

I get the same error if I have apex installed as well.
pytext_apex_bug.txt

@snisarg
Copy link
Contributor

snisarg commented Oct 2, 2019

@ButteredGroove can you try from github master instead of the pip release and let us know?

@snisarg snisarg self-assigned this Oct 2, 2019
@ButteredGroove
Copy link
Author

ButteredGroove commented Oct 2, 2019

Hi @snisarg . The issue is in the github master. Please refer to line 75 of https://github.com/facebookresearch/pytext/files/3664614/pytext_apex_bug.txt or lines 1 through 12 of https://github.com/facebookresearch/pytext/files/3664372/pytext_bug.txt

@hengchao0248
Copy link

hengchao0248 commented Oct 14, 2019

I met the same bug.
in pytext/data/tensorizers.py around line 1313,

   try:
        while True:
            row = yield
            annotation = Annotation(row[self.column])
            actions = annotation.tree.to_actions()
            self.vocab_builder.add_all(actions)
    except GeneratorExit:
        self.vocab = self.vocab_builder.make_vocab()
        print(self.vocab.idx)   # maybe the bug is here, self.vocab.idx is {}
        self.shift_idx = self.vocab.idx[SHIFT]
        self.reduce_idx = self.vocab.idx[REDUCE]

@snisarg
Copy link
Contributor

snisarg commented Oct 17, 2019

Thanks for the detailed traces. I can reproduce it now. There's something up with the tensorizer that I can't point by looking at the code. I'll prioritize this and debug to get results about this soon.

sdwivedi pushed a commit to sdwivedi/pytext that referenced this issue Nov 20, 2019
…arch#1012)

In daae0c3, the column_schema for
AnnotationNumberizer changed from str to List[str], which led to the
broken demo
@ButteredGroove ButteredGroove changed the title Heirarchical intent and slot filling demo is broken Hierarchical intent and slot filling demo is broken Nov 21, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants