Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import json
import numpy as np
import torch
import torch.distributed as dist
from modelscope import BitsAndBytesConfig, GenerationConfig
from transformers import IntervalStrategy
from transformers.integrations import is_deepspeed_zero3_enabled
Expand All @@ -27,17 +26,6 @@
print_example, set_generation_config, sort_by_max_length,
stat_dataset)

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import *
# datasets is required in Xtuner
from datasets import Dataset
from xtuner.dataset.huggingface import pack_dataset
SUPPORT_XTUNER = True
except ImportError:
pass

logger = get_logger()


Expand Down Expand Up @@ -208,25 +196,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
dataset_info['train_dataset'] = stat_dataset(train_dataset)
if val_dataset is not None:
dataset_info['val_dataset'] = stat_dataset(val_dataset)
if args.pack_to_max_length:
assert SUPPORT_XTUNER, \
('Please install XTuner first to pack dataset to `max_length`.'
'`pip install -U \'xtuner[deepspeed]\'`')
if dist.get_rank() == 0:
ds = [i[0] for i in train_dataset.data]
train_dataset = Dataset.from_list(ds)
train_dataset = pack_dataset(
train_dataset,
max_length=args.max_length,
use_varlen_attn=False,
shuffle_before_pack=True,
map_num_proc=16)
objects = [train_dataset]
train_dataset.save_to_disk('alpaca_pack')
else:
objects = [None]
dist.broadcast_object_list(objects, src=0)
train_dataset = objects[0]
else:
dataset_info = None
td0, tkwargs0 = template.encode(train_dataset[0])
Expand Down Expand Up @@ -267,7 +236,6 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
trainer_kwargs['check_model'] = False

trainer = Seq2SeqTrainer(
sequence_parallel_size=args.sequence_parallel_size,
model=model,
args=training_args,
data_collator=data_collator,
Expand Down
12 changes: 0 additions & 12 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,6 @@
from .utils import (SftArguments, find_all_linears, find_embedding, find_ln,
is_adapter)

SUPPORT_XTUNER = False

try:
from xtuner.model.modules.dispatch import dispatch_modules
from xtuner.parallel.sequence import *
SUPPORT_XTUNER = True
except ImportError:
pass

logger = get_logger()


Expand Down Expand Up @@ -208,9 +199,6 @@ def prepare_model(model, args: SftArguments):
model.load_state_dict(state_dict, False)
# release memory
del state_dict
if SUPPORT_XTUNER:
dispatch_modules(model)
logger.info('Dispatch modules for sequence parallel.')
else:
raise ValueError(f'args.sft_type: {args.sft_type}')

Expand Down
4 changes: 0 additions & 4 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,6 @@ class SftArguments(ArgumentsBase):
# fsdp config file
fsdp_config: Optional[str] = None

# xtuner config
sequence_parallel_size: int = 1
pack_to_max_length: bool = False

def handle_dataset_mixture(self, train_dataset: HfDataset) -> None:
if train_dataset is None:
return train_dataset
Expand Down
45 changes: 0 additions & 45 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,6 @@
from swift.torchacc_utils import pad_and_split_batch
from swift.utils import get_dist_setting, use_torchacc

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import (pad_for_sequence_parallel,
split_for_sequence_parallel,
get_sequence_parallel_group,
get_sequence_parallel_world_size)
SUPPORT_XTUNER = True
except ImportError:
pass

DEFAULT_SYSTEM = 'You are a helpful assistant.'
History = List[Union[Tuple[str, str], List[str]]]

Expand Down Expand Up @@ -432,31 +421,6 @@ def _concat_tokenizer_kwargs(
assert len(old_tokenizer_kwargs) == 0
return curr_tokenizer_kwargs

def _pad_and_split_for_sequence_parallel(self, tokenizer, input_ids,
labels, position_ids,
attention_mask, loss_scale):
input_ids = pad_for_sequence_parallel(
input_ids, padding_value=tokenizer.pad_token_id, dim=-1)
labels = pad_for_sequence_parallel(labels, padding_value=-100, dim=-1)
position_ids = pad_for_sequence_parallel(
position_ids, padding_value=0, dim=-1)
attention_mask = pad_for_sequence_parallel(
attention_mask, padding_value=0, dim=-1)

sp_group = get_sequence_parallel_group()
input_ids = split_for_sequence_parallel(
input_ids, dim=1, sp_group=sp_group)
labels = split_for_sequence_parallel(labels, dim=1, sp_group=sp_group)
position_ids = split_for_sequence_parallel(
position_ids, dim=1, sp_group=sp_group)
if loss_scale is not None:
loss_scale = pad_for_sequence_parallel(
loss_scale, padding_value=0., dim=-1)
loss_scale = split_for_sequence_parallel(
loss_scale, dim=1, sp_group=sp_group)

return input_ids, labels, position_ids, attention_mask, loss_scale

def data_collator(self,
batch: List[Dict[str, Any]],
padding_to: Optional[int] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -506,19 +470,10 @@ def data_collator(self,
padding_to, input_ids, attention_mask, labels, loss_scale,
self.max_length, self.tokenizer, rank, world_size)

bs, seq_len = input_ids.shape
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)

if get_sequence_parallel_world_size() > 1:
input_ids, labels, position_ids, attention_mask, loss_scale = \
self._pad_and_split_for_sequence_parallel(
tokenizer, input_ids, labels, position_ids, attention_mask, loss_scale)

res = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'position_ids': position_ids,
}
if loss_scale is not None:
res['loss_scale'] = loss_scale
Expand Down
118 changes: 4 additions & 114 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss
Expand All @@ -15,8 +14,7 @@
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer_utils import seed_worker
from transformers.utils import is_peft_available, is_torch_xla_available
from transformers.utils import is_peft_available

from swift.torchacc_utils import (ta_eval_dataloader, ta_test_dataloader,
ta_train_dataloader)
Expand All @@ -30,30 +28,14 @@
except ImportError:
from transformers.deepspeed import is_deepspeed_zero3_enabled

if is_torch_xla_available():
import torch_xla.core.xla_model as xm

SUPPORT_XTUNER = False

try:
from xtuner.parallel.sequence import (init_sequence_parallel,
SequenceParallelSampler,
reduce_sequence_parallel_loss,
get_sequence_parallel_world_size,
get_sequence_parallel_group)
from mmengine.device.utils import get_max_cuda_memory
SUPPORT_XTUNER = True
except ImportError:
pass


class Trainer(PushToMsHubMixin, SwiftMixin, HfTrainer):
pass


class Seq2SeqTrainer(PushToMsHubMixin, SwiftMixin, HfSeq2SeqTrainer):

def __init__(self, sequence_parallel_size=1, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# performance
self.perf: Dict[str, Any] = {
Expand All @@ -67,9 +49,6 @@ def __init__(self, sequence_parallel_size=1, *args, **kwargs):
self.model, 'get_trainable_parameters') else None,
}
self._acc = torch.tensor(0.).to(self.args.device)
if SUPPORT_XTUNER:
self.sequence_parallel_size = sequence_parallel_size
init_sequence_parallel(sequence_parallel_size)

def train(self, *args, **kwargs) -> torch.Tensor:
res = super().train(*args, **kwargs)
Expand Down Expand Up @@ -226,7 +205,6 @@ def compute_scaled_loss(self, labels: torch.Tensor,
return loss.mean()

def compute_loss(self, model, inputs, return_outputs=None):
assert 'labels' in inputs
if not hasattr(self, '_custom_metrics'):
self._custom_metrics = {}

Expand Down Expand Up @@ -262,17 +240,9 @@ def compute_loss(self, model, inputs, return_outputs=None):
else:
loss = outputs['loss'] if isinstance(outputs, dict) else outputs[0]

preds = outputs.logits.argmax(dim=2)[..., :-1]
if labels is None:
labels = inputs['labels']

if SUPPORT_XTUNER:
# reduce loss for logging correctly
num_tokens = (labels != -100).sum()
loss = reduce_sequence_parallel_loss(loss, num_tokens,
get_sequence_parallel_group())

preds = outputs.logits.argmax(dim=2)[..., :-1]

labels = labels[..., 1:]
masks = labels != -100
acc_strategy = getattr(self.args, 'acc_strategy', 'token')
Expand All @@ -296,90 +266,10 @@ def compute_loss(self, model, inputs, return_outputs=None):
'acc'] + acc / self.args.gradient_accumulation_steps
return (loss, outputs) if return_outputs else loss

# Support logging cuda memory usage
# hacky: Override Trainer's private method
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch,
ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if is_torch_xla_available():
xm.mark_step()

logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()

# reset tr_loss to zero
tr_loss -= tr_loss

logs['loss'] = round(
tr_loss_scalar /
(self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs['grad_norm'] = grad_norm.detach().item() if isinstance(
grad_norm, torch.Tensor) else grad_norm
logs['learning_rate'] = self._get_learning_rate()
logs['memory'] = get_max_cuda_memory()

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs)

metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
if isinstance(self.lr_scheduler,
torch.optim.lr_scheduler.ReduceLROnPlateau):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith('eval_'):
metric_to_check = f'eval_{metric_to_check}'
self.lr_scheduler.step(metrics[metric_to_check])

if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(
self.args, self.state, self.control)

def get_train_dataloader(self):

if not use_torchacc():
# modified from HFTrainer.get_train_dataloader
# RandomSampler -> SequenceParallelSampler
if trainer.is_datasets_available():
import datasets
if self.train_dataset is None:
raise ValueError('Trainer: training requires a train_dataset.')

train_dataset = self.train_dataset
data_collator = self.data_collator
if trainer.is_datasets_available() and isinstance(
train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description='training')
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description='training')

dataloader_params = {
'batch_size': self._train_batch_size,
'collate_fn': data_collator,
'num_workers': self.args.dataloader_num_workers,
'pin_memory': self.args.dataloader_pin_memory,
'persistent_workers': self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params['sampler'] = SequenceParallelSampler(
train_dataset, seed=1024)
dataloader_params['drop_last'] = self.args.dataloader_drop_last
dataloader_params['worker_init_fn'] = seed_worker

return DataLoader(train_dataset, **dataloader_params)
return super().get_train_dataloader()

else:
if trainer.is_datasets_available():
Expand Down