Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
618eaaa
add transformers in gitignore
DaozeZhang Sep 10, 2024
8fd4051
fix a typo bug in text-caps
DaozeZhang Sep 11, 2024
f7a3eea
Merge branch 'modelscope:main' into main
DaozeZhang Sep 20, 2024
add49e5
add .run into gitignore
DaozeZhang Sep 20, 2024
4f43621
Merge commit 'f7a3eeaf0c57a100e985276fa783239f21fa9ee6'
DaozeZhang Sep 20, 2024
40d3ad0
add vlmeval to gitignore
DaozeZhang Sep 30, 2024
a3ae3ca
Merge commit '9bae58c9dc3015855bcd6a75794eb9c3947cb476'
DaozeZhang Sep 30, 2024
818d0fd
add my_model/ to gitignore
DaozeZhang Oct 18, 2024
619a4bd
Merge commit '704381f2e27b29c94b97b7c12d1e8a2dd886fdff'
DaozeZhang Oct 18, 2024
6bbd9f7
Merge branch 'modelscope:main' into main
DaozeZhang Oct 21, 2024
9ce9e9f
Merge branch 'modelscope:main' into main
DaozeZhang Oct 22, 2024
b7afb5e
Merge commit 'd42f8b5c3d3af7ee7ef58ff3ebdcb8ed790afb68'
DaozeZhang Oct 23, 2024
b22f7f3
Merge branch 'modelscope:main' into main
DaozeZhang Oct 23, 2024
9ef7838
Merge commit '44302aba29a8909892322ec18dfdf7398c5cec88'
DaozeZhang Oct 31, 2024
7d45ea3
Merge branch 'modelscope:main' into main
DaozeZhang Oct 31, 2024
e4d420e
Merge commit '53790bd87fb1327f83837a779a55d90e029101d4'
DaozeZhang Nov 8, 2024
a2a5c15
Merge branch 'modelscope:main' into main
DaozeZhang Feb 17, 2025
3017915
Merge commit 'a2a5c15cb2da3e75d11581230e0c310d7c4f5e5b'
DaozeZhang Feb 18, 2025
8926d3d
Merge commit '96eeecc6af0bf3fcfbec564eb6f5d66a8c8e0593'
DaozeZhang Feb 19, 2025
3b721f6
support generation using Janus Pro
DaozeZhang Feb 21, 2025
4d4037a
update comment
DaozeZhang Feb 21, 2025
61d5842
change some var name, add test_gene.py
DaozeZhang Feb 21, 2025
eac536f
change format
DaozeZhang Feb 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MaxLengthError(ValueError):


class Template(ProcessorMixin):
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>']
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>']
special_keys = ['images', 'videos', 'audios', 'objects']

image_placeholder = ['<image>']
Expand Down Expand Up @@ -173,6 +173,7 @@ def _preprocess_inputs(
) -> None:
if self.model_meta.is_multimodal:
self._replace_image_tags(inputs)
self._replace_start_image_tags(inputs)
images = inputs.images
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'}
load_images_origin = load_images
Expand Down Expand Up @@ -217,6 +218,19 @@ def _replace_image_tags(inputs: StdTemplateInputs):
assert not inputs.images, f'images: {images}, inputs.images: {inputs.images}'
inputs.images = images

@staticmethod
def _replace_start_image_tags(inputs: StdTemplateInputs):
# compat
generate_mode = False
for message in inputs.messages:
content = message['content']
if not isinstance(content, str):
continue
if content.strip().endswith('<start-image>'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strip()去掉吧

Copy link
Collaborator Author

@DaozeZhang DaozeZhang Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的话,用户有可能把<start-image>放末尾以后会再跟几个空格(比如我经常有这个习惯😂)需要考虑吗

generate_mode = True
message['content'] = re.sub('<start-image>', '', content).strip() # remove the <start-image>
Copy link
Collaborator

@Jintao-Huang Jintao-Huang Feb 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

message['content'][:-len('<start-image>')]

inputs.generate_mode = generate_mode

def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
Expand Down
195 changes: 160 additions & 35 deletions swift/llm/template/template/deepseek.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import numpy as np
import torch
import torch.nn as nn
from PIL import Image

from swift.utils import get_env_args
from ..base import Template
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
Expand Down Expand Up @@ -40,6 +44,9 @@ class DeepseekTemplateMeta(TemplateMeta):
class DeepseekVLTemplate(Template):
image_placeholder = ['<image_placeholder>']
skip_prompt = False
use_model = True

image_token_num_per_image: int = 576

def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
is_janus = getattr(self, 'is_janus', False)
Expand All @@ -48,48 +55,166 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
images = inputs.images
processor = self.processor
input_ids, labels = encoded['input_ids'], encoded['labels']
idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
new_input_ids, new_labels = [], []
lo = 0
for hi in idx_list:
new_input_ids += input_ids[lo:hi]

if not inputs.generate_mode: # understanding task
idx_list = findall(input_ids, processor.image_id) # '<image_placeholder>'
new_input_ids, new_labels = [], []
lo = 0
for hi in idx_list:
new_input_ids += input_ids[lo:hi]
if labels is not None:
new_labels += labels[lo:hi]
image_tokens = [processor.image_id] * processor.num_image_tokens
if is_janus:
image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
new_input_ids += image_tokens
new_labels += [-100] * len(image_tokens)
lo = hi + 1
new_input_ids += input_ids[lo:]
if labels is not None:
new_labels += labels[lo:hi]
image_tokens = [processor.image_id] * processor.num_image_tokens
new_labels += labels[lo:]
else:
new_labels = None
if is_janus:
image_tokens = [processor.image_start_id] + image_tokens + [processor.image_end_id]
new_input_ids += image_tokens
new_labels += [-100] * len(image_tokens)
lo = hi + 1
new_input_ids += input_ids[lo:]
if labels is not None:
new_labels += labels[lo:]
from janus.models.processing_vlm import VLChatProcessorOutput
else:
from deepseek_vl.models.processing_vlm import VLChatProcessorOutput

images_outputs = processor.image_processor(images, return_tensors='pt')
output = VLChatProcessorOutput(
sft_format=None,
input_ids=torch.tensor(new_input_ids),
pixel_values=images_outputs.pixel_values,
num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
return encoded

else: # image generation task
if self.is_training:
raise NotImplementedError('Only support the inference of generation of Janus series models.')
sft_format = self.tokenizer.decode(input_ids)
prompt = sft_format + processor.image_start_tag
input_ids = processor.tokenizer.encode(prompt)
input_ids = torch.LongTensor(input_ids)

encoded = {'input_ids': input_ids, 'labels': labels, 'generate_mode': inputs.generate_mode}
return encoded

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not inputs.get('generate_mode'):
inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.config.torch_dtype)
inputs_embeds = model.prepare_inputs_embeds(**inputs)
return {'inputs_embeds': inputs_embeds}
else:
new_labels = None
if is_janus:
from janus.models.processing_vlm import VLChatProcessorOutput
return inputs

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
gene_img_list = [b.get('generate_mode') for b in batch]
if all(gene_img_list):
generate_mode = True
elif not any(gene_img_list):
generate_mode = False
else:
from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
raise NotImplementedError('Do not support understanding and image generation tasks in one batch.')

images_outputs = processor.image_processor(images, return_tensors='pt')
output = VLChatProcessorOutput(
sft_format=None,
input_ids=torch.tensor(new_input_ids),
pixel_values=images_outputs.pixel_values,
num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
encoded = {'output': output, 'input_ids': new_input_ids, 'labels': new_labels}
return encoded
if not generate_mode:
output = self.fetch_inputs(batch, ['output'])['output']
batched_output = dict(self.processor.batchify(output))
res = super()._data_collator(batch, padding_to=padding_to)
return {**batched_output, **res}
else:
res = super()._data_collator(batch, padding_to=padding_to)
res['generate_mode'] = generate_mode
return res

def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
inputs['pixel_values'] = inputs['pixel_values'].to(dtype=self.config.torch_dtype)
inputs_embeds = model.prepare_inputs_embeds(**inputs)
return {'inputs_embeds': inputs_embeds}
def generate(self, model, *args, **kwargs):
if not kwargs.get('generate_mode'):
return model.generate(*args, **kwargs)

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
output = self.fetch_inputs(batch, ['output'])['output']
batched_output = dict(self.processor.batchify(output))
res = super()._data_collator(batch, padding_to=padding_to)
return {**batched_output, **res}
else:
# generate how many number of images for each prompt, it is named parallel_size in the author's code
parallel_size = kwargs['generation_config'].num_return_sequences
temperature = kwargs['generation_config'].temperature
cfg_weight = get_env_args('cfg_weight', float, 5.0)

input_ids = kwargs['input_ids'] # [bsz, max_input_token_num]
bsz, max_input_token_num = input_ids.shape
tokens = torch.zeros((bsz, parallel_size * 2, max_input_token_num),
dtype=torch.int).cuda() # [bsz, parallel_size*2, max_input_token_num]
for i in range(parallel_size * 2):
tokens[:, i, :] = input_ids
if i % 2 != 0:
tokens[:, i, 1:-1] = self.processor.pad_id

inputs_embeds = model.language_model.get_input_embeddings()(
tokens) # [bsz, parallel_size*2, max_input_token_num, 2048]

generated_tokens = torch.zeros(
(bsz, parallel_size, self.image_token_num_per_image),
dtype=torch.int).cuda() # [bsz, 16, image_token_num_per_image] placeholder for the generated tokens

# set the first two dimensions into one dimension for batch size
inputs_embeds = inputs_embeds.reshape(bsz * parallel_size * 2, max_input_token_num, -1)
generated_tokens = generated_tokens.reshape(bsz * parallel_size, self.image_token_num_per_image)

for i in range(self.image_token_num_per_image): # generate the tokens of image in a auto-regression way
outputs = model.language_model.model(
inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=outputs.past_key_values if i != 0 else None)
hidden_states = outputs.last_hidden_state

logits = self.model.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]

logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
probs = torch.softmax(logits / temperature, dim=-1)

next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1) # [parallel_size, self.image_token_num_per_image]

next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = model.prepare_gen_img_embeds(next_token) # [parallel_size * 2, 2048]
inputs_embeds = img_embeds.unsqueeze(dim=1) # [parallel_size * 2, 1, 2048]

# no need to reset the original first two dimensions, waiting for the update of the upper layer
# inputs_embeds = inputs_embeds.reshape(bsz, parallel_size*2, -1)
# generated_tokens = generated_tokens.reshape(bsz, parallel_size, self.image_token_num_per_image)

return {'sequences': generated_tokens}

def decode(self, generate_ids: List[int], is_finished: bool = True, tokenizer_kwargs=None, **kwargs) -> Any:
if not kwargs['template_inputs'].generate_mode:
return super().decode(generate_ids, is_finished, tokenizer_kwargs, **kwargs)

else:
img_size = get_env_args('img_size', int, 384)
patch_size = 16

num_to_decode = 1 # for now, generate_ids is a 1D list

generate_ids = torch.tensor(generate_ids).unsqueeze(0) # [num_to_decode=1, self.image_token_num_per_image]

dec = self.model.gen_vision_model.decode_code(
generate_ids.to(dtype=torch.int),
shape=[num_to_decode, 8, img_size // patch_size, img_size // patch_size])
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) # [num_to_decode, H, W, ch=3]

dec = np.clip((dec + 1) / 2 * 255, 0, 255)

visual_img = np.zeros((num_to_decode, img_size, img_size, 3), dtype=np.uint8)
visual_img[:, :, :] = dec

img_list = []
for i in range(num_to_decode):
cur_img = Image.fromarray(visual_img[i])
img_list.append({'type': 'image', 'image': cur_img})

os.makedirs('generated_images', exist_ok=True)
cur_img.save(os.path.join('generated_images', f'img_{i}.jpg'))

return img_list


@dataclass
Expand Down
32 changes: 32 additions & 0 deletions tests/test_align/test_template/test_gene.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

import torch

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['SWIFT_DEBUG'] = '1'


def test_deepseek_janus_pro_gene():
from swift.llm import infer_main, InferArguments
args = InferArguments(
# model='deepseek-ai/Janus-Pro-1B',
model='/mnt/nas1/.cache/modelscope/hub/deepseek-ai/Janus-Pro-1B',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

路径

infer_backend='pt')
infer_main(args)


def test_emu3_gen(infer_backend):
from swift.llm import infer_main, InferArguments
args = InferArguments(
model='BAAI/Emu3-Gen',
infer_backend=infer_backend,
stream=False,
use_chat_template=False,
top_k=2048,
max_new_tokens=40960)
infer_main(args)


if __name__ == '__main__':
# test_emu3_gen('pt')
test_deepseek_janus_pro_gene()