Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .model import (register_model, MODEL_MAPPING, ModelType, get_model_tokenizer, safe_snapshot_download,
HfConfigFactory, ModelInfo, ModelMeta, ModelKeys, register_model_arch, MultiModelKeys,
ModelArch, get_model_arch, MODEL_ARCH_MAPPING, get_model_info_meta, get_model_name, ModelGroup,
Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth)
Model, get_model_tokenizer_with_flash_attn, get_model_tokenizer_multimodal, load_by_unsloth,
git_clone_github)
from .dataset import (AlpacaPreprocessor, ResponsePreprocessor, MessagesPreprocessor, AutoPreprocessor,
DATASET_MAPPING, MediaResource, register_dataset, register_dataset_info, EncodePreprocessor,
LazyLLMDataset, ConstantLengthDataset, standard_keys, load_dataset, DATASET_TYPE,
Expand Down Expand Up @@ -51,7 +52,7 @@
'ModelInfo', 'ModelMeta', 'ModelKeys', 'register_model_arch', 'MultiModelKeys', 'ModelArch',
'MODEL_ARCH_MAPPING', 'get_model_arch', 'get_model_info_meta', 'get_model_name', 'register_model',
'ModelGroup', 'Model', 'get_model_tokenizer_with_flash_attn', 'get_model_tokenizer_multimodal',
'load_by_unsloth'
'load_by_unsloth', 'git_clone_github'
],
'dataset': [
'AlpacaPreprocessor', 'ClsPreprocessor', 'ComposePreprocessor', 'MessagesPreprocessor', 'DATASET_MAPPING',
Expand Down
13 changes: 8 additions & 5 deletions swift/llm/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
from ..register import DatasetMeta, SubsetDataset, register_dataset


def _concat_inst_inp_alpaca_zh(inst: str, inp: str) -> str:
if inp.startswith('输入:'):
inp = inp[3:]
return f'{inst}\n{inp}'
class AlpacaZhPreprocessor(AlpacaPreprocessor):

@classmethod
def concat_inst_input(cls, instruction, input_):
if input_ and input_.startswith('输入:'):
input_ = input_[3:]
return super().concat_inst_input(instruction, input_)


register_dataset(
DatasetMeta(
ms_dataset_id='AI-ModelScope/alpaca-gpt4-data-zh',
hf_dataset_id='llm-wizard/alpaca-gpt4-data-zh',
preprocess_func=AlpacaPreprocessor(concat_inst_input=_concat_inst_inp_alpaca_zh),
preprocess_func=AlpacaZhPreprocessor(),
tags=['chat', 'general', '🔥'],
))

Expand Down
30 changes: 9 additions & 21 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,34 +312,22 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:

class AlpacaPreprocessor(ResponsePreprocessor):

def __init__(self,
*,
concat_inst_input: Union[Callable[[str, str], str]] = '\n',
columns_mapping: Optional[Dict[str, str]] = None,
**kwargs) -> None:
"""Alpaca format preprocessor

Args:
concat_inst_input: The concat sep between instruction and input
"""
super().__init__(columns_mapping=columns_mapping, **kwargs)
self.concat_inst_input = concat_inst_input
@classmethod
def concat_inst_input(cls, instruction, input_):
if instruction and input_:
query = f'{instruction}\n{input_}'
else:
query = instruction or input_
assert isinstance(query, str), f'query: {query}'
return query

def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
instruction = row.pop('instruction', None)
input_ = row.pop('input', None)
output = row.pop('output', None)
if output is not None:
row['response'] = output

if instruction is not None or input_ is not None:
instruction = instruction or ''
input_ = input_ or ''
if isinstance(self.concat_inst_input, str):
query = instruction + self.concat_inst_input + input_
else:
query = self.concat_inst_input(instruction, input_)
row['query'] = query
row['query'] = self.concat_inst_input(instruction, input_)
return super().preprocess(row)


Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head,
load_by_unsloth, register_model)
from .utils import HfConfigFactory, ModelInfo, safe_snapshot_download
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download
5 changes: 3 additions & 2 deletions swift/llm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def git_clone_github(github_url: str,
local_repo_name: Optional[str] = None,
branch: Optional[str] = None,
commit_hash: Optional[str] = None) -> str:
if github_url.endswith('.git'):
github_url = github_url[:-4]
git_cache_dir = os.path.join(get_cache_dir(), '_github')
os.makedirs(git_cache_dir, exist_ok=True)
if local_repo_name is None:
Expand All @@ -282,8 +284,7 @@ def git_clone_github(github_url: str,
local_repo_path = os.path.join(git_cache_dir, local_repo_name)
with safe_ddp_context(hash_id=local_repo_path):
if not os.path.exists(local_repo_path):
if not github_url.endswith('.git'):
github_url = f'{github_url}.git'
github_url = f'{github_url}.git'
command = ['git', '-C', git_cache_dir, 'clone', github_url, local_repo_name]
command_str = f"git -C '{git_cache_dir}' clone '{github_url}' {local_repo_name}"
if branch is not None:
Expand Down
7 changes: 5 additions & 2 deletions tests/general/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def test_sft():
# _test_dataset(['AI-ModelScope/Duet-v0.5'])
# _test_dataset(['swift/SlimOrca', 'swift/cosmopedia-100k'])
# _test_dataset(['OmniData/Zhihu-KOL-More-Than-100-Upvotes'])
_test_dataset(['OmniData/Zhihu-KOL'])
# _test_dataset(['AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#200'])
# _test_dataset(['OmniData/Zhihu-KOL'])
_test_dataset([
'AI-ModelScope/alpaca-gpt4-data-zh#1000', 'AI-ModelScope/alpaca-gpt4-data-en#1000',
'AI-ModelScope/LongAlpaca-12k#1000'
])
# _test_dataset(['swift/Infinity-Instruct:all'])
# _test_dataset(['swift/sharegpt:all'])
# _test_dataset(['AI-ModelScope/sharegpt_gpt4:all'])
Expand Down
9 changes: 4 additions & 5 deletions tests/general/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

def test_local_dataset():
# please use git clone
local_dataset = '/mnt/nas2/huangjintao.hjt/work/datasets/swift-sft-mixture:firefly#100'
dataset = load_dataset(datasets=[local_dataset], streaming=True)[0]
for i, x in enumerate(dataset):
pass
print(i, x)
from swift.llm import git_clone_github
model_dir = git_clone_github('https://www.modelscope.cn/datasets/swift/swift-sft-mixture.git')
dataset = load_dataset(datasets=[f'{model_dir}:firefly'], streaming=True)[0]
print(next(iter(dataset)))


def test_hub_dataset():
Expand Down
Loading