Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do you plan to include the document context as the BERT paper did? #4

Closed
xuuuluuu opened this issue Mar 26, 2021 · 2 comments
Closed

Comments

@xuuuluuu
Copy link

Hi, Thanks for the nice framework.

Do you plan to include the document context as the BERT paper did?

@dsindex
Copy link
Owner

dsindex commented Mar 27, 2021

i have a plan to implement —bert_use_subword_pooling —bert_use_word_embedding —bert_use_doc_context options.

reference : #1 (comment)

@dsindex
Copy link
Owner

dsindex commented Mar 30, 2021

@xuuuluuu

as i mentioned, i add --bert_use_doc_context --bert_use_subword_pooling option.

document context

https://github.com/dsindex/ntagger/blob/master/util_bert.py

 ---------------------------------------------------------------------------
      with --bert_use_doc_context:
     
        with --bert_doc_context_option=1:
                           prev example   example     next examples   
      tokens:        [CLS] p1 p2 p3 p4 p5 x1 x2 x3 x4 n1 n2 n3 n4  m1 m2 m3 ...
      token_idx:     0     1  2  3  4  5  6  7  8  9  10 11 12 13  14 15 16 ...
      input_ids:     x     x  x  x  x  x  x  x  x  x  x  x  x  x   x  x  x  ...
      segment_ids:   0     0  0  0  0  0  0  0  0  0  0  0  0  0   0  0  0  ...
      input_mask:    1     1  1  1  1  1  1  1  1  1  1  1  1  1   1  1  1  ...
      doc2sent_idx:  0     6  7  8  9  0  0  0  0  0  0  0  0  0   0  0  0  ...
      doc2sent_mask: 1     1  1  1  1  0  0  0  0  0  0  0  0  0   0  0  0  ...
        with --bert_doc_context_option=2:
                           prev examples   example    next examples   
     
      input_ids, segment_ids, input_maks are replaced to document-level.
      and doc2sent_idx will be used to slice input_ids, segment_ids, input_mask.
      ---------------------------------------------------------------------------

usage

# modify maximum sequence length for document context
$ vi configs/config-bert.json 
  "n_ctx": 512

# for Linear

## preprocessing
$ python preprocess.py --config=configs/config-bert.json --data_dir=data/conll2003 --bert_model_name_or_path=./embeddings/bert-base-cased --bert_use_doc_context --bert_use_subword_pooling --bert_doc_context_option=1

## train
$ python train.py --config=configs/config-bert.json --data_dir=data/conll2003 --save_path=pytorch-model-bert.pt --bert_model_name_or_path=./embeddings/bert-base-cased --bert_output_dir=bert-checkpoint --batch_size=16 --lr=2e-5 --epoch=30 --bert_use_doc_context --bert_use_subword_pooling --bert_disable_lstm

## evaluate
$ python evaluate.py --config=configs/config-bert.json --data_dir=data/conll2003 --model_path=pytorch-model-bert.pt --bert_output_dir=bert-checkpoint --bert_use_doc_context --bert_use_subword_pooling --bert_disable_lstm
$ cd data/conll2003; perl ../../etc/conlleval.pl < test.txt.pred ; cd ../..

# for BiLSTM-CRF + word embedding

## preprocessing
$ python preprocess.py --config=configs/config-bert.json --data_dir=data/conll2003 --bert_model_name_or_path=./embeddings/bert-base-cased --bert_use_doc_context --bert_use_subword_pooling --bert_use_word_embedding --bert_doc_context_option=1

## train
$ python train.py --config=configs/config-bert.json --data_dir=data/conll2003 --save_path=pytorch-model-bert.pt --bert_model_name_or_path=./embeddings/bert-base-cased --bert_output_dir=bert-checkpoint --batch_size=8 --lr=1e-5 --epoch=30 --bert_freezing_epoch=3 --bert_lr_during_freezing=1e-3 --use_crf --bert_use_doc_context --bert_use_subword_pooling --bert_use_word_embedding

## evaluate
$ python evaluate.py --config=configs/config-bert.json --data_dir=data/conll2003 --model_path=pytorch-model-bert.pt --bert_output_dir=bert-checkpoint --use_crf --bert_use_doc_context --bert_use_subword_pooling --bert_use_word_embedding
$ cd data/conll2003; perl ../../etc/conlleval.pl < test.txt.pred ; cd ../..

results

  • F1 score measured by conlleval.pl
  • bert-base-cased
    • Linear : 91.45%
    • BiLSTM-CRF : 91.55%
    • document context + subword pooling : 92.23%
    • document context + --bert_doc_context_option=2 : 92.35%
    • document context + subword pooling + word embedding + BiLSTM-CRF : 92.85%
  • bert-large-cased
    • Linear : 91.13%
    • BiLSTM-CRF : 92.02%
    • document context + --bert_doc_context_option=2 : 92.27%
    • document context + subword pooling + word embedding + BiLSTM-CRF : 92.83%
  • xlm-roberta-large
    • Linear : 92.75%
    • document context + --bert_doc_context_option=2 : 93.86%
    • document context + subword pooling + word embedding + BiLSTM-CRF : 93.59%
  • deberta-v2-xlarge
    • Linear : 93.12%
    • document context + --bert_doc_context_option=2 : 94.00%
  • BERT paper
    • bert-base-cased : 92.4%
    • bert-large-cased : 92.8%

스크린샷 2021-04-03 오후 9 36 48

스크린샷 2021-04-03 오후 9 39 27

additionally, with --bert_use_word_embedding option, you can add GloVe word embedding features to bert embeddings at word-level. it is likely to get a better result.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants