Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Extend BERT-based classification with customized layers #4553

Merged
merged 5 commits into from May 24, 2022

Conversation

Golovneva
Copy link
Contributor

Patch description
Added functionality to specify custom decoder layers for BERT-based classification. Code is a modification of existing in external ParlAI functions.

Testing steps

parlai train_model -m bert_classifier -t snli --classes 'entailment' 'contradiction' 'neutral' -mf /tmp/BERT_snli -bs 20 --classifier-layers linear,64 linear,3 relu
...
13:11:19 | Current ParlAI commit: 844a027ec81d543477d135a87eb5274ef4c013bd
13:11:19 | Current internal commit: 27aa6546aaec9e2e06069faabf2bb34aec1ba9f7
13:11:19 | Current fb commit: a69320df72c2a0c76873574e941eff3dc380fc4b
13:11:19 | creating task(s): snli
loading: /private/home/olggol/ParlAI/data/SNLI/snli_1.0/snli_1.0_train.jsonl
13:11:25 | training...
13:11:32 | time:7s total_exs:1000 total_steps:50 epochs:0.00
    accuracy   bleu-4  class_contradiction_f1  class_contradiction_prec  class_contradiction_recall  class_entailment_f1  class_entailment_prec  class_entailment_recall  class_neutral_f1  class_neutral_prec  \
       .3310 3.31e-10                  .05556                     .2703                      .03096                .4871                  .3311                    .9207             .0950               .3725
    class_neutral_recall  clen  clip  ctpb  ctps  ctrunc  ctrunclen  exps  exs    f1  gnorm  gpu_mem  llen  loss  lr  ltpb  ltps  ltrunc  ltrunclen  total_train_updates   tpb  tps  ups  weighted_f1
                  .05444 27.29 .1400 565.8  4111       0          0 145.3 1000 .3310  .2950   .09049 3.656 1.103   1 73.12 531.3       0          0                   50 638.9 4642 7.27        .2109
...

Other information

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this all looks great! Would it be possible to add a short test?

@@ -90,20 +99,6 @@ def add_cmdline_args(
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser = parser.add_argument_group("BERT Classifier Arguments")
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this option just never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I actually also found it in BertWrapper's add_common_args function, also never used, so I'll remove it from there


if ind < len(dimensions):
raise Exception(
"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: think you're missing f"" string here

"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}"
)
raise Exception(
"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

aggregation="first",
bert_model: BertModel,
output_dim: int = -1,
classifier_layer: torch.nn.Module = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you please move this arg to be the last one? so that prior calls to this __init__ don't fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed issues and added tests

Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for adding tests!

edit: approving assuming long_gpu_tests pass

@Golovneva Golovneva requested a review from klshuster May 23, 2022 20:46
@Golovneva
Copy link
Contributor Author

I have changed the torch version in CircleCI config to make it work since CR was approved, so re-requesting approval for this change

@Golovneva Golovneva merged commit 1628d8c into main May 24, 2022
@Golovneva Golovneva deleted the olggol/bert-classifier branch May 24, 2022 17:09
kushalarora pushed a commit that referenced this pull request Jun 15, 2022
* Extend BERT-based classification with customized layers

* fix bugs and add tests

* increase lr to improve training stability

* upgrading torch version

* adjusting loss value
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants