-
Notifications
You must be signed in to change notification settings - Fork 400
/
model.py
118 lines (88 loc) · 4.49 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Implements a BERT wrapper around a :class:`.ComposerTransformer`."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
import torch
from torchmetrics import MeanSquaredError, Metric, MetricCollection
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef
from torchmetrics.regression.spearman import SpearmanCorrCoef
from composer.metrics.nlp import BinaryF1Score, LanguageCrossEntropy, MaskedAccuracy
from composer.models.transformer_shared import ComposerTransformer
if TYPE_CHECKING:
import transformers
from composer.core.types import Batch
__all__ = ['BERTModel']
class BERTModel(ComposerTransformer):
"""BERT model based on |:hugging_face:| Transformers.
For more information, see `Transformers <https://huggingface.co/transformers/>`_.
Args:
module (transformers.BertModel): An instance of BertModel that
contains the forward pass function.
config (transformers.BertConfig): The BertConfig object that
stores information about the model hyperparameters.
tokenizer (transformers.BertTokenizer): An instance of BertTokenizer. Necessary to process model inputs.
To create a BERT model for Language Model pretraining:
.. testcode::
from composer.models import BERTModel
import transformers
config = transformers.BertConfig()
hf_model = transformers.BertLMHeadModel(config=config)
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
model = BERTModel(module=hf_model, config=config, tokenizer=tokenizer)
"""
def __init__(self,
module: transformers.BertModel,
config: transformers.BertConfig,
tokenizer: Optional[transformers.BertTokenizer] = None) -> None:
if tokenizer is None:
model_inputs = {'input_ids', 'attention_mask', 'token_type_ids'}
else:
model_inputs = set(tokenizer.model_input_names)
super().__init__(
module=module, #type: ignore (thirdparty)
config=config,
model_inputs=model_inputs)
# we're going to remove the label from the expected inputs
# since we will handle metric calculation with TorchMetrics instead of HuggingFace.
self.model_inputs.remove('labels')
# When using Evaluators, the validation metrics represent all possible
# validation metrics that can be used with the bert model
# The Evaluator class checks if it's metrics are in the models validation metrics
ignore_index = -100
self.val_metrics = [
Accuracy(),
MeanSquaredError(),
SpearmanCorrCoef(),
BinaryF1Score(),
MatthewsCorrCoef(num_classes=config.num_labels),
LanguageCrossEntropy(ignore_index=ignore_index, vocab_size=config.num_labels),
MaskedAccuracy(ignore_index=ignore_index),
]
self.train_metrics = []
def loss(self, outputs: Mapping, batch: Batch) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
if outputs.get('loss', None) is not None:
return outputs['loss']
else:
raise NotImplementedError('Calculating loss directly not supported yet.')
def validate(self, batch: Any) -> Any:
"""Runs the validation step.
Args:
batch (Dict): a dictionary of Dict[str, Tensor] of inputs
that the model expects, as found in :meth:`.ComposerTransformer.get_model_inputs`.
Returns:
tuple (Tensor, Tensor): with the output from the forward pass and the correct labels.
This is fed into directly into the output of :meth:`.ComposerModel.metrics`.
"""
assert self.training is False, 'For validation, model must be in eval mode'
# temporary hack until eval on multiple datasets is finished
labels = batch.pop('labels')
output = self.forward(batch)
output = output['logits']
# if we are in the single class case, then remove the classes dimension
if output.shape[1] == 1:
output = output.squeeze(dim=1)
return output, labels
def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]:
return MetricCollection(self.train_metrics) if train else MetricCollection(self.val_metrics)