Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/pytorch/llm/rome_example/request.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[
{
"prompt": "{} was the founder of",
"subject": "Steve Jobs",
"target": "Microsoft"
},
{
"prompt": "{} is located in",
"subject": "HangZhou",
"target": "Africa"
}
]
6 changes: 6 additions & 0 deletions examples/pytorch/llm/rome_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from swift.llm.run import rome_main

if __name__ == '__main__':
rome_main()
15 changes: 15 additions & 0 deletions examples/pytorch/llm/scripts/llama2_13b_chat/rome.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Experimental environment: A10
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python rome_infer.py \
--model_id_or_path modelscope/Llama-2-13b-chat-ms \
--model_revision master \
--template_type llama \
--dtype bf16 \
--eval_human true \
--max_new_tokens 128 \
--temperature 0.1 \
--top_k 50 \
--top_p 0.9 \
--do_sample true \
--rome_request_file rome_example/request.json
1 change: 1 addition & 0 deletions swift/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .infer import llm_infer
from .rome import rome_infer
from .sft import llm_sft
from .utils import *
87 changes: 87 additions & 0 deletions swift/llm/rome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import torch
from modelscope import GenerationConfig

from swift.tuners import Swift
from swift.utils import (get_logger, print_model_info, seed_everything,
show_layers)
from ..tuners.rome import RomeConfig
from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer,
get_template, inference)

logger = get_logger()


def rome_infer(args: RomeArguments) -> None:
logger.info(f'args: {args}')
logger.info(
'Rome does not support quantization for now, all quantization args will be ignored.'
)
logger.info(f'device_count: {torch.cuda.device_count()}')
seed_everything(args.seed)

# ### Loading Model and Tokenizer
model_kwargs = {'low_cpu_mem_usage': True, 'device_map': 'auto'}
kwargs = {'use_flash_attn': args.use_flash_attn}
model, tokenizer = get_model_tokenizer(args.model_type, args.torch_dtype,
model_kwargs, **kwargs)

with open(args.rome_request_file, 'r') as f:
request = json.load(f)

rome_type: str = None
if args.model_type in ('llama2-13b-chat', 'llama2-13b', 'llama-13b-chat',
'llama-13b'):
rome_type = 'llama-13b'
elif args.model_type in ('llama2-7b-chat', 'llama2-7b', 'llama-7b-chat',
'llama-7b'):
rome_type = 'llama-7b'

config = RomeConfig(
model_type=rome_type,
knowledge=request,
tokenizer=tokenizer,
)
model = Swift.prepare_model(model, config, inference_mode=True)

show_layers(model)
print_model_info(model)

# ### Inference
template: Template = get_template(args.template_type, tokenizer,
args.system, args.max_length)
generation_config = GenerationConfig(
max_length=None,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
do_sample=args.do_sample,
repetition_penalty=args.repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
logger.info(f'generation_config: {generation_config}')
if args.overwrite_generation_config:
generation_config.save_pretrained(args.ckpt_dir)
model.generation_config = generation_config

if args.eval_human:
while True:
query = input('<<< ')
data = {'query': query}
input_ids = template.encode(data)['input_ids']
inference(input_ids, model, tokenizer, args.stream)
else:
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio,
args.dataset_seed)
mini_val_dataset = val_dataset.select(
range(min(args.show_dataset_sample, val_dataset.shape[0])))
for data in mini_val_dataset:
response = data['response']
data['response'] = None
input_ids = template.encode(data)['input_ids']
inference(input_ids, model, tokenizer, args.stream)
print()
print(f'[LABELS]{response}')
print('-' * 80)
# input('next[ENTER]')
5 changes: 3 additions & 2 deletions swift/llm/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.llm import (InferArguments, SftArguments, get_main, llm_infer,
llm_sft)
from swift.llm import (InferArguments, RomeArguments, SftArguments, get_main,
llm_infer, llm_sft, rome_infer)

sft_main = get_main(SftArguments, llm_sft)
infer_main = get_main(InferArguments, llm_infer)
rome_main = get_main(RomeArguments, rome_infer)
16 changes: 15 additions & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def llm_sft(args: SftArguments) -> str:
model_kwargs, **kwargs)

# ### Preparing LoRA
if args.sft_type == 'lora' or args.sft_type == 'longlora':
if args.sft_type in ('lora', 'qalora', 'longlora'):
if args.resume_from_checkpoint is None:
if 'ALL' in args.lora_target_modules:
assert len(args.lora_target_modules) == 1
Expand Down Expand Up @@ -88,6 +88,20 @@ def llm_sft(args: SftArguments) -> str:
use_flash_attn=args.use_flash_attn)
model = Swift.prepare_model(model, longlora_config)
logger.info(f'longlora_config: {longlora_config}')
elif args.sft_type == 'qalora':
assert getattr(
model, 'quantization_method',
None) == 'gptq', 'qalora must be used with auto_gptq'
lora_kwargs = {}
lora_config = LoRAConfig(
r=args.lora_rank,
target_modules=args.lora_target_modules,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout_p,
use_qa_lora=True,
**lora_kwargs)
model = Swift.prepare_model(model, lora_config)
logger.info(f'lora_config: {lora_config}')
else:
model = Swift.from_pretrained(
model, args.resume_from_checkpoint, is_trainable=True)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .argument import InferArguments, SftArguments
from .argument import InferArguments, RomeArguments, SftArguments
from .dataset import (DATASET_MAPPING, AlpacaPreprocessor,
ConversationsPreprocessor, DatasetName,
GetDatasetFunction, get_dataset, get_dataset_from_repo,
Expand Down
35 changes: 32 additions & 3 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class SftArguments:
model_cache_dir: Optional[str] = None

sft_type: str = field(
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
default='lora',
metadata={'choices': ['lora', 'longlora', 'qalora', 'full']})
tuner_backend: str = field(
default='swift', metadata={'choices': ['swift', 'peft']})
template_type: Optional[str] = field(
Expand Down Expand Up @@ -158,7 +159,7 @@ def init_argument(self):
# Make sure to set the same output_dir when using DDP.
self.output_dir = broadcast_string(self.output_dir)

if self.sft_type == 'lora' or self.sft_type == 'longlora':
if self.sft_type in ('lora', 'longlora', 'qalora'):
if self.learning_rate is None:
self.learning_rate = 1e-4
if self.only_save_model is None:
Expand Down Expand Up @@ -224,7 +225,8 @@ class InferArguments:
model_revision: Optional[str] = None

sft_type: str = field(
default='lora', metadata={'choices': ['longlora', 'lora', 'full']})
default='lora',
metadata={'choices': ['lora', 'longlora', 'qalora', 'full']})
template_type: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -291,6 +293,33 @@ def init_argument(self):
self.max_length = None


@dataclass
class RomeArguments(InferArguments):

rome_request_file: str = field(
default=None,
metadata={
'help':
'The rome request file, please check the documentation '
'to get the format'
})

def init_argument(self):
# Can be manually initialized, unlike __post_init__
handle_compatibility(self)
set_model_type(self)
handle_dir(self)

self.torch_dtype, _, _ = select_dtype(self)
if self.template_type is None:
self.template_type = MODEL_MAPPING[self.model_type]['template']
logger.info(f'Setting template_type: {self.template_type}')

assert isinstance(self.dataset, (list, tuple))
if self.max_length == -1:
self.max_length = None


dtype_mapping_reversed = {v: k for k, v in dtype_mapping.items()}


Expand Down
4 changes: 2 additions & 2 deletions swift/tuners/rome/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def execute_rome(
layer,
context_template,
)
logger.info('Left vector shape:', left_vector.shape)
logger.info(f'Left vector shape: {left_vector.shape}')
right_vector: torch.Tensor = compute_v(
model,
tok,
Expand All @@ -157,7 +157,7 @@ def execute_rome(
left_vector,
context_template,
)
logger.info('Right vector shape:', right_vector.shape)
logger.info(f'Right vector shape: {right_vector.shape}')
right_vector = right_vector.to(left_vector.dtype)

with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/rome/rome_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def from_name(cls, name: str):
mlp_module_tmp='model.layers.{}.mlp',
))
else:
raise NotImplementedError
raise NotImplementedError(f'{name} not supported.')

return cls(**data)