From 051e3574aa0fc6923a3c6ab9a4405679149f8e8b Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Tue, 2 Sep 2025 17:55:19 +0800 Subject: [PATCH] fix apis for g class --- mindnlp/{factory => }/cli.py | 0 mindnlp/core/_functorch/apis.py | 1 - mindnlp/core/_prims/ascend.py | 8 + mindnlp/core/_prims/meta.py | 14 +- mindnlp/core/_prims/numpy.py | 66 +++++++- mindnlp/core/_tensor.py | 17 +- mindnlp/core/nn/attention/__init__.py | 7 +- mindnlp/core/nn/functional.py | 47 ++++++ mindnlp/core/ops/creation.py | 14 +- mindnlp/core/ops/reduction.py | 2 +- mindnlp/core/special/__init__.py | 4 +- mindnlp/factory/__init__.py | 0 mindnlp/transformers/__init__.py | 18 ++- mindnlp/transformers/masking_utils.py | 150 ++++++++++++++++-- .../transformers/models/gemma3/__init__.py | 36 +++++ 15 files changed, 349 insertions(+), 35 deletions(-) rename mindnlp/{factory => }/cli.py (100%) delete mode 100644 mindnlp/factory/__init__.py create mode 100644 mindnlp/transformers/models/gemma3/__init__.py diff --git a/mindnlp/factory/cli.py b/mindnlp/cli.py similarity index 100% rename from mindnlp/factory/cli.py rename to mindnlp/cli.py diff --git a/mindnlp/core/_functorch/apis.py b/mindnlp/core/_functorch/apis.py index 943849fa2..bce5bb20e 100644 --- a/mindnlp/core/_functorch/apis.py +++ b/mindnlp/core/_functorch/apis.py @@ -1,5 +1,4 @@ from typing import Callable - import mindspore def vmap( diff --git a/mindnlp/core/_prims/ascend.py b/mindnlp/core/_prims/ascend.py index e07834b94..45ff6758f 100644 --- a/mindnlp/core/_prims/ascend.py +++ b/mindnlp/core/_prims/ascend.py @@ -396,3 +396,11 @@ def cross(input, other, dim=None, *, out=None): return pyboost_inner_prim.cross_impl(input, other, dim) __all__.append('cross') + +def logit(input, eps): + if eps is None: + eps = -1.0 + logit_ = _get_cache_prim(ops.Logit)(eps).set_device('Ascend') + return logit_(input) + +__all__.append('logit') diff --git a/mindnlp/core/_prims/meta.py b/mindnlp/core/_prims/meta.py index 396e8471b..bd5098b54 100644 --- a/mindnlp/core/_prims/meta.py +++ b/mindnlp/core/_prims/meta.py @@ -65,7 +65,9 @@ def getitem(input, slice): __all__.append('getitem') def sub_ext(input, other, alpha): - return input + if isinstance(input, core.Tensor): + return input + return other __all__.append('sub_ext') @@ -341,3 +343,13 @@ def reverse_v2(input, dims): return input __all__.append('reverse_v2') + +def rsqrt(input): + return input + +__all__.append('rsqrt') + +def bitwise_xor_tensor(input, other): + return input + +__all__.append('bitwise_xor_tensor') diff --git a/mindnlp/core/_prims/numpy.py b/mindnlp/core/_prims/numpy.py index 8efdfcffb..bcd2a09f4 100644 --- a/mindnlp/core/_prims/numpy.py +++ b/mindnlp/core/_prims/numpy.py @@ -753,10 +753,13 @@ def std(input, dim, correction, keepdim): __all__.append('std') def meshgrid(tensors, indexing): - out = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing) - if not isinstance(out, np.ndarray): - out = np.array(out) - return core.Tensor.from_numpy(out) + outs = np.meshgrid(*[t.numpy() for t in tensors], indexing=indexing) + new_outs = () + for out in outs: + if not isinstance(out, np.ndarray): + out = np.array(out) + new_outs += (core.Tensor.from_numpy(out),) + return new_outs __all__.append('meshgrid') @@ -809,3 +812,58 @@ def reverse_v2(input, dims): return core.Tensor.from_numpy(out) __all__.append('reverse_v2') + +def rsqrt(input): + out = np.reciprocal(np.sqrt(input.numpy())) + if not isinstance(out, np.ndarray): + out = np.array(out) + return core.Tensor.from_numpy(out) + +__all__.append('rsqrt') + +def bitwise_xor_tensor(input, other): + out = np.bitwise_xor(input.numpy(), other.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('bitwise_xor_tensor') + +def minimum(input, other): + out = np.minimum(input.numpy(), other.numpy()) + return core.Tensor.from_numpy(out) + +__all__.append('minimum') + +def prod_ext(input, dim, keepdim, dtype): + out = np.prod(input.numpy(), axis=dim, keepdims=keepdim) + return core.Tensor.from_numpy(out) + +__all__.append('prod_ext') + +def select(condition, input, other): + if not isinstance(input, numbers.Number): + input = input.numpy() + if not isinstance(other, numbers.Number): + other = other.numpy() + + out = np.where(condition.numpy(), input, other) + return core.Tensor.from_numpy(out) + +__all__.append('select') + +def dense(input, weight, bias): + output = np.dot(input.numpy(), weight.numpy().T) + if bias is not None: + output += bias + return core.Tensor.from_numpy(output) + +__all__.append('dense') + +def dropout_ext(input, p): + if p != 0: + mask = (np.random.rand(*input.shape) < (1 - p)) + out = input.numpy() * mask / (1 - p) + return core.Tensor.from_numpy(out), core.Tensor.from_numpy(mask) + else: + return input, None + +__all__.append('dropout_ext') diff --git a/mindnlp/core/_tensor.py b/mindnlp/core/_tensor.py index 81f42d2fc..841b5d5e9 100644 --- a/mindnlp/core/_tensor.py +++ b/mindnlp/core/_tensor.py @@ -196,7 +196,7 @@ def __len__(self): return self.shape[0] def __repr__(self) -> str: - self.data_sync(True) + # self.data_sync(True) return Tensor_.__repr__(self)[:-1] + f', device={self.device})' def __format__(self, format_spec): @@ -982,8 +982,8 @@ def diagnoal(self, offset=0, dim1=0, dim2=1): # Tensor.div - def div(self, other): - return ops.div(self, other) + def div(self, other, rounding_mode=None): + return ops.div(self, other, rounding_mode=rounding_mode) # Tensor.div_ def div_(self, other): @@ -1257,13 +1257,18 @@ def index_add_(self, dim, index, source, *, alpha=1): # Tensor.index_add def index_add(self, dim, index, source, *, alpha=1): - return ops.index_add(self, dim, source, alpha=alpha) + return ops.index_add(self, dim, index, source, alpha=alpha) # Tensor.index_copy_ - + def index_copy_(self, dim, index, tensor2): + return self.copy_(self.index_copy(dim, index, tensor2)) # Tensor.index_copy - + def index_copy(self, dim, index, tensor2): + original_values_at_index = self.index_select(dim, index) + result = self.index_add(dim, index, -original_values_at_index) + result.index_add_(dim, index, tensor2) + return result # Tensor.index_fill_ diff --git a/mindnlp/core/nn/attention/__init__.py b/mindnlp/core/nn/attention/__init__.py index 7d0e25b2d..508d79c18 100644 --- a/mindnlp/core/nn/attention/__init__.py +++ b/mindnlp/core/nn/attention/__init__.py @@ -1,5 +1,8 @@ +import contextlib + class SDPBackend: - pass + MATH = 0 +@contextlib.contextmanager def sdpa_kernel(*args, **kwargs): - pass + yield {} \ No newline at end of file diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index 638457a67..976066d93 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -1590,9 +1590,56 @@ def pixel_shuffle(input, upscale_factor): def pixel_unshuffle(input, downscale_factor): return ops.pixel_unshuffle(input, downscale_factor) +def getWH(input): + """Get [W, H] tensor from input""" + H, W = input.size()[-2:] + return core.tensor([[W, H]], dtype=core.float32, device=input.device) + +def center_of(input): + """return [(W-1)/2, (H-1)/2] tensor of input img""" + if input.dim() == 4: + H, W = input.size()[-2:] + shape = [[W, H]] + else: + D, H, W = input.size()[-3:] + shape = [[W, H, D]] + return core.tensor(shape, dtype=core.float32, device=input.device).sub_(1).div_(2) + +def u(s, a: float = -0.75): + s2, s3 = s**2, s**3 + l1 = (a+2)*s3 - (a+3)*s2 + 1 + l2 = a*s3 - (5*a)*s2 + (8*a)*s - 4*a + return l1.where(s <= 1, l2) + +def bicubic_grid_sample(input, grid, padding_mode: str = 'zeros', align_corners: bool = False): + """bicubic_grid_sample""" + kernel_size = 4 + if not align_corners: + grid = grid * getWH(input) / getWH(input).sub_(1) + center = center_of(input) + abs_loc = ((grid + 1) * center).unsqueeze(-1) + + locs = abs_loc.floor() + core.tensor([-1, 0, 1, 2], device=grid.device) + + loc_w, loc_h = locs.detach().flatten(0, 2).unbind(dim=-2) + loc_w = loc_w.reshape(-1, 1, kernel_size).expand(-1, kernel_size, -1) + loc_h = loc_h.reshape(-1, kernel_size, 1).expand(-1, -1, kernel_size) + loc_grid = core.stack([loc_w, loc_h], dim=-1) + loc_grid = loc_grid.view(grid.size(0), -1, 1, 2)/center - 1 + + selected = grid_sample(input, loc_grid.detach(), mode='nearest', + padding_mode=padding_mode, align_corners=True) + patch = selected.view(input.size()[:2]+grid.size()[1:3]+(kernel_size,)*2) + + mat_r, mat_l = u(core.abs(abs_loc - locs.detach())).unbind(dim=-2) + output = core.einsum('bhwl,bchwlr,bhwr->bchw', mat_l, patch, mat_r) + return output + def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): align_corners = False if align_corners is None else align_corners if input.ndim == 4: + if mode == 'bicubic': + return bicubic_grid_sample(input, grid, padding_mode, align_corners) return execute('grid_sampler_2d', input, grid, mode, padding_mode, align_corners) return execute('grid_sampler_3d', input, grid, mode, padding_mode, align_corners) diff --git a/mindnlp/core/ops/creation.py b/mindnlp/core/ops/creation.py index 3348df4c3..f598df76a 100644 --- a/mindnlp/core/ops/creation.py +++ b/mindnlp/core/ops/creation.py @@ -39,9 +39,9 @@ def zeros(*size, out=None, dtype=None, layout=None, device=None, requires_grad=F if isinstance(device, str): device = core.device(device) - if isinstance(size[0], (tuple, list)): + if len(size) > 0 and isinstance(size[0], (tuple, list)): size = size[0] - + new_size = () for s in size: if not isinstance(s, int): @@ -139,6 +139,10 @@ def linspace(start, end, steps, *, out=None, dtype=None, layout=None, device=Non if isinstance(device, str): device = core.device(device) + start = start.item() if isinstance(start, (core.Tensor, np.integer)) else start + end = end.item() if isinstance(end, (core.Tensor, np.integer)) else end + steps = steps.item() if isinstance(steps, (core.Tensor, np.integer)) else steps + output = execute('lin_space_ext', start, end, steps, dtype, device=device, requires_grad=requires_grad, user_created=True) if out is None: @@ -154,6 +158,8 @@ def eye(n, m=None, *, out=None, dtype=None, layout=None, device=None, requires_g device = get_device_in_context() if dtype is None: dtype = get_default_dtype() + if m is None: + m = n output = execute('eye', n, m, dtype, device=device, requires_grad=requires_grad, user_created=True) if out is None: @@ -194,8 +200,8 @@ def empty_like(input, *, dtype=None, layout=None, device=None, requires_grad=Fal # full def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): - if dtype is None: - dtype = get_default_dtype() + # if dtype is None: + # dtype = get_default_dtype() if device is None: device = get_device_in_context() if device.type == 'cpu': diff --git a/mindnlp/core/ops/reduction.py b/mindnlp/core/ops/reduction.py index 5b0501bb4..db15f2bb3 100644 --- a/mindnlp/core/ops/reduction.py +++ b/mindnlp/core/ops/reduction.py @@ -152,7 +152,7 @@ def nansum(input, dim=None, keepdim=False, *, dtype=None): # prod def prod(input, dim=None, keepdim=False, *, dtype=None): - return execute('prod_ext', input, dim, keepdim,dtype) + return execute('prod_ext', input, dim, keepdim, dtype) # quantile diff --git a/mindnlp/core/special/__init__.py b/mindnlp/core/special/__init__.py index fc5983c91..36a850ef4 100644 --- a/mindnlp/core/special/__init__.py +++ b/mindnlp/core/special/__init__.py @@ -1,4 +1,4 @@ -from mindspore import ops +from ..executor import execute def logit(input, eps=None, *, out=None): - return ops.logit(input, eps) + return execute('logit', input, eps) diff --git a/mindnlp/factory/__init__.py b/mindnlp/factory/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py index 752c6214e..a815632f7 100644 --- a/mindnlp/transformers/__init__.py +++ b/mindnlp/transformers/__init__.py @@ -1,4 +1,5 @@ import sys +from packaging import version from mindnlp.core.configs import ON_ORANGE_PI from mindnlp.utils.import_utils import * from mindnlp.utils.import_utils import _LazyModule @@ -4102,12 +4103,13 @@ from . import ms_utils -from .masking_utils import create_causal_mask, create_sliding_window_causal_mask +from .masking_utils import create_causal_mask, create_sliding_window_causal_mask, create_masks_for_generate from .modeling_utils import construct_pipeline_parallel_model, _load_pretrained_model_wrapper, \ _get_resolved_checkpoint_files_wrapper from .tokenization_utils import apply_chat_template_wrapper from .trainer import training_step from .generation import * +from .models.gemma3 import token_type_ids_mask_function # redirect mindnlp.transformers to transformers import transformers @@ -4134,9 +4136,14 @@ def empty_fn(*args, **kwargs): from ..utils.decorators import dtype_wrapper, patch_dtype_wrapper, patch_wrappers patch_dtype_wrapper(transformers.AutoModel, 'from_pretrained') -patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained', - [transformers.modeling_utils.restore_default_torch_dtype] - ) +if version.parse(transformers.__version__) >= version.parse('4.56.0'): + patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained', + [transformers.modeling_utils.restore_default_dtype] + ) +else: + patch_dtype_wrapper(transformers.modeling_utils.PreTrainedModel, 'from_pretrained', + [transformers.modeling_utils.restore_default_torch_dtype] + ) patch_wrappers(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', [_load_pretrained_model_wrapper]) @@ -4152,6 +4159,8 @@ def empty_fn(*args, **kwargs): transformers.modeling_utils.caching_allocator_warmup = empty_fn transformers.masking_utils.create_causal_mask = create_causal_mask transformers.masking_utils.create_sliding_window_causal_mask = create_sliding_window_causal_mask +transformers.masking_utils.create_masks_for_generate = create_masks_for_generate +transformers.generation.utils.create_masks_for_generate = create_masks_for_generate transformers.trainer.Trainer.training_step = training_step # for ORANGE_PI @@ -4160,3 +4169,4 @@ def empty_fn(*args, **kwargs): # add mindnlp.transformers modules/attrs to lazymodule # setattr(sys.modules[__name__], 'test_ms_model', test_ms_model) +transformers.models.gemma3.modeling_gemma3.token_type_ids_mask_function = token_type_ids_mask_function diff --git a/mindnlp/transformers/masking_utils.py b/mindnlp/transformers/masking_utils.py index 8440b60a5..74ac4e0d2 100644 --- a/mindnlp/transformers/masking_utils.py +++ b/mindnlp/transformers/masking_utils.py @@ -21,8 +21,7 @@ # Register a fake type to avoid crashing for annotations and `isinstance` checks BlockMask = core.Tensor -_is_torch_greater_or_equal_than_2_6 = False - +_is_torch_greater_or_equal_than_2_6 = True def and_masks(*mask_functions: list[Callable]) -> Callable: """Returns a mask function that is the intersection of provided mask functions""" @@ -32,12 +31,11 @@ def and_masks(*mask_functions: list[Callable]) -> Callable: def and_mask(batch_idx, head_idx, q_idx, kv_idx): result = q_idx.new_ones((), dtype=core.bool) for mask in mask_functions: - result = result & mask(batch_idx, head_idx, q_idx, kv_idx) + result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device) return result return and_mask - def or_masks(*mask_functions: list[Callable]) -> Callable: """Returns a mask function that is the union of provided mask functions""" if not all(callable(arg) for arg in mask_functions): @@ -46,12 +44,11 @@ def or_masks(*mask_functions: list[Callable]) -> Callable: def or_mask(batch_idx, head_idx, q_idx, kv_idx): result = q_idx.new_zeros((), dtype=core.bool) for mask in mask_functions: - result = result | mask(batch_idx, head_idx, q_idx, kv_idx) + result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device) return result return or_mask - def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: """ This creates a basic lower-diagonal causal mask. @@ -110,14 +107,13 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask - def packed_sequence_mask_function(packed_sequence_mask: core.Tensor) -> Callable: """ This return the mask_function function corresponding to a 2D packed sequence mask. """ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: - return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx] + return packed_sequence_mask[:, q_idx] == packed_sequence_mask[:, kv_idx] return inner_mask @@ -221,6 +217,136 @@ def _ignore_causal_mask_sdpa( return False +def sdpa_mask_recent_torch( + batch_size: int, + cache_position: core.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[core.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + **kwargs, +) -> Optional[core.Tensor]: + """ + Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that + the element should take part in the attention computation, and False that it should not. + This function can only be used with torch>=2.5, as the context manager is otherwise not available. + + Args: + batch_size (`int`): + The batch size of the input sequence. + cache_position (`torch.Tensor`): + A tensor of shape (query_length,) indicating the current indices of the input sequence elements. + kv_length (`int`): + The size that the key and value states will have during the attention computation. + kv_offset (`int`, optional): + An optional offset to indicate at which first position the key and values states will refer to. + mask_function (`Callable`): + The mask factory function describing the mask pattern. + attention_mask (`torch.Tensor`, optional): + The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) + local_size (`int`, optional): + The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` + to try to skip mask creation if possible. + allow_is_causal_skip (`bool`, optional): + Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in + `torch.sdpa` instead. Default to `True`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + + + ## Creating a simple causal mask: + + To create the following causal mask: + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ■ ■ ■ ■ ⬚ + 4 ■ ■ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [ True, True, True, True, False], + [ True, True, True, True, True]]]]) + ``` + + ## Creating a sliding window mask: + + To create the following sliding window mask (`sliding_window=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ■ ■ ■ ⬚ + 4 ⬚ ⬚ ■ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=sliding_window_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True]]]]) + ``` + + ## Creating a chunked attention mask + + To create the following chunked attention mask (`chunk_size=3`): + + 0 ■ ⬚ ⬚ ⬚ ⬚ + 1 ■ ■ ⬚ ⬚ ⬚ + 2 ■ ■ ■ ⬚ ⬚ + 3 ⬚ ⬚ ⬚ ■ ⬚ + 4 ⬚ ⬚ ⬚ ■ ■ + + You can do + + ```python + >>> create_4d_causal_mask(batch_size=1, cache_position=torch.arange(5), kv_length=5, mask_function=chunked_causal_mask_function(3)) + >>> tensor([[[[ True, False, False, False, False], + [ True, True, False, False, False], + [ True, True, True, False, False], + [False, False, False, True, False], + [False, False, False, True, True]]]]) + ``` + + """ + q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + + # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument + if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): + return None + + # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = core.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + + # Potentially add the padding 2D mask + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + batch_arange = core.arange(batch_size, device=cache_position.device) + head_arange = core.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + return causal_mask + def sdpa_mask_older_torch( batch_size: int, @@ -283,9 +409,14 @@ def sdpa_mask_older_torch( # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have # `sdpa_mask_recent_torch`, as it allows more general `mask_function` + # causal_mask = mask_function(None, None, cache_position, kv_arange) causal_mask = mask_function(None, None, cache_position.reshape(cache_position.shape[0], 1), kv_arange.reshape(1, kv_arange.shape[0])) # causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + if causal_mask.ndim == 2: + causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + elif causal_mask.ndim == 3: + causal_mask = causal_mask[:, None, :, :].expand(batch_size, -1, -1, -1) + if padding_mask is not None: causal_mask = causal_mask * padding_mask[:, None, None, :] @@ -300,7 +431,6 @@ def sdpa_mask_older_torch( # (especially mask_function indexing a tensor, such as the padding mask function) sdpa_mask = sdpa_mask_older_torch - def eager_mask( batch_size: int, cache_position: core.Tensor, diff --git a/mindnlp/transformers/models/gemma3/__init__.py b/mindnlp/transformers/models/gemma3/__init__.py new file mode 100644 index 000000000..62e92bc4e --- /dev/null +++ b/mindnlp/transformers/models/gemma3/__init__.py @@ -0,0 +1,36 @@ +from typing import Optional, Callable +import mindspore +from mindspore import ops, nn, mint +from mindnlp import core + +def token_type_ids_mask_function( + token_type_ids: Optional[core.Tensor], + image_group_ids: Optional[core.Tensor], + tokens_per_image: int, +) -> Optional[Callable]: + """ + This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, + not start and end indices. + """ + # Do not return an additional mask in this case + if token_type_ids is None: + return None + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + # If it's 1 for both query and key/value, we are in an image block + # NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length + # Since vmap doesn't support `if statement` we workaround it with `torch.where` + safe_idx = core.where(kv_idx < token_type_ids.shape[1], kv_idx, 0) + token_type_ids_at_kv_idx = token_type_ids[:, safe_idx] + token_type_ids_at_kv_idx = core.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0) + + image_group_ids_at_kv_idx = image_group_ids[:, safe_idx] + image_group_ids_at_kv_idx = core.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1) + + is_image_block = (token_type_ids[:, q_idx] == 1) & (token_type_ids_at_kv_idx == 1) + same_image_block = image_group_ids[:, q_idx] == image_group_ids_at_kv_idx + + # This is bidirectional attention whenever we are dealing with image tokens + return is_image_block & same_image_block + + return inner_mask \ No newline at end of file