From 1628d8c87e6f13e3f73d2e59acc7bd40770b9943 Mon Sep 17 00:00:00 2001 From: Golovneva <103262907+Golovneva@users.noreply.github.com> Date: Tue, 24 May 2022 13:09:55 -0400 Subject: [PATCH] Extend BERT-based classification with customized layers (#4553) * Extend BERT-based classification with customized layers * fix bugs and add tests * increase lr to improve training stability * upgrading torch version * adjusting loss value --- .circleci/config.yml | 6 +- .../agents/bert_classifier/bert_classifier.py | 91 ++++++++++++++++--- parlai/agents/bert_ranker/README.md | 4 +- parlai/agents/bert_ranker/helpers.py | 51 ++++++----- .../test_anti_scaling/bart_narrow.yml | 2 +- .../test_anti_scaling/transformer_narrow.yml | 2 +- tests/nightly/gpu/test_bert.py | 30 +++++- 7 files changed, 141 insertions(+), 45 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9362aba0f9e..068455a7334 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -105,7 +105,7 @@ commands: - run: name: Install torch GPU and dependencies command: | - python -m pip install --progress-bar off torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html + python -m pip install --progress-bar off torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html python -m pip install --progress-bar off 'fairscale~=0.4.0' python -m pip install --progress-bar off pytorch-pretrained-bert python -m pip install --progress-bar off 'transformers==4.3.3' @@ -124,7 +124,7 @@ commands: name: Install torch CPU and dependencies command: | python -m pip install --progress-bar off 'transformers==4.3.3' - python -m pip install --progress-bar off 'torch==1.10.2' + python -m pip install --progress-bar off 'torch==1.11.0' python -c 'import torch; print("Torch version:", torch.__version__)' python -m torch.utils.collect_env @@ -134,7 +134,7 @@ commands: - run: name: Install torch CPU and dependencies command: | - python -m pip install --progress-bar off 'torch==1.10.2+cpu' 'torchvision==0.11.3+cpu' 'torchaudio==0.10.2+cpu' -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install --progress-bar off 'torch==1.11.0+cpu' 'torchvision==0.12.0+cpu' 'torchaudio==0.11.0+cpu' -f https://download.pytorch.org/whl/torch_stable.html python -m pip install --progress-bar off 'transformers==4.3.3' python -c 'import torch; print("Torch version:", torch.__version__)' python -m torch.utils.collect_env diff --git a/parlai/agents/bert_classifier/bert_classifier.py b/parlai/agents/bert_classifier/bert_classifier.py index b25e60d7910..a489a6fbfd7 100644 --- a/parlai/agents/bert_classifier/bert_classifier.py +++ b/parlai/agents/bert_classifier/bert_classifier.py @@ -6,6 +6,8 @@ """ BERT classifier agent uses bert embeddings to make an utterance-level classification. + +This implementation allows to customize classifier layers with input arguments. """ import os @@ -32,6 +34,12 @@ ) +LINEAR = "linear" +RELU = "relu" + +SUPPORTED_LAYERS = [LINEAR, RELU] + + class BertClassifierHistory(History): """ Handles tokenization history. @@ -72,6 +80,7 @@ def __init__(self, opt, shared=None): opt["pretrained_path"] = self.pretrained_path self.add_cls_token = opt.get("add_cls_token", True) self.sep_last_utt = opt.get("sep_last_utt", False) + self.classifier_layers = opt.get("classifier_layers", None) super().__init__(opt, shared) @classmethod @@ -90,20 +99,6 @@ def add_cmdline_args( """ super().add_cmdline_args(parser, partial_opt=partial_opt) parser = parser.add_argument_group("BERT Classifier Arguments") - parser.add_argument( - "--type-optimization", - type=str, - default="all_encoder_layers", - choices=[ - "additional_layers", - "top_layer", - "top4_layers", - "all_encoder_layers", - "all", - ], - help="which part of the encoders do we optimize " - "(defaults to all layers)", - ) parser.add_argument( "--add-cls-token", type="bool", @@ -117,6 +112,13 @@ def add_cmdline_args( help="separate the last utterance into a different" "segment with [SEP] token in between", ) + parser.add_argument( + "--classifier-layers", + nargs='+', + type=str, + default=None, + help="list of classifier layers comma-separated with layer's dimension where applicable. For example: linear,64 linear,32 relu", + ) parser.set_defaults(dict_maxexs=0) # skip building dictionary return parser @@ -142,12 +144,71 @@ def upgrade_opt(cls, opt_on_disk): return opt_on_disk + def _get_layer_parameters(self, prev_dimension, output_dimension): + """ + Parse layer definitions from the input. + """ + layers = [] + dimensions = [] + no_dimension = -1 + for layer in self.classifier_layers: + if ',' in layer: + l, d = layer.split(',') + layers.append(l) + dimensions.append((prev_dimension, int(d))) + prev_dimension = int(d) + else: + layers.append(layer) + dimensions.append(no_dimension) + ind = 0 + while ( + ind < len(dimensions) + and dimensions[len(dimensions) - ind - 1] == no_dimension + ): + ind += 1 + if (ind == len(dimensions) and prev_dimension == output_dimension) or ( + ind < len(dimensions) and dimensions[ind][1] == output_dimension + ): + return layers, dimensions + + if ind < len(dimensions): + raise Exception( + f"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}" + ) + raise Exception( + f"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}" + ) + + def _map_layer(self, layer: str, dim=None): + """ + Get torch wrappers for nn layers. + """ + if layer == LINEAR: + return torch.nn.Linear(dim[0], dim[1]) + elif layer == RELU: + return torch.nn.ReLU(inplace=False) + raise Exception( + "Unrecognized network layer {}. Available options are: {}".format( + layer, ", ".join(SUPPORTED_LAYERS) + ) + ) + def build_model(self): """ Construct the model. """ num_classes = len(self.class_list) - return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes) + bert_model = BertModel.from_pretrained(self.pretrained_path) + if self.classifier_layers is not None: + prev_dimension = bert_model.embeddings.word_embeddings.weight.size(1) + layers, dims = self._get_layer_parameters( + prev_dimension=prev_dimension, output_dimension=num_classes + ) + decoders = torch.nn.Sequential() + for l, d in zip(layers, dims): + decoders.append(self._map_layer(l, d)) + return BertWrapper(bert_model=bert_model, classifier_layer=decoders) + return BertWrapper(bert_model=bert_model, output_dim=num_classes) def _set_text_vec(self, *args, **kwargs): obs = super()._set_text_vec(*args, **kwargs) diff --git a/parlai/agents/bert_ranker/README.md b/parlai/agents/bert_ranker/README.md index d5f09bcd699..2594de4466d 100644 --- a/parlai/agents/bert_ranker/README.md +++ b/parlai/agents/bert_ranker/README.md @@ -18,10 +18,10 @@ In order to use those agents you need to install pytorch-pretrained-bert (https: Train a BiEncoder BERT model on ConvAI2: ```bash -parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True +parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True ``` Train a CrossEncoder BERT model on ConvAI2: ```bash -parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True +parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True ``` diff --git a/parlai/agents/bert_ranker/helpers.py b/parlai/agents/bert_ranker/helpers.py index a2b59360b54..f18f94d10dc 100644 --- a/parlai/agents/bert_ranker/helpers.py +++ b/parlai/agents/bert_ranker/helpers.py @@ -20,7 +20,6 @@ 'installed. Install with:\n `pip install transformers`.' ) - import torch @@ -63,20 +62,6 @@ def add_common_args(parser): 'multiple gpus. NOTE This is incompatible' ' with distributed training', ) - parser.add_argument( - '--type-optimization', - type=str, - default='all_encoder_layers', - choices=[ - 'additional_layers', - 'top_layer', - 'top4_layers', - 'all_encoder_layers', - 'all', - ], - help='Which part of the encoders do we optimize. ' - '(Default: all_encoder_layers.)', - ) parser.add_argument( '--bert-aggregation', type=str, @@ -97,15 +82,26 @@ def add_common_args(parser): class BertWrapper(torch.nn.Module): """ Adds a optional transformer layer and a linear layer on top of BERT. + Args: + bert_model: pretrained BERT model + output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified + classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList) + add_transformer_layer: if additional transformer layer should be added on top of the pretrained model + layer_pulled: which layer should be pulled from pretrained model + aggregation: embeddings aggregation (pooling) strategy. Available options are: + (default)"first" - [CLS] representation, + "mean" - average of all embeddings except CLS, + "max" - max of all embeddings except CLS """ def __init__( self, - bert_model, - output_dim, - add_transformer_layer=False, - layer_pulled=-1, - aggregation="first", + bert_model: BertModel, + output_dim: int = -1, + add_transformer_layer: bool = False, + layer_pulled: int = -1, + aggregation: str = "first", + classifier_layer: torch.nn.Module = None, ): super(BertWrapper, self).__init__() self.layer_pulled = layer_pulled @@ -123,7 +119,18 @@ def __init__( hidden_act='gelu', ) self.additional_transformer_layer = BertLayer(config_for_one_layer) - self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim) + if classifier_layer is None and output_dim == -1: + raise Exception( + "Either output dimention or classifier layers must be specified" + ) + elif classifier_layer is None: + self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim) + else: + self.additional_linear_layer = classifier_layer + if output_dim != -1: + print( + "Both classifier layer and output dimension are specified. Output dimension parameter is ignored." + ) self.bert_model = bert_model def forward(self, token_ids, segment_ids, attention_mask): @@ -171,7 +178,7 @@ def forward(self, token_ids, segment_ids, attention_mask): # Sort of hack to make it work with distributed: this way the pooler layer # is used for grad computation, even though it does not change anything... # in practice, it just adds a very (768*768) x (768*batchsize) matmul - result += 0 * torch.sum(output_pooler) + result = result + 0 * torch.sum(output_pooler) return result diff --git a/tests/nightly/gpu/anti_scaling/test_anti_scaling/bart_narrow.yml b/tests/nightly/gpu/anti_scaling/test_anti_scaling/bart_narrow.yml index 773c9467a0f..e869dbf89b0 100644 --- a/tests/nightly/gpu/anti_scaling/test_anti_scaling/bart_narrow.yml +++ b/tests/nightly/gpu/anti_scaling/test_anti_scaling/bart_narrow.yml @@ -1,5 +1,5 @@ dec_emb_loss: 0.0151884 -dec_hid_loss: 0.662957 +dec_hid_loss: 0.662956 dec_self_attn_loss: 497.628 enc_dec_attn_loss: 230.709 enc_emb_loss: 0.0109334 diff --git a/tests/nightly/gpu/anti_scaling/test_anti_scaling/transformer_narrow.yml b/tests/nightly/gpu/anti_scaling/test_anti_scaling/transformer_narrow.yml index a28fe606850..5ab221e0bb7 100644 --- a/tests/nightly/gpu/anti_scaling/test_anti_scaling/transformer_narrow.yml +++ b/tests/nightly/gpu/anti_scaling/test_anti_scaling/transformer_narrow.yml @@ -6,5 +6,5 @@ enc_emb_loss: 0.00210945 enc_hid_loss: 0.279337 enc_loss: 0.284342 enc_self_attn_loss: 371.567 -loss: 11.7943 +loss: 11.7944 pred_loss: 6.80161 diff --git a/tests/nightly/gpu/test_bert.py b/tests/nightly/gpu/test_bert.py index 33e5c635fb2..b506644214d 100644 --- a/tests/nightly/gpu/test_bert.py +++ b/tests/nightly/gpu/test_bert.py @@ -41,13 +41,41 @@ def test_crossencoder(self): batchsize=2, learningrate=1e-3, gradient_clip=1.0, - type_optimization="all_encoder_layers", text_truncate=8, label_truncate=8, ) ) self.assertGreaterEqual(test['accuracy'], 0.8) + def test_bertclassifier(self): + valid, test = testing_utils.train_model( + dict( + task='integration_tests:classifier', + model='bert_classifier/bert_classifier', + num_epochs=2, + batchsize=2, + learningrate=1e-2, + gradient_clip=1.0, + classes=["zero", "one"], + ) + ) + self.assertGreaterEqual(test['accuracy'], 0.9) + + def test_bertclassifier_with_relu(self): + valid, test = testing_utils.train_model( + dict( + task='integration_tests:classifier', + model='bert_classifier/bert_classifier', + num_epochs=2, + batchsize=2, + learningrate=1e-2, + gradient_clip=1.0, + classes=["zero", "one"], + classifier_layers=["linear,64", "linear,2", "relu"], + ) + ) + self.assertGreaterEqual(test['accuracy'], 0.9) + if __name__ == '__main__': unittest.main()