Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Extend BERT-based classification with customized layers #4553

Merged
merged 5 commits into from May 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 76 additions & 15 deletions parlai/agents/bert_classifier/bert_classifier.py
Expand Up @@ -6,6 +6,8 @@

"""
BERT classifier agent uses bert embeddings to make an utterance-level classification.

This implementation allows to customize classifier layers with input arguments.
"""

import os
Expand All @@ -32,6 +34,12 @@
)


LINEAR = "linear"
RELU = "relu"

SUPPORTED_LAYERS = [LINEAR, RELU]


class BertClassifierHistory(History):
"""
Handles tokenization history.
Expand Down Expand Up @@ -72,6 +80,7 @@ def __init__(self, opt, shared=None):
opt["pretrained_path"] = self.pretrained_path
self.add_cls_token = opt.get("add_cls_token", True)
self.sep_last_utt = opt.get("sep_last_utt", False)
self.classifier_layers = opt.get("classifier_layers", None)
super().__init__(opt, shared)

@classmethod
Expand All @@ -90,20 +99,6 @@ def add_cmdline_args(
"""
super().add_cmdline_args(parser, partial_opt=partial_opt)
parser = parser.add_argument_group("BERT Classifier Arguments")
parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this option just never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I actually also found it in BertWrapper's add_common_args function, also never used, so I'll remove it from there

"--type-optimization",
type=str,
default="all_encoder_layers",
choices=[
"additional_layers",
"top_layer",
"top4_layers",
"all_encoder_layers",
"all",
],
help="which part of the encoders do we optimize "
"(defaults to all layers)",
)
parser.add_argument(
"--add-cls-token",
type="bool",
Expand All @@ -117,6 +112,13 @@ def add_cmdline_args(
help="separate the last utterance into a different"
"segment with [SEP] token in between",
)
parser.add_argument(
"--classifier-layers",
nargs='+',
type=str,
default=None,
help="list of classifier layers comma-separated with layer's dimension where applicable. For example: linear,64 linear,32 relu",
)
parser.set_defaults(dict_maxexs=0) # skip building dictionary
return parser

Expand All @@ -142,12 +144,71 @@ def upgrade_opt(cls, opt_on_disk):

return opt_on_disk

def _get_layer_parameters(self, prev_dimension, output_dimension):
"""
Parse layer definitions from the input.
"""
layers = []
dimensions = []
no_dimension = -1
for layer in self.classifier_layers:
if ',' in layer:
l, d = layer.split(',')
layers.append(l)
dimensions.append((prev_dimension, int(d)))
prev_dimension = int(d)
else:
layers.append(layer)
dimensions.append(no_dimension)
ind = 0
while (
ind < len(dimensions)
and dimensions[len(dimensions) - ind - 1] == no_dimension
):
ind += 1
if (ind == len(dimensions) and prev_dimension == output_dimension) or (
ind < len(dimensions) and dimensions[ind][1] == output_dimension
):
return layers, dimensions

if ind < len(dimensions):
raise Exception(
"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: think you're missing f"" string here

)
raise Exception(
"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)

def _map_layer(self, layer: str, dim=None):
"""
Get torch wrappers for nn layers.
"""
if layer == LINEAR:
return torch.nn.Linear(dim[0], dim[1])
elif layer == RELU:
return torch.nn.ReLU(inplace=False)
raise Exception(
"Unrecognized network layer {}. Available options are: {}".format(
layer, ", ".join(SUPPORTED_LAYERS)
)
)

def build_model(self):
"""
Construct the model.
"""
num_classes = len(self.class_list)
return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes)
bert_model = BertModel.from_pretrained(self.pretrained_path)
if self.classifier_layers is not None:
prev_dimension = bert_model.embeddings.word_embeddings.weight.size(1)
layers, dims = self._get_layer_parameters(
prev_dimension=prev_dimension, output_dimension=num_classes
)
decoders = torch.nn.Sequential()
for l, d in zip(layers, dims):
decoders.append(self._map_layer(l, d))
return BertWrapper(bert_model=bert_model, classifier_layer=decoders)
return BertWrapper(bert_model=bert_model, output_dim=num_classes)

def _set_text_vec(self, *args, **kwargs):
obs = super()._set_text_vec(*args, **kwargs)
Expand Down
37 changes: 29 additions & 8 deletions parlai/agents/bert_ranker/helpers.py
Expand Up @@ -20,7 +20,6 @@
'installed. Install with:\n `pip install transformers`.'
)


import torch


Expand Down Expand Up @@ -97,15 +96,26 @@ def add_common_args(parser):
class BertWrapper(torch.nn.Module):
"""
Adds a optional transformer layer and a linear layer on top of BERT.
Args:
bert_model: pretrained BERT model
output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList)
add_transformer_layer: if additional transformer layer should be added on top of the pretrained model
layer_pulled: which layer should be pulled from pretrained model
aggregation: embeddings aggregation (pooling) strategy. Available options are:
(default)"first" - [CLS] representation,
"mean" - average of all embeddings except CLS,
"max" - max of all embeddings except CLS
"""

def __init__(
self,
bert_model,
output_dim,
add_transformer_layer=False,
layer_pulled=-1,
aggregation="first",
bert_model: BertModel,
output_dim: int = -1,
classifier_layer: torch.nn.Module = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you please move this arg to be the last one? so that prior calls to this __init__ don't fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed issues and added tests

add_transformer_layer: bool = False,
layer_pulled: int = -1,
aggregation: str = "first",
):
super(BertWrapper, self).__init__()
self.layer_pulled = layer_pulled
Expand All @@ -123,7 +133,18 @@ def __init__(
hidden_act='gelu',
)
self.additional_transformer_layer = BertLayer(config_for_one_layer)
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
if classifier_layer is None and output_dim == -1:
raise Exception(
"Either output dimention or classifier layers must be specified"
)
elif classifier_layer is None:
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim)
else:
self.additional_linear_layer = classifier_layer
if output_dim != -1:
print(
"Both classifier layer and output dimension are specified. Output dimension parameter is ignored."
)
self.bert_model = bert_model

def forward(self, token_ids, segment_ids, attention_mask):
Expand Down Expand Up @@ -171,7 +192,7 @@ def forward(self, token_ids, segment_ids, attention_mask):
# Sort of hack to make it work with distributed: this way the pooler layer
# is used for grad computation, even though it does not change anything...
# in practice, it just adds a very (768*768) x (768*batchsize) matmul
result += 0 * torch.sum(output_pooler)
result = result + 0 * torch.sum(output_pooler)
return result


Expand Down