Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/pytorch/llm/scripts/cogagent_chat/lora/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Experimental environment: V100, A10, 3090
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0 \
python llm_infer.py \
--ckpt_dir "/xxx/xxx/cogagent-chat/vx-xxx/checkpoint-xx" \
--load_args_from_ckpt_dir true \
--eval_human true \
--max_length 4096 \
--use_flash_attn true \
--max_new_tokens 2048 \
--temperature 0.3 \
--top_p 0.7 \
--repetition_penalty 1.05 \
--do_sample true \
--merge_lora_and_save false \
33 changes: 33 additions & 0 deletions examples/pytorch/llm/scripts/cogagent_chat/lora/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Experimental environment: 2 * A100
# 2 * 45GB
PYTHONPATH=../../.. \
CUDA_VISIBLE_DEVICES=0,1 \
python llm_sft.py \
--model_type cogagent-chat \
--sft_type lora \
--tuner_backend swift \
--dtype fp16 \
--output_dir output \
--dataset capcha-images \
--train_dataset_sample -1 \
--num_train_epochs 2 \
--max_length 1024 \
--check_dataset_strategy warning \
--lora_rank 8 \
--lora_alpha 32 \
--lora_dropout_p 0.05 \
--gradient_checkpointing false \
--batch_size 1 \
--weight_decay 0.01 \
--learning_rate 1e-4 \
--gradient_accumulation_steps 16 \
--max_grad_norm 0.5 \
--warmup_ratio 0.03 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 10
--push_to_hub false \
--hub_model_id cogagent-chat-lora \
--hub_private_repo true \
--hub_token 'your-sdk-token' \
17 changes: 13 additions & 4 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,13 @@ def prepare_model_template(
logger.info(get_model_info(model))
show_layers(model)

template: Template = get_template(args.template_type, tokenizer,
args.system, args.max_length,
args.truncation_strategy)
template: Template = get_template(
args.template_type,
tokenizer,
args.system,
args.max_length,
args.truncation_strategy,
model=model)
args.system = template.default_system
logger.info(f'system: {args.system}')
return model, template
Expand Down Expand Up @@ -175,6 +179,10 @@ def llm_infer(args: InferArguments) -> None:
logger.info(
'The current template only supports single-round dialogues.')
history = []
if 'cogagent' in args.model_type:
image = input('Input an image url<<< ')
from PIL import Image
image = Image.open(image)
while True:
if input_mode == 'S':
query = input('<<< ')
Expand Down Expand Up @@ -210,7 +218,8 @@ def llm_infer(args: InferArguments) -> None:
print(response[print_idx:], end='', flush=True)
print_idx = len(response)
else:
gen = inference_stream(model, template, query, history)
gen = inference_stream(
model, template, query, history, image=image)
for response, new_history in gen:
if len(response) > print_idx:
print(response[print_idx:], end='', flush=True)
Expand Down
10 changes: 7 additions & 3 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,13 @@ def llm_sft(args: SftArguments) -> str:

logger.info(f'train_dataset: {train_dataset}')
logger.info(f'val_dataset: {val_dataset}')
template: Template = get_template(args.template_type, tokenizer,
args.system, args.max_length,
args.truncation_strategy)
template: Template = get_template(
args.template_type,
tokenizer,
args.system,
args.max_length,
args.truncation_strategy,
model=model)
args.system = template.default_system
logger.info(f'system: {args.system}')
if not args.lazy_tokenize:
Expand Down
25 changes: 24 additions & 1 deletion swift/llm/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def _remove_useless_columns(dataset: HfDataset) -> HfDataset:
k_list = []
for k in dataset.features.keys():
if k in {'query', 'response', 'system', 'history'}:
if k in {'query', 'response', 'system', 'history', 'image'}:
k_list.append(k)
dataset = dataset.select_columns(k_list)
return dataset
Expand Down Expand Up @@ -106,6 +106,7 @@ class DatasetName:
# vision
coco_en = 'coco-en'
coco_mini_en = 'coco-mini-en'
capcha_images = 'capcha-images'
# audio
aishell1_zh = 'aishell1-zh'
aishell1_mini_zh = 'aishell1-mini-zh'
Expand Down Expand Up @@ -599,6 +600,28 @@ def _preprocess_sharegpt(dataset: HfDataset) -> HfDataset:
get_dataset_from_repo,
tags=['chat', 'general', 'multi-round'])


def _preprocess_capcha_images(dataset: HfDataset) -> HfDataset:
dataset = dataset.rename_columns({
'solution': 'response',
})

def add_system(row):
row['query'] = 'CAPTCHA:'
return row

dataset = dataset.map(add_system)
return dataset


register_dataset(
DatasetName.capcha_images,
'AI-ModelScope/captcha-images', [('default', 'train')],
[('default', 'validation')],
_preprocess_capcha_images,
get_dataset_from_repo,
tags=['chat', 'multi-modal', 'vision', '🔥'])

register_dataset(
DatasetName.cls_fudan_news_zh,
'damo/zh_cls_fudan-news', ['train'],
Expand Down
58 changes: 58 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class ModelType:
# phi
phi2_3b = 'phi2-3b'

cogagent_chat = 'cogagent-chat'
cogagent_vqa = 'cogagent-vqa'

@classmethod
def get_model_name_list(cls) -> List[str]:
res = []
Expand All @@ -172,6 +175,11 @@ class LoRATM(NamedTuple):
qwen = ['c_attn']
polylm = ['c_attn']
bloom = ['query_key_value']
cogagent = [
'vision_expert_query_key_value', 'vision_expert_dense',
'language_expert_query_key_value', 'language_expert_dense', 'query',
'key_value', 'dense'
]
phi = ['Wqkv']


Expand Down Expand Up @@ -318,6 +326,56 @@ def get_model_tokenizer_from_repo(model_dir: str,
return model, tokenizer


@register_model(
ModelType.cogagent_chat,
'ZhipuAI/cogagent-chat',
LoRATM.cogagent,
TemplateType.cogagent,
requires=['transformers>=4.36'],
support_vllm=False)
@register_model(
ModelType.cogagent_vqa,
'ZhipuAI/cogagent-vqa',
LoRATM.cogagent,
TemplateType.cogagent,
requires=['transformers>=4.36'],
support_vllm=False)
def get_model_tokenizer_from_repo_cogagent(
model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
model_config=None,
tokenizer=None,
automodel_class=AutoModelForCausalLM,
**kwargs):
"""load from an independent repository"""
if model_config is None:
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
model_config.torch_dtype = torch_dtype
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(
'AI-ModelScope/vicuna-7b-v1.5',
trust_remote_code=True,
padding_side='left')
eos_token = kwargs.get('eos_token')
if eos_token is not None:
tokenizer.eos_token = eos_token
model = None
if load_model:
model = automodel_class.from_pretrained(
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs)
logger.info(
'CogAgent with FusedLayerNorm will cause an training loss of Nan, '
'to avoid this, please uninstall apex.')
return model, tokenizer


@register_model(
ModelType.internlm_20b_chat,
'Shanghai_AI_Laboratory/internlm-chat-20b',
Expand Down
Loading