Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Extend BERT-based classification with customized layers (#4553)
Browse files Browse the repository at this point in the history
* Extend BERT-based classification with customized layers

* fix bugs and add tests

* increase lr to improve training stability

* upgrading torch version

* adjusting loss value
  • Loading branch information
Golovneva committed May 24, 2022
1 parent da85231 commit 1628d8c
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 45 deletions.
6 changes: 3 additions & 3 deletions .circleci/config.yml
Expand Up @@ -105,7 +105,7 @@ commands:
- run:
name: Install torch GPU and dependencies
command: |
python -m pip install --progress-bar off torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
python -m pip install --progress-bar off torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
python -m pip install --progress-bar off 'fairscale~=0.4.0'
python -m pip install --progress-bar off pytorch-pretrained-bert
python -m pip install --progress-bar off 'transformers==4.3.3'
Expand All @@ -124,7 +124,7 @@ commands:
name: Install torch CPU and dependencies
command: |
python -m pip install --progress-bar off 'transformers==4.3.3'
python -m pip install --progress-bar off 'torch==1.10.2'
python -m pip install --progress-bar off 'torch==1.11.0'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
Expand All @@ -134,7 +134,7 @@ commands:
- run:
name: Install torch CPU and dependencies
command: |
python -m pip install --progress-bar off 'torch==1.10.2+cpu' 'torchvision==0.11.3+cpu' 'torchaudio==0.10.2+cpu' -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --progress-bar off 'torch==1.11.0+cpu' 'torchvision==0.12.0+cpu' 'torchaudio==0.11.0+cpu' -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install --progress-bar off 'transformers==4.3.3'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -m torch.utils.collect_env
Expand Down
91 changes: 76 additions & 15 deletions parlai/agents/bert_classifier/bert_classifier.py
Expand Up @@ -6,6 +6,8 @@

"""
BERT classifier agent uses bert embeddings to make an utterance-level classification.
This implementation allows to customize classifier layers with input arguments.
"""

import os
Expand All @@ -32,6 +34,12 @@
)


LINEAR = "linear"
RELU = "relu"

SUPPORTED_LAYERS = [LINEAR, RELU]


class BertClassifierHistory(History):
"""
Handles tokenization history.
Expand Down Expand Up @@ -72,6 +80,7 @@ def __init__(self, opt, shared=None):
opt["pretrained_path"] = self.pretrained_path
self.add_cls_token = opt.get("add_cls_token", True)
self.sep_last_utt = opt.get("sep_last_utt", False)
self.classifier_layers = opt.get("classifier_layers", None)
super().__init__(opt, shared)

@classmethod
Expand All @@ -90,20 +99,6 @@ def add_cmdline_args(
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser = parser.add_argument_group("BERT Classifier Arguments")
parser.add_argument(
"--type-optimization",
type=str,
default="all_encoder_layers",
choices=[
"additional_layers",
"top_layer",
"top4_layers",
"all_encoder_layers",
"all",
],
help="which part of the encoders do we optimize "
"(defaults to all layers)",
)
parser.add_argument(
"--add-cls-token",
type="bool",
Expand All @@ -117,6 +112,13 @@ def add_cmdline_args(
help="separate the last utterance into a different"
"segment with [SEP] token in between",
)
parser.add_argument(
"--classifier-layers",
nargs='+',
type=str,
default=None,
help="list of classifier layers comma-separated with layer's dimension where applicable. For example: linear,64 linear,32 relu",
)
parser.set_defaults(dict_maxexs=0) # skip building dictionary
return parser

Expand All @@ -142,12 +144,71 @@ def upgrade_opt(cls, opt_on_disk):

return opt_on_disk

def _get_layer_parameters(self, prev_dimension, output_dimension):
"""
Parse layer definitions from the input.
"""
layers = []
dimensions = []
no_dimension = -1
for layer in self.classifier_layers:
if ',' in layer:
l, d = layer.split(',')
layers.append(l)
dimensions.append((prev_dimension, int(d)))
prev_dimension = int(d)
else:
layers.append(layer)
dimensions.append(no_dimension)
ind = 0
while (
ind < len(dimensions)
and dimensions[len(dimensions) - ind - 1] == no_dimension
):
ind += 1
if (ind == len(dimensions) and prev_dimension == output_dimension) or (
ind < len(dimensions) and dimensions[ind][1] == output_dimension
):
return layers, dimensions

if ind < len(dimensions):
raise Exception(
f"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}"
)
raise Exception(
f"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}"
)

def _map_layer(self, layer: str, dim=None):
"""
Get torch wrappers for nn layers.
"""
if layer == LINEAR:
return torch.nn.Linear(dim[0], dim[1])
elif layer == RELU:
return torch.nn.ReLU(inplace=False)
raise Exception(
"Unrecognized network layer {}. Available options are: {}".format(
layer, ", ".join(SUPPORTED_LAYERS)
)
)

def build_model(self):
"""
Construct the model.
"""
num_classes = len(self.class_list)
return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes)
bert_model = BertModel.from_pretrained(self.pretrained_path)
if self.classifier_layers is not None:
prev_dimension = bert_model.embeddings.word_embeddings.weight.size(1)
layers, dims = self._get_layer_parameters(
prev_dimension=prev_dimension, output_dimension=num_classes
)
decoders = torch.nn.Sequential()
for l, d in zip(layers, dims):
decoders.append(self._map_layer(l, d))
return BertWrapper(bert_model=bert_model, classifier_layer=decoders)
return BertWrapper(bert_model=bert_model, output_dim=num_classes)

def _set_text_vec(self, *args, **kwargs):
obs = super()._set_text_vec(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions parlai/agents/bert_ranker/README.md
Expand Up @@ -18,10 +18,10 @@ In order to use those agents you need to install pytorch-pretrained-bert (https:

Train a BiEncoder BERT model on ConvAI2:
```bash
parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True
parlai train_model -t convai2 -m bert_ranker/bi_encoder_ranker --batchsize 20 -vtim 30 --model-file /tmp/bert_biencoder_test --data-parallel True
```

Train a CrossEncoder BERT model on ConvAI2:
```bash
parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 --type-optimization all_encoder_layers -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True
parlai train_model -t convai2 -m bert_ranker/cross_encoder_ranker --batchsize 2 -vtim 30 --model-file /tmp/bert_crossencoder_test --data-parallel True
```
51 changes: 29 additions & 22 deletions parlai/agents/bert_ranker/helpers.py
Expand Up @@ -20,7 +20,6 @@
'installed. Install with:\n `pip install transformers`.'
)


import torch


Expand Down Expand Up @@ -63,20 +62,6 @@ def add_common_args(parser):
'multiple gpus. NOTE This is incompatible'
' with distributed training',
)
parser.add_argument(
'--type-optimization',
type=str,
default='all_encoder_layers',
choices=[
'additional_layers',
'top_layer',
'top4_layers',
'all_encoder_layers',
'all',
],
help='Which part of the encoders do we optimize. '
'(Default: all_encoder_layers.)',
)
parser.add_argument(
'--bert-aggregation',
type=str,
Expand All @@ -97,15 +82,26 @@ def add_common_args(parser):
class BertWrapper(torch.nn.Module):
"""
Adds a optional transformer layer and a linear layer on top of BERT.
Args:
bert_model: pretrained BERT model
output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList)
add_transformer_layer: if additional transformer layer should be added on top of the pretrained model
layer_pulled: which layer should be pulled from pretrained model
aggregation: embeddings aggregation (pooling) strategy. Available options are:
(default)"first" - [CLS] representation,
"mean" - average of all embeddings except CLS,
"max" - max of all embeddings except CLS
"""

def __init__(
self,
bert_model,
output_dim,
add_transformer_layer=False,
layer_pulled=-1,
aggregation="first",
bert_model: BertModel,
output_dim: int = -1,
add_transformer_layer: bool = False,
layer_pulled: int = -1,
aggregation: str = "first",
classifier_layer: torch.nn.Module = None,
):
super(BertWrapper, self).__init__()
self.layer_pulled = layer_pulled
Expand All @@ -123,7 +119,18 @@ def __init__(
hidden_act='gelu',
)
self.additional_transformer_layer = BertLayer(config_for_one_layer)
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
if classifier_layer is None and output_dim == -1:
raise Exception(
"Either output dimention or classifier layers must be specified"
)
elif classifier_layer is None:
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
else:
self.additional_linear_layer = classifier_layer
if output_dim != -1:
print(
"Both classifier layer and output dimension are specified. Output dimension parameter is ignored."
)
self.bert_model = bert_model

def forward(self, token_ids, segment_ids, attention_mask):
Expand Down Expand Up @@ -171,7 +178,7 @@ def forward(self, token_ids, segment_ids, attention_mask):
# Sort of hack to make it work with distributed: this way the pooler layer
# is used for grad computation, even though it does not change anything...
# in practice, it just adds a very (768*768) x (768*batchsize) matmul
result += 0 * torch.sum(output_pooler)
result = result + 0 * torch.sum(output_pooler)
return result


Expand Down
@@ -1,5 +1,5 @@
dec_emb_loss: 0.0151884
dec_hid_loss: 0.662957
dec_hid_loss: 0.662956
dec_self_attn_loss: 497.628
enc_dec_attn_loss: 230.709
enc_emb_loss: 0.0109334
Expand Down
Expand Up @@ -6,5 +6,5 @@ enc_emb_loss: 0.00210945
enc_hid_loss: 0.279337
enc_loss: 0.284342
enc_self_attn_loss: 371.567
loss: 11.7943
loss: 11.7944
pred_loss: 6.80161
30 changes: 29 additions & 1 deletion tests/nightly/gpu/test_bert.py
Expand Up @@ -41,13 +41,41 @@ def test_crossencoder(self):
batchsize=2,
learningrate=1e-3,
gradient_clip=1.0,
type_optimization="all_encoder_layers",
text_truncate=8,
label_truncate=8,
)
)
self.assertGreaterEqual(test['accuracy'], 0.8)

def test_bertclassifier(self):
valid, test = testing_utils.train_model(
dict(
task='integration_tests:classifier',
model='bert_classifier/bert_classifier',
num_epochs=2,
batchsize=2,
learningrate=1e-2,
gradient_clip=1.0,
classes=["zero", "one"],
)
)
self.assertGreaterEqual(test['accuracy'], 0.9)

def test_bertclassifier_with_relu(self):
valid, test = testing_utils.train_model(
dict(
task='integration_tests:classifier',
model='bert_classifier/bert_classifier',
num_epochs=2,
batchsize=2,
learningrate=1e-2,
gradient_clip=1.0,
classes=["zero", "one"],
classifier_layers=["linear,64", "linear,2", "relu"],
)
)
self.assertGreaterEqual(test['accuracy'], 0.9)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1628d8c

Please sign in to comment.