This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
Extend BERT-based classification with customized layers #4553
Merged
Merged
Changes from 1 commit
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d99616e
Extend BERT-based classification with customized layers
Golovneva ef1c2c0
fix bugs and add tests
Golovneva 4fa0d08
increase lr to improve training stability
Golovneva 6f96100
upgrading torch version
Golovneva 57d4aef
adjusting loss value
Golovneva File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
|
||
""" | ||
BERT classifier agent uses bert embeddings to make an utterance-level classification. | ||
|
||
This implementation allows to customize classifier layers with input arguments. | ||
""" | ||
|
||
import os | ||
|
@@ -32,6 +34,12 @@ | |
) | ||
|
||
|
||
LINEAR = "linear" | ||
RELU = "relu" | ||
|
||
SUPPORTED_LAYERS = [LINEAR, RELU] | ||
|
||
|
||
class BertClassifierHistory(History): | ||
""" | ||
Handles tokenization history. | ||
|
@@ -72,6 +80,7 @@ def __init__(self, opt, shared=None): | |
opt["pretrained_path"] = self.pretrained_path | ||
self.add_cls_token = opt.get("add_cls_token", True) | ||
self.sep_last_utt = opt.get("sep_last_utt", False) | ||
self.classifier_layers = opt.get("classifier_layers", None) | ||
super().__init__(opt, shared) | ||
|
||
@classmethod | ||
|
@@ -90,20 +99,6 @@ def add_cmdline_args( | |
""" | ||
super().add_cmdline_args(parser, partial_opt=partial_opt) | ||
parser = parser.add_argument_group("BERT Classifier Arguments") | ||
parser.add_argument( | ||
"--type-optimization", | ||
type=str, | ||
default="all_encoder_layers", | ||
choices=[ | ||
"additional_layers", | ||
"top_layer", | ||
"top4_layers", | ||
"all_encoder_layers", | ||
"all", | ||
], | ||
help="which part of the encoders do we optimize " | ||
"(defaults to all layers)", | ||
) | ||
parser.add_argument( | ||
"--add-cls-token", | ||
type="bool", | ||
|
@@ -117,6 +112,13 @@ def add_cmdline_args( | |
help="separate the last utterance into a different" | ||
"segment with [SEP] token in between", | ||
) | ||
parser.add_argument( | ||
"--classifier-layers", | ||
nargs='+', | ||
type=str, | ||
default=None, | ||
help="list of classifier layers comma-separated with layer's dimension where applicable. For example: linear,64 linear,32 relu", | ||
) | ||
parser.set_defaults(dict_maxexs=0) # skip building dictionary | ||
return parser | ||
|
||
|
@@ -142,12 +144,71 @@ def upgrade_opt(cls, opt_on_disk): | |
|
||
return opt_on_disk | ||
|
||
def _get_layer_parameters(self, prev_dimension, output_dimension): | ||
""" | ||
Parse layer definitions from the input. | ||
""" | ||
layers = [] | ||
dimensions = [] | ||
no_dimension = -1 | ||
for layer in self.classifier_layers: | ||
if ',' in layer: | ||
l, d = layer.split(',') | ||
layers.append(l) | ||
dimensions.append((prev_dimension, int(d))) | ||
prev_dimension = int(d) | ||
else: | ||
layers.append(layer) | ||
dimensions.append(no_dimension) | ||
ind = 0 | ||
while ( | ||
ind < len(dimensions) | ||
and dimensions[len(dimensions) - ind - 1] == no_dimension | ||
): | ||
ind += 1 | ||
if (ind == len(dimensions) and prev_dimension == output_dimension) or ( | ||
ind < len(dimensions) and dimensions[ind][1] == output_dimension | ||
): | ||
return layers, dimensions | ||
|
||
if ind < len(dimensions): | ||
raise Exception( | ||
"Output layer's dimension does not match number of classes. Found {dimensions[ind][1]}, expected {output_dimension}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: think you're missing |
||
) | ||
raise Exception( | ||
"Output layer's dimension does not match number of classes. Found {prev_dimension}, expected {output_dimension}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
) | ||
|
||
def _map_layer(self, layer: str, dim=None): | ||
""" | ||
Get torch wrappers for nn layers. | ||
""" | ||
if layer == LINEAR: | ||
return torch.nn.Linear(dim[0], dim[1]) | ||
elif layer == RELU: | ||
return torch.nn.ReLU(inplace=False) | ||
raise Exception( | ||
"Unrecognized network layer {}. Available options are: {}".format( | ||
layer, ", ".join(SUPPORTED_LAYERS) | ||
) | ||
) | ||
|
||
def build_model(self): | ||
""" | ||
Construct the model. | ||
""" | ||
num_classes = len(self.class_list) | ||
return BertWrapper(BertModel.from_pretrained(self.pretrained_path), num_classes) | ||
bert_model = BertModel.from_pretrained(self.pretrained_path) | ||
if self.classifier_layers is not None: | ||
prev_dimension = bert_model.embeddings.word_embeddings.weight.size(1) | ||
layers, dims = self._get_layer_parameters( | ||
prev_dimension=prev_dimension, output_dimension=num_classes | ||
) | ||
decoders = torch.nn.Sequential() | ||
for l, d in zip(layers, dims): | ||
decoders.append(self._map_layer(l, d)) | ||
return BertWrapper(bert_model=bert_model, classifier_layer=decoders) | ||
return BertWrapper(bert_model=bert_model, output_dim=num_classes) | ||
|
||
def _set_text_vec(self, *args, **kwargs): | ||
obs = super()._set_text_vec(*args, **kwargs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ | |
'installed. Install with:\n `pip install transformers`.' | ||
) | ||
|
||
|
||
import torch | ||
|
||
|
||
|
@@ -97,15 +96,26 @@ def add_common_args(parser): | |
class BertWrapper(torch.nn.Module): | ||
""" | ||
Adds a optional transformer layer and a linear layer on top of BERT. | ||
Args: | ||
bert_model: pretrained BERT model | ||
output_dim: dimension of the output layer for defult 1 linear layer classifier. Either output_dim or classifier_layer must be specified | ||
classifier_layer: classification layers, can be a signle layer, or list of layers (for ex, ModuleList) | ||
add_transformer_layer: if additional transformer layer should be added on top of the pretrained model | ||
layer_pulled: which layer should be pulled from pretrained model | ||
aggregation: embeddings aggregation (pooling) strategy. Available options are: | ||
(default)"first" - [CLS] representation, | ||
"mean" - average of all embeddings except CLS, | ||
"max" - max of all embeddings except CLS | ||
""" | ||
|
||
def __init__( | ||
self, | ||
bert_model, | ||
output_dim, | ||
add_transformer_layer=False, | ||
layer_pulled=-1, | ||
aggregation="first", | ||
bert_model: BertModel, | ||
output_dim: int = -1, | ||
classifier_layer: torch.nn.Module = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: could you please move this arg to be the last one? so that prior calls to this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed issues and added tests |
||
add_transformer_layer: bool = False, | ||
layer_pulled: int = -1, | ||
aggregation: str = "first", | ||
): | ||
super(BertWrapper, self).__init__() | ||
self.layer_pulled = layer_pulled | ||
|
@@ -123,7 +133,18 @@ def __init__( | |
hidden_act='gelu', | ||
) | ||
self.additional_transformer_layer = BertLayer(config_for_one_layer) | ||
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim) | ||
if classifier_layer is None and output_dim == -1: | ||
raise Exception( | ||
"Either output dimention or classifier layers must be specified" | ||
) | ||
elif classifier_layer is None: | ||
self.additional_linear_layer = torch.nn.Linear(bert_output_dim, output_dim) | ||
else: | ||
self.additional_linear_layer = classifier_layer | ||
if output_dim != -1: | ||
print( | ||
"Both classifier layer and output dimension are specified. Output dimension parameter is ignored." | ||
) | ||
self.bert_model = bert_model | ||
|
||
def forward(self, token_ids, segment_ids, attention_mask): | ||
|
@@ -171,7 +192,7 @@ def forward(self, token_ids, segment_ids, attention_mask): | |
# Sort of hack to make it work with distributed: this way the pooler layer | ||
# is used for grad computation, even though it does not change anything... | ||
# in practice, it just adds a very (768*768) x (768*batchsize) matmul | ||
result += 0 * torch.sum(output_pooler) | ||
result = result + 0 * torch.sum(output_pooler) | ||
return result | ||
|
||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
was this option just never used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right. I actually also found it in BertWrapper's add_common_args function, also never used, so I'll remove it from there