Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffusion/janus/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
conversations=conversation, images=pil_images, force_batchify=True
).to(vl_gpt.device)


# # run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

# # run the model to get the response
outputs = vl_gpt.language_model.generate(
inputs_embeds=inputs_embeds,
Expand Down
114 changes: 84 additions & 30 deletions mindnlp/core/_functorch/apis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import functools
from typing import Callable
from mindnlp import core

from .vmap import (
_check_randomness_arg,
vmap_impl,
)

def vmap(
func: Callable,
Expand All @@ -9,43 +14,92 @@ def vmap(
*,
chunk_size=None,
) -> Callable:
# from torch.compiler import is_compiling

# _check_randomness_arg(randomness)
# if not (chunk_size is None or chunk_size > 0):
# raise ValueError(
# f"vmap: chunk_size should be None or greater than 0. (got {chunk_size})"
# )

# def wrapped(*args, **kwargs):
# return vmap_impl(
# func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs
# )

# if not is_compiling():
# wrapped = functools.wraps(func)(wrapped)

# return wrapped
def batched_func(*args):
# 统一处理in_dims格式
if not isinstance(in_dims, tuple):
# 标准化输入维度参数
if isinstance(in_dims, int):
in_dims_tuple = (in_dims,) * len(args)
else:
in_dims_tuple = in_dims

# 验证输入维度一致性
batch_sizes = set()
for i, (arg, dim) in enumerate(zip(args, in_dims_tuple)):
if len(in_dims_tuple) != len(args):
raise ValueError(f"输入的in_dims长度({len(in_dims_tuple)})与参数数量({len(args)})不匹配")

# 识别并验证批处理大小
batch_size = None
for i, dim in enumerate(in_dims_tuple):
if dim is not None:
batch_sizes.add(arg.shape[dim])
if batch_size is None:
batch_size = args[i].shape[dim]
elif args[i].shape[dim] != batch_size:
raise ValueError(f"不一致的批处理大小: "
f"参数 {i} 有大小 {args[i].shape[dim]}, "
f"期望 {batch_size}")

# 如果没有批处理维度,设置批处理大小为1
if batch_size is None:
batch_size = 1

# 重新排列所有输入,使批处理维度位于第0位
reordered_args = []
reshaped_shapes = []

if len(batch_sizes) > 1:
raise ValueError(f"不一致的批处理大小: {batch_sizes}")
batch_size = next(iter(batch_sizes)) if batch_sizes else 1
for arg, dim in zip(args, in_dims_tuple):
if dim is None:
# 无批处理维度:添加伪批处理维度并进行广播
# 使用unsqueeze而不是expand来保持梯度
expanded = arg.unsqueeze(0)
if batch_size > 1:
expanded = expanded.expand(batch_size, *[-1]*arg.ndim)
reordered_args.append(expanded)
reshaped_shapes.append(None) # 标记为无原始维度
else:
# 有批处理维度:将其移动到维度0
# 创建新维度顺序: [dim, 0, 1, ..., dim-1, dim+1, ...]
dims_order = [dim] + [d for d in range(arg.ndim) if d != dim]
permuted = arg.permute(*dims_order)
reordered_args.append(permuted)
reshaped_shapes.append(arg.shape) # 保存原始形状

# 处理函数可能返回元组的情况
result = func(*reordered_args)

# 收集单个样本的结果
results = []
for b in range(batch_size):
# 为当前批次构造输入
single_args = []
for arg, dim in zip(args, in_dims_tuple):
if dim is None:
single_args.append(arg)
else:
# 切片获取当前批次的样本
slices = [slice(None)] * arg.ndim
slices[dim] = b
single_args.append(arg[tuple(slices)])
# 调整输出维度的函数
def adjust_out_dims(tensor):
if tensor.size(0) != batch_size:
# 如果函数返回了标量或没有批处理维度
return tensor.unsqueeze(out_dims).expand(
*[batch_size if i == out_dims else -1
for i in range(tensor.ndim + 1)]
)

# 调用原始函数
result = func(*single_args)
results.append(result)
if out_dims == 0:
return tensor

# 创建将维度0移动到out_dims位置的新顺序
new_order = list(range(1, tensor.ndim))
new_order.insert(out_dims, 0)
return tensor.permute(*new_order)

# 堆叠结果并调整维度
stacked = core.stack(results, dim=out_dims)
return stacked
# 处理不同类型输出
if isinstance(result, tuple):
return tuple(adjust_out_dims(r) for r in result)
else:
return adjust_out_dims(result)

return batched_func
Loading
Loading