-
Notifications
You must be signed in to change notification settings - Fork 462
/
Copy pathmixin.py
357 lines (325 loc) · 17.4 KB
/
mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import inspect
import os
import shutil
import time
from copy import copy
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import safetensors
import torch
import torch.nn as nn
import transformers
from datasets import Dataset as HfDataset
from modelscope import check_local_model_is_latest
from packaging import version
from peft import PeftModel
from torch.nn import Module
from transformers import PreTrainedModel
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_torch_npu_available
from swift.hub import get_hub
from swift.llm import Template
from swift.plugin import MeanMetric, compute_acc, extra_tuners
from swift.tuners import SwiftModel
from swift.utils import get_logger, is_mp_ddp, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from .arguments import TrainingArguments
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model
try:
from trl import AutoModelForCausalLMWithValueHead
except (ImportError, RuntimeError):
AutoModelForCausalLMWithValueHead = None
logger = get_logger()
class SwiftMixin:
def __init__(self,
model: Union[PreTrainedModel, Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[HfDataset] = None,
eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
template: Optional[Template] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
**kwargs) -> None:
if args.check_model and hasattr(model, 'model_dir'):
check_local_model_is_latest(
model.model_dir, user_agent={
'invoked_by': 'local_trainer',
'third_party': 'swift',
})
self._custom_metrics = {}
self.template = template
self.max_memory = 0
self.hub = get_hub()
if args.sequence_parallel_size > 1:
from swift.trainers.xtuner import init_sequence_parallel_xtuner
init_sequence_parallel_xtuner(args.sequence_parallel_size)
self.model_meta = model.model_meta
with self.hub.patch_hub():
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=template.tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
**kwargs)
self.compute_loss_func = compute_loss_func
if get_function(model.__class__.forward) is not get_function(model.forward):
self.label_names = find_labels(model) or ['labels']
self.can_return_loss = can_return_loss(model)
self.start_time = time.time()
def _save_initial_model(self, output_dir):
# pissa/olora/lora-ga
model = unwrap_model(self.model)
if isinstance(model, PeftModel):
config = model.peft_config.get('default')
init_lora_weights = getattr(config, 'init_lora_weights', None)
if (isinstance(init_lora_weights, str)
and any(s in init_lora_weights for s in ('pissa', 'olora', 'lora-ga'))):
config.init_lora_weights = True
model.save_pretrained(os.path.join(output_dir, 'initial_model'))
config.init_lora_weights = init_lora_weights
def _save_converted_model(self, output_dir):
# pissa/olora/lora-ga
model = unwrap_model(self.model)
if isinstance(model, PeftModel):
config = model.peft_config.get('default')
init_lora_weights = getattr(config, 'init_lora_weights', None)
if isinstance(init_lora_weights, str):
config = copy(config)
os.makedirs(os.path.join(output_dir, 'converted'), exist_ok=True)
if 'lora-ga' in init_lora_weights:
try:
from lora_ga.entrypoint import LoraGAContext
with LoraGAContext(model):
model.save_pretrained(
os.path.join(output_dir, 'converted', 'default'),
path_initial_model_for_weight_conversion=os.path.join(
os.path.dirname(output_dir), 'initial_model'),
)
model.peft_config['default'] = config
except ImportError as e:
error_message = """
Since 'LoRA-GA' is not implemented by PEFT, you will need to install it directly from GitHub.
Command: 'pip install git+https://github.com/lxline/LoRA-GA.git'.
"""
logger.info(error_message)
raise RuntimeError(error_message) from e
elif 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
model.save_pretrained(
os.path.join(output_dir, 'converted', 'default'),
path_initial_model_for_weight_conversion=os.path.join(
os.path.dirname(output_dir), 'initial_model'),
)
model.peft_config['default'] = config
def _load_optimizer_and_scheduler(self, *args, **kwargs):
super()._load_optimizer_and_scheduler(*args, **kwargs)
if is_mp_ddp():
# fix mp+ddp adamw
for v in self.optimizer.state.values():
if 'step' in v:
# not on the same device
device_set = set([t.device for t in v.values()]) - {v['step'].device, torch.device('cpu')}
if len(device_set) >= 1:
v['step'] = v['step'].to('cpu')
def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
# model
supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
if AutoModelForCausalLMWithValueHead is not None:
supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
save_safetensors = self.args.save_safetensors
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
_unwrap_model = unwrap_model(self.model)
if isinstance(_unwrap_model, supported_classes):
_unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
else:
logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
if save_safetensors:
safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
else:
torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
elif AutoModelForCausalLMWithValueHead and isinstance(self.model, AutoModelForCausalLMWithValueHead):
# save reward model
state_dict = self.model.state_dict()
decoder_state_dict, v_head_state_dict = {}, {}
for name, param in state_dict.items():
if name.startswith('v_head.'):
v_head_state_dict[name] = param
else:
decoder_state_dict[name.replace('pretrained_model.', '', 1)] = param
self.model.pretrained_model.save_pretrained(
output_dir, state_dict=decoder_state_dict or None, safe_serialization=save_safetensors)
if save_safetensors:
from safetensors.torch import save_file
save_file(
v_head_state_dict, os.path.join(output_dir, 'value_head.safetensors'), metadata={'format': 'pt'})
else:
torch.save(v_head_state_dict, os.path.join(output_dir, 'value_head.bin'))
elif is_instance_of_ms_model(self.model):
PreTrainedModel.save_pretrained(
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
elif self.args.train_type in extra_tuners:
extra_tuners[self.args.train_type].save_pretrained(
self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
else:
self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
"""Compatible with swift and peft"""
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
self._save_model(output_dir, state_dict)
# training_args.bin
torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
self._save_converted_model(output_dir)
# args.json
args_path = os.path.join(os.path.dirname(output_dir), 'args.json')
if os.path.exists(args_path):
shutil.copy(args_path, os.path.join(output_dir, 'args.json'))
# predict.jsonl
predict_jsonl = os.path.join(os.path.dirname(output_dir), 'predict.jsonl')
if os.path.exists(predict_jsonl):
shutil.move(predict_jsonl, os.path.join(output_dir, 'predict.jsonl'))
is_adapter = isinstance(self.model, (SwiftModel, PeftModel))
# tokenizer
if not is_adapter:
from swift.llm import save_checkpoint
additional_saved_files = self.model_meta.additional_saved_files
save_checkpoint(None, self.template.processor, output_dir, additional_saved_files=additional_saved_files)
def _fix_zero3_gather_all_parameters(self) -> None:
if is_deepspeed_zero3_enabled() and not hasattr(self.deepspeed, '_zero3_consolidated_16bit_state_dict_origin'):
parameters = inspect.signature(self.deepspeed._zero3_consolidated_16bit_state_dict).parameters
if 'exclude_frozen_parameters' in parameters:
def _zero3_consolidated_16bit_state_dict(model, exclude_frozen_parameters=False):
unwrapped = unwrap_model(model)
exclude_frozen_parameters = False
if isinstance(unwrapped, SwiftModel) and unwrapped.has_additional_modules:
exclude_frozen_parameters = True
if isinstance(unwrapped, PeftModel):
exclude_frozen_parameters = True
return model._zero3_consolidated_16bit_state_dict_origin(exclude_frozen_parameters)
self.deepspeed._zero3_consolidated_16bit_state_dict_origin = (
self.deepspeed._zero3_consolidated_16bit_state_dict)
self.deepspeed._zero3_consolidated_16bit_state_dict = MethodType(_zero3_consolidated_16bit_state_dict,
self.deepspeed)
def _save_checkpoint(self, *args, **kwargs):
self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
self._fix_zero3_gather_all_parameters()
result = super()._save_checkpoint(*args, **kwargs)
logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
return result
def train(self, *args, **kwargs):
if self.model_meta.is_multimodal:
models = list(
set([
v for k, v in self.__dict__.items()
if isinstance(v, nn.Module) and k in {'model', 'ref_model', 'reward_model', 'value_model'}
]))
self.template.register_post_encode_hook(models)
logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}')
self._save_initial_model(self.args.output_dir)
with self.hub.patch_hub():
res = super().train(*args, **kwargs)
self.template.remove_post_encode_hook()
return res
def push_to_hub(self, *args, **kwargs):
with self.hub.patch_hub():
return super().push_to_hub(*args, **kwargs)
def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float:
if device is None:
mems = [torch.cuda.max_memory_reserved(device=device) for device in range(torch.cuda.device_count())]
else:
mems = [torch.cuda.max_memory_reserved(device=device)]
mem = sum(mems) / 1024**3
self.max_memory = max(self.max_memory, mem)
return mem
def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
self.control.should_log = False
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
loss = tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged)
logs: Dict[str, float] = {'loss': loss} # loss first
for k, metric in self._custom_metrics.items():
value = metric.compute()
if len(value) == 1:
val = list(value.values())[0]
logs[k] = val
else:
for k_suffix, val in value.items():
new_k = f'{k}_{k_suffix}'
logs[new_k] = val
metric.reset()
if version.parse(transformers.__version__) >= version.parse('4.38'):
grad_norm = args[0]
if grad_norm is not None:
logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs['learning_rate'] = self._get_learning_rate()
if not is_torch_npu_available():
logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2)
elapse_time = time.time() - self.start_time
logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6)
for k in list(logs.keys()):
if logs[k] is None:
logs.pop(k)
tr_loss -= tr_loss
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)
def create_optimizer_and_scheduler(self, num_training_steps: int):
if self.args.optimizer is not None:
from swift.plugin import optimizers_map
optimizer_callback = optimizers_map[self.args.optimizer]
self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
if self.optimizer is None:
self.create_optimizer()
if self.lr_scheduler is None:
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
else:
super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.train_sampler_random:
return super()._get_train_sampler()
else:
return self._get_eval_sampler(self.train_dataset)
def get_train_dataloader(self):
if self.args.sequence_parallel_size == 1:
return super().get_train_dataloader()
else:
from swift.trainers.xtuner import get_xtuner_train_dataloader
return get_xtuner_train_dataloader(self)
def _compute_acc(self, outputs, labels) -> None:
args = self.args
acc_steps = args.acc_steps
preds = outputs.logits.argmax(dim=-1)
if self.state.global_step % acc_steps == 0:
if use_torchacc():
ta_trim_graph()
preds = preds.to('cpu')
labels = labels.to('cpu')
metrics = compute_acc(
preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=args.is_encoder_decoder)
for k, v in metrics.items():
if k not in self._custom_metrics:
self._custom_metrics[k] = MeanMetric(nan_value=None)
self._custom_metrics[k].update(v)