diff --git a/examples/pytorch/llm/scripts/cogagent_chat/lora/infer.sh b/examples/pytorch/llm/scripts/cogagent_chat/lora/infer.sh new file mode 100644 index 0000000000..4d0e48de20 --- /dev/null +++ b/examples/pytorch/llm/scripts/cogagent_chat/lora/infer.sh @@ -0,0 +1,15 @@ +# Experimental environment: V100, A10, 3090 +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0 \ +python llm_infer.py \ + --ckpt_dir "/xxx/xxx/cogagent-chat/vx-xxx/checkpoint-xx" \ + --load_args_from_ckpt_dir true \ + --eval_human true \ + --max_length 4096 \ + --use_flash_attn true \ + --max_new_tokens 2048 \ + --temperature 0.3 \ + --top_p 0.7 \ + --repetition_penalty 1.05 \ + --do_sample true \ + --merge_lora_and_save false \ diff --git a/examples/pytorch/llm/scripts/cogagent_chat/lora/sft.sh b/examples/pytorch/llm/scripts/cogagent_chat/lora/sft.sh new file mode 100644 index 0000000000..0b642444db --- /dev/null +++ b/examples/pytorch/llm/scripts/cogagent_chat/lora/sft.sh @@ -0,0 +1,33 @@ +# Experimental environment: 2 * A100 +# 2 * 45GB +PYTHONPATH=../../.. \ +CUDA_VISIBLE_DEVICES=0,1 \ +python llm_sft.py \ + --model_type cogagent-chat \ + --sft_type lora \ + --tuner_backend swift \ + --dtype fp16 \ + --output_dir output \ + --dataset capcha-images \ + --train_dataset_sample -1 \ + --num_train_epochs 2 \ + --max_length 1024 \ + --check_dataset_strategy warning \ + --lora_rank 8 \ + --lora_alpha 32 \ + --lora_dropout_p 0.05 \ + --gradient_checkpointing false \ + --batch_size 1 \ + --weight_decay 0.01 \ + --learning_rate 1e-4 \ + --gradient_accumulation_steps 16 \ + --max_grad_norm 0.5 \ + --warmup_ratio 0.03 \ + --eval_steps 100 \ + --save_steps 100 \ + --save_total_limit 2 \ + --logging_steps 10 + --push_to_hub false \ + --hub_model_id cogagent-chat-lora \ + --hub_private_repo true \ + --hub_token 'your-sdk-token' \ diff --git a/swift/llm/infer.py b/swift/llm/infer.py index 20bc708002..2e6176e2d2 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -140,9 +140,13 @@ def prepare_model_template( logger.info(get_model_info(model)) show_layers(model) - template: Template = get_template(args.template_type, tokenizer, - args.system, args.max_length, - args.truncation_strategy) + template: Template = get_template( + args.template_type, + tokenizer, + args.system, + args.max_length, + args.truncation_strategy, + model=model) args.system = template.default_system logger.info(f'system: {args.system}') return model, template @@ -175,6 +179,10 @@ def llm_infer(args: InferArguments) -> None: logger.info( 'The current template only supports single-round dialogues.') history = [] + if 'cogagent' in args.model_type: + image = input('Input an image url<<< ') + from PIL import Image + image = Image.open(image) while True: if input_mode == 'S': query = input('<<< ') @@ -210,7 +218,8 @@ def llm_infer(args: InferArguments) -> None: print(response[print_idx:], end='', flush=True) print_idx = len(response) else: - gen = inference_stream(model, template, query, history) + gen = inference_stream( + model, template, query, history, image=image) for response, new_history in gen: if len(response) > print_idx: print(response[print_idx:], end='', flush=True) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 01c97e889e..1cec1b0b0b 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -172,9 +172,13 @@ def llm_sft(args: SftArguments) -> str: logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') - template: Template = get_template(args.template_type, tokenizer, - args.system, args.max_length, - args.truncation_strategy) + template: Template = get_template( + args.template_type, + tokenizer, + args.system, + args.max_length, + args.truncation_strategy, + model=model) args.system = template.default_system logger.info(f'system: {args.system}') if not args.lazy_tokenize: diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 0845fe03ea..d6a8803625 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -28,7 +28,7 @@ def _remove_useless_columns(dataset: HfDataset) -> HfDataset: k_list = [] for k in dataset.features.keys(): - if k in {'query', 'response', 'system', 'history'}: + if k in {'query', 'response', 'system', 'history', 'image'}: k_list.append(k) dataset = dataset.select_columns(k_list) return dataset @@ -106,6 +106,7 @@ class DatasetName: # vision coco_en = 'coco-en' coco_mini_en = 'coco-mini-en' + capcha_images = 'capcha-images' # audio aishell1_zh = 'aishell1-zh' aishell1_mini_zh = 'aishell1-mini-zh' @@ -599,6 +600,28 @@ def _preprocess_sharegpt(dataset: HfDataset) -> HfDataset: get_dataset_from_repo, tags=['chat', 'general', 'multi-round']) + +def _preprocess_capcha_images(dataset: HfDataset) -> HfDataset: + dataset = dataset.rename_columns({ + 'solution': 'response', + }) + + def add_system(row): + row['query'] = 'CAPTCHA:' + return row + + dataset = dataset.map(add_system) + return dataset + + +register_dataset( + DatasetName.capcha_images, + 'AI-ModelScope/captcha-images', [('default', 'train')], + [('default', 'validation')], + _preprocess_capcha_images, + get_dataset_from_repo, + tags=['chat', 'multi-modal', 'vision', '🔥']) + register_dataset( DatasetName.cls_fudan_news_zh, 'damo/zh_cls_fudan-news', ['train'], diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 6b4483c124..5ddc05745e 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -154,6 +154,9 @@ class ModelType: # phi phi2_3b = 'phi2-3b' + cogagent_chat = 'cogagent-chat' + cogagent_vqa = 'cogagent-vqa' + @classmethod def get_model_name_list(cls) -> List[str]: res = [] @@ -172,6 +175,11 @@ class LoRATM(NamedTuple): qwen = ['c_attn'] polylm = ['c_attn'] bloom = ['query_key_value'] + cogagent = [ + 'vision_expert_query_key_value', 'vision_expert_dense', + 'language_expert_query_key_value', 'language_expert_dense', 'query', + 'key_value', 'dense' + ] phi = ['Wqkv'] @@ -318,6 +326,56 @@ def get_model_tokenizer_from_repo(model_dir: str, return model, tokenizer +@register_model( + ModelType.cogagent_chat, + 'ZhipuAI/cogagent-chat', + LoRATM.cogagent, + TemplateType.cogagent, + requires=['transformers>=4.36'], + support_vllm=False) +@register_model( + ModelType.cogagent_vqa, + 'ZhipuAI/cogagent-vqa', + LoRATM.cogagent, + TemplateType.cogagent, + requires=['transformers>=4.36'], + support_vllm=False) +def get_model_tokenizer_from_repo_cogagent( + model_dir: str, + torch_dtype: Dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + model_config=None, + tokenizer=None, + automodel_class=AutoModelForCausalLM, + **kwargs): + """load from an independent repository""" + if model_config is None: + model_config = AutoConfig.from_pretrained( + model_dir, trust_remote_code=True) + model_config.torch_dtype = torch_dtype + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained( + 'AI-ModelScope/vicuna-7b-v1.5', + trust_remote_code=True, + padding_side='left') + eos_token = kwargs.get('eos_token') + if eos_token is not None: + tokenizer.eos_token = eos_token + model = None + if load_model: + model = automodel_class.from_pretrained( + model_dir, + config=model_config, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs) + logger.info( + 'CogAgent with FusedLayerNorm will cause an training loss of Nan, ' + 'to avoid this, please uninstall apex.') + return model, tokenizer + + @register_model( ModelType.internlm_20b_chat, 'Shanghai_AI_Laboratory/internlm-chat-20b', diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index 924e65c562..97e1bebf53 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -35,6 +35,7 @@ class TemplateType: deepseek = 'deepseek' codefuse_codellama = 'codefuse-codellama' deepseek_coder = 'deepseek-coder' + cogagent = 'cogagent' @classmethod def get_template_name_list(cls) -> List[str]: @@ -125,7 +126,8 @@ def _concat_context_list( def _encode_context_list( tokenizer: PreTrainedTokenizerBase, context_list: List[Context], - compute_loss_idx: Optional[List[int]] = None + compute_loss_idx: Optional[List[int]] = None, + **args, ) -> Tuple[List[int], Optional[List[int]], Dict[str, Any]]: input_ids: List[int] = [] labels: List[int] = [] @@ -154,6 +156,7 @@ def _encode_context_list( [old_audio_info[k], audio_info[k]], dim=0) for k in ['audio_span_tokens', 'audio_urls']: old_audio_info[k] = old_audio_info[k] + audio_info[k] + token_list = tokenizer( context, return_attention_mask=False, @@ -293,13 +296,13 @@ def __init__(self, self.use_default_system = True self._is_init = False - def _init_template( - self, - tokenizer: PreTrainedTokenizerBase, - default_system: Optional[str] = None, - max_length: Optional[int] = None, - truncation_strategy: Literal['delete', 'truncation_left'] = 'delete' - ) -> None: + def _init_template(self, + tokenizer: PreTrainedTokenizerBase, + default_system: Optional[str] = None, + max_length: Optional[int] = None, + truncation_strategy: Literal[ + 'delete', 'truncation_left'] = 'delete', + **kwargs) -> None: assert self._is_init is False self._is_init = True self.tokenizer = tokenizer @@ -334,6 +337,173 @@ def encode(self, example: Dict[str, self.truncation_strategy) +class CogAgentTemplate(Template): + LANGUAGE_TOKEN_TYPE = 0 + VISION_TOKEN_TYPE = 1 + + def _init_template(self, + tokenizer: PreTrainedTokenizerBase, + default_system: Optional[str] = None, + max_length: Optional[int] = None, + truncation_strategy: Literal[ + 'delete', 'truncation_left'] = 'delete', + **kwargs) -> None: + self.model = kwargs.pop('model') + self.suffix = [tokenizer.eos_token] + super()._init_template(tokenizer, default_system, max_length, + truncation_strategy) + + @staticmethod + def vqa_history_to_prompt(history, query): + # Only support single round chat in vqa mode + prompt = 'Question: ' + # for i, (old_query, response) in enumerate(history): + # prompt += old_query + " Short answer: " + response + " Question: " + prompt += query + ' Short answer:' + return prompt + + @staticmethod + def chat_old_history_to_prompt(history, query): + prompt = 'Question: ' + for i, (old_query, response) in enumerate(history): + prompt += old_query + ' Answer: ' + response + '\nQuestion: ' + prompt += query + ' Answer:' + return prompt + + @staticmethod + def chat_history_to_prompt(history, query): + prompt = ' [INST] ' + for i, (old_query, response) in enumerate(history): + prompt += old_query + ' [/INST] ' + response + ' [INST] ' + prompt += query + ' [/INST] ' + return prompt + + @staticmethod + def base_history_to_prompt(history, query): + prompt = query + return prompt + + _history_to_prompt = { + 'base': base_history_to_prompt, + 'chat': chat_history_to_prompt, + 'chat_old': chat_old_history_to_prompt, + 'vqa': vqa_history_to_prompt + } + + def build_conversation_input_ids( + self, + tokenizer: 'PreTrainedTokenizer', + *, + query: str, + label: Optional[str] = None, + history: Optional[List[Tuple[str, str]]] = None, + images: Optional[List['PIL.Image']] = None, + template_version: Optional[Literal['base', 'chat', 'vqa']] = None, + ): + from torchvision import transforms + image_size: int = self.model.config.vision_config['image_size'] + cross_image_size: int = self.model.config.cross_image_size + patch_size: int = self.model.config.vision_config['patch_size'] + template_version = template_version or self.model.config.template_version + assert images is None or len( + images) <= 1, 'not support multi images by now.' + history = history or [] + text = self._history_to_prompt[template_version](history, query) + + input_ids = [tokenizer.bos_token_id] + token_type_ids = [self.LANGUAGE_TOKEN_TYPE] + if images is not None and len(images) == 1: + ori = images + # vision + transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + images = [transform(ori[0])] + cross_transform = transforms.Compose([ + transforms.Resize( + (cross_image_size, cross_image_size), + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711)), + ]) + cross_images = [cross_transform(ori[0])] + # language + vision_token_num = (image_size // patch_size) * (image_size + // patch_size) + 2 + input_ids += [tokenizer.pad_token_id] * vision_token_num + token_type_ids += [self.VISION_TOKEN_TYPE] * vision_token_num + text_ids = tokenizer.encode(text, add_special_tokens=False) + train = label is not None + label_ids = tokenizer.encode( + label, add_special_tokens=False) if train else [] + if len(text_ids) + len(input_ids) + len( + label_ids) > self.max_length - 1: + if self.truncation_strategy == 'delete' or ( + len(input_ids) + len(label_ids) >= self.max_length - 1): + return None + else: + text_ids = text_ids[-(self.max_length - len(input_ids) + - len(label_ids) - 1):] + + input_ids += text_ids + if train: + labels = [-100] * len(input_ids) + label_ids + [ + tokenizer.eos_token_id + ] + input_ids += label_ids + [tokenizer.eos_token_id] + token_type_ids += [self.LANGUAGE_TOKEN_TYPE] * ( + len(text_ids) + len(label_ids) + 1) + else: + token_type_ids += [self.LANGUAGE_TOKEN_TYPE] * len(text_ids) + attention_mask = [1] * len(input_ids) + + if len(input_ids) < self.max_length and train: + padding_len = self.max_length - len(input_ids) + input_ids += [tokenizer.pad_token_id] * padding_len + token_type_ids += [self.LANGUAGE_TOKEN_TYPE] * padding_len + attention_mask += [0] * padding_len + if label_ids: + labels += [-100] * padding_len + + if train: + return { + 'input_ids': torch.tensor(input_ids, dtype=torch.long), + 'token_type_ids': + torch.tensor(token_type_ids, dtype=torch.long), + 'attention_mask': + torch.tensor(attention_mask, dtype=torch.long), + 'images': images, + 'cross_images': cross_images, + 'labels': labels, + } + else: + return { + 'input_ids': + torch.tensor(input_ids, dtype=torch.long), + 'token_type_ids': + torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0), + 'attention_mask': + torch.tensor(attention_mask, dtype=torch.long).unsqueeze(0), + 'images': [images], + 'cross_images': [cross_images], + } + + def encode(self, example: Dict[str, + Any]) -> Dict[str, Optional[List[int]]]: + return self.build_conversation_input_ids( + self.tokenizer, + query=example['query'], + label=example.get('response'), + history=example.get('history'), + images=[example['image'].convert('RGB')]) + + TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {} @@ -488,17 +658,21 @@ def register_template(template_type: str, Template(['{{SYSTEM}}'], ['### Human: {{QUERY}}\n\n### Assistant: '], ['<|endoftext|>'], ['<|endoftext|>'], '')) +register_template(TemplateType.cogagent, + CogAgentTemplate([], [], [], [], None, [])) + def get_template( template_type: str, tokenizer: PreTrainedTokenizerBase, default_system: Optional[str] = None, max_length: Optional[int] = None, - truncation_strategy: Literal['delete', 'truncation_left'] = 'delete' + truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', + **kwargs, ) -> Template: template_info = TEMPLATE_MAPPING[template_type] template = deepcopy(template_info['template']) template._init_template(tokenizer, default_system, max_length, - truncation_strategy) + truncation_strategy, **kwargs) template.template_type = template_type return template diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 95d207a362..f13283b60e 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -342,6 +342,13 @@ def data_collate_fn(batch: List[Dict[str, Any]], get_audio_info(tokenizer, audio_info=b['audio_info']) for b in batch ] + if batch[0].get('images') is not None: + res['images'] = [b['images'] for b in batch] + if batch[0].get('cross_images') is not None: + res['cross_images'] = [b['cross_images'] for b in batch] + if batch[0].get('token_type_ids') is not None: + res['token_type_ids'] = torch.stack( + [b['token_type_ids'] for b in batch]) return res @@ -446,6 +453,7 @@ def inference_stream( query: str, history: Optional[History] = None, system: Optional[str] = None, + image: Optional['Image'] = None, *, generation_config: Optional[GenerationConfig] = None, stop_words: Optional[List[StopWords]] = None, @@ -460,13 +468,18 @@ def inference_stream( else: history = deepcopy(history) example = {'query': query, 'history': history, 'system': system} + if image is not None: + example['image'] = image inputs = template.encode(example) audio_info = inputs.get('audio_info') # Compatible with qwen-audio input_ids = inputs['input_ids'] tokenizer = template.tokenizer device = next(model.parameters()).device input_ids = torch.tensor(input_ids)[None].to(device) - attention_mask = torch.ones_like(input_ids).to(device) + if 'attention_mask' not in inputs: + attention_mask = torch.ones_like(input_ids).to(device) + else: + attention_mask = inputs['attention_mask'].to(device) model.eval() if generation_config is None: generation_config = getattr(model, 'generation_config', None) @@ -487,6 +500,16 @@ def inference_stream( stop_words.append(template.suffix[-1]) decode_kwargs = {} model_kwargs = {} + if 'token_type_ids' in inputs: + model_kwargs['token_type_ids'] = inputs['token_type_ids'].to(device) + if 'images' in inputs: + model_kwargs['images'] = [[ + inputs['images'][0][0].to(device).to(torch.float16) + ]] + if 'cross_images' in inputs: + model_kwargs['cross_images'] = [[ + inputs['cross_images'][0][0].to(device).to(torch.float16) + ]] if audio_info is not None: audio_info = get_audio_info(tokenizer, audio_info=audio_info) decode_kwargs['audio_info'] = audio_info