-
Notifications
You must be signed in to change notification settings - Fork 515
/
models.py
256 lines (199 loc) 路 7.37 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
Models module
"""
import os
import torch
from transformers import (
AutoConfig,
AutoModel,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
from .onnx import OnnxModel
class Models:
"""
Utility methods for working with machine learning models
"""
@staticmethod
def checklength(config, tokenizer):
"""
Checks the length for a Hugging Face Transformers tokenizer using a Hugging Face Transformers config. Copies the
max_position_embeddings parameter if the tokenizer has no max_length set. This helps with backwards compatibility
with older tokenizers.
Args:
config: transformers config
tokenizer: transformers tokenizer
"""
# Unpack nested config, handles passing model directly
if hasattr(config, "config"):
config = config.config
if (
hasattr(config, "max_position_embeddings")
and tokenizer
and hasattr(tokenizer, "model_max_length")
and tokenizer.model_max_length == int(1e30)
):
tokenizer.model_max_length = config.max_position_embeddings
@staticmethod
def maxlength(config, tokenizer):
"""
Gets the best max length to use for generate calls. This method will return config.max_length if it's set. Otherwise, it will return
tokenizer.model_max_length.
Args:
config: transformers config
tokenizer: transformers tokenizer
"""
# Unpack nested config, handles passing model directly
if hasattr(config, "config"):
config = config.config
# Get non-defaulted fields
keys = config.to_diff_dict()
# Use config.max_length if not set to default value, else use tokenizer.model_max_length if available
return config.max_length if "max_length" in keys or not hasattr(tokenizer, "model_max_length") else tokenizer.model_max_length
@staticmethod
def deviceid(gpu):
"""
Translates input gpu argument into a device id.
Args:
gpu: True/False if GPU should be enabled, also supports a device id/string/instance
Returns:
device id
"""
# Return if this is already a torch device
# pylint: disable=E1101
if isinstance(gpu, torch.device):
return gpu
# Always return -1 if gpu is None or an accelerator device is unavailable
if gpu is None or not Models.hasaccelerator():
return -1
# Default to device 0 if gpu is True and not otherwise specified
if isinstance(gpu, bool):
return 0 if gpu else -1
# Return gpu as device id if gpu flag is an int
return int(gpu)
@staticmethod
def device(deviceid):
"""
Gets a tensor device.
Args:
deviceid: device id
Returns:
tensor device
"""
# Torch device
# pylint: disable=E1101
return deviceid if isinstance(deviceid, torch.device) else torch.device(Models.reference(deviceid))
@staticmethod
def reference(deviceid):
"""
Gets a tensor device reference.
Args:
deviceid: device id
Returns:
device reference
"""
return (
deviceid
if isinstance(deviceid, str)
else "cpu"
if deviceid < 0
else f"cuda:{deviceid}"
if torch.cuda.is_available()
else "mps"
if Models.hasmpsdevice()
else Models.finddevice()
)
@staticmethod
def hasaccelerator():
"""
Checks if there is an accelerator device available.
Returns:
True if an accelerator device is available, False otherwise
"""
return torch.cuda.is_available() or Models.hasmpsdevice() or bool(Models.finddevice())
@staticmethod
def hasmpsdevice():
"""
Checks if there is a MPS device available.
Returns:
True if a MPS device is available, False otherwise
"""
return os.environ.get("PYTORCH_MPS_DISABLE") != "1" and torch.backends.mps.is_available()
@staticmethod
def finddevice():
"""
Attempts to find an alternative accelerator device.
Returns:
name of first alternative accelerator available or None if not found
"""
return next((device for device in ["xpu"] if hasattr(torch, device) and getattr(torch, device).is_available()), None)
@staticmethod
def load(path, config=None, task="default", modelargs=None):
"""
Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers).
Args:
path: path to model
config: path to model configuration
task: task name used to lookup model type
Returns:
machine learning model
"""
# Detect ONNX models
if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)):
return OnnxModel(path, config)
# Return path, if path isn't a string
if not isinstance(path, str):
return path
# Transformer models
models = {
"default": AutoModel.from_pretrained,
"question-answering": AutoModelForQuestionAnswering.from_pretrained,
"summarization": AutoModelForSeq2SeqLM.from_pretrained,
"text-classification": AutoModelForSequenceClassification.from_pretrained,
"zero-shot-classification": AutoModelForSequenceClassification.from_pretrained,
}
# Pass modelargs as keyword arguments
modelargs = modelargs if modelargs else {}
# Load model for supported tasks. Return path for unsupported tasks.
return models[task](path, **modelargs) if task in models else path
@staticmethod
def tokenizer(path, **kwargs):
"""
Loads a tokenizer from path.
Args:
path: path to tokenizer
kwargs: optional additional keyword arguments
Returns:
tokenizer
"""
return AutoTokenizer.from_pretrained(path, **kwargs) if isinstance(path, str) else path
@staticmethod
def task(path, **kwargs):
"""
Attempts to detect the model task from path.
Args:
path: path to model
kwargs: optional additional keyword arguments
Returns:
inferred model task
"""
# Get model configuration
config = None
if isinstance(path, (list, tuple)) and hasattr(path[0], "config"):
config = path[0].config
elif isinstance(path, str):
config = AutoConfig.from_pretrained(path, **kwargs)
# Attempt to resolve task using configuration
task = None
if config:
architecture = config.architectures[0] if config.architectures else None
if architecture:
if any(x for x in ["LMHead", "CausalLM"] if x in architecture):
task = "language-generation"
elif "QuestionAnswering" in architecture:
task = "question-answering"
elif "ConditionalGeneration" in architecture:
task = "sequence-sequence"
return task