Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def data(self, new_value):
Tensor.data = data
StubTensor.data = data

Tensor.narrow = ops.narrow
StubTensor.narrow = ops.narrow


def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
return ret
48 changes: 0 additions & 48 deletions mindnlp/core/dispatcher.py

This file was deleted.

3 changes: 0 additions & 3 deletions mindnlp/core/distributed/c10d/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from typing import List, Optional, Dict, Any
from enum import Enum

from mindnlp.core.executor import execute


class BackendType(Enum):
UNDEFINED = 0
GLOO = 1
Expand Down
41 changes: 0 additions & 41 deletions mindnlp/core/executor.py

This file was deleted.

41 changes: 13 additions & 28 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
from typing import Optional, Tuple, List
import numpy as np
from mindspore import ops, mint
from mindspore.ops.auto_generate.gen_arg_handler import dtype_to_type_id
from mindspore.common.generator import default_generator
from mindspore.ops._primitive_cache import _get_cache_prim

from mindnlp import core
from mindnlp.core.executor import execute
from ..configs import DEVICE_TARGET, ON_ORANGE_PI, use_pyboost, ON_A1

generator_step_ = 12
Expand Down Expand Up @@ -237,28 +234,6 @@ def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0):
query, key, cos, sin, position_ids, cos_format
)

def _reflection_pad(input, pad):
"""reflection pad"""
out = input
if len(pad) == 2:
out = execute('reflection_pad_1d', input, pad)
elif len(pad) == 4:
out = execute('reflection_pad_2d', input, pad)
else:
out = execute('reflection_pad_3d', input, pad)
return out

def _replication_pad(input, pad):
"""replication pad"""
out = input
if len(pad) == 2:
out = execute('replication_pad_1d', input, pad)
elif len(pad) == 4:
out = execute('replication_pad_2d', input, pad)
else:
out = execute('replication_pad_3d', input, pad)
return out

def pad(input, pad, mode='constant', value=0.0):
if sum(pad) == 0:
return input
Expand All @@ -268,7 +243,16 @@ def pad(input, pad, mode='constant', value=0.0):
return mint.nn.functional.pad(input, pad, mode, value)
if mode in ['reflect', 'circular']:
return ops.pad(input, pad, mode)
return ops.pad(input, pad, mode, value)
new_pad = ()
for idx, pad_v in enumerate(pad):
if pad_v < 0:
dim = idx // 2
input = input.narrow(dim, 0, input.shape[dim] + pad_v)
pad_v = 0
new_pad += (pad_v,)
if sum(new_pad) == 0:
return input
return ops.pad(input, new_pad, mode, value)

def nll_loss(input, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
return _inner_nll_loss(input, target, weight, ignore_index, reduction, label_smoothing)
Expand Down Expand Up @@ -656,8 +640,9 @@ def scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_cau
query = query / scaling_factor

if is_causal:
L = query.shape[-2], S = key.shape[-2]
attn_mask = ops.ones((L, S), mindspore.bool_).tril()
L = query.shape[-2]
S = key.shape[-2]
attn_mask = ops.ones((L, S), core.bool_).tril()

attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor)
if attn_mask is not None:
Expand Down
10 changes: 8 additions & 2 deletions tests/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@ def run_tests():
"""
# 获取命令行参数(排除脚本名本身)
pytest_args = sys.argv[1:]

skip_ut = "not sdpa " \
"and not headmasking " \
"and not gradient_checkpointing " \
"and not retain_grad " \
"and not data_parallel"

pytest_args.extend(['-k', skip_ut])
if not pytest_args:
print("未提供参数,默认运行当前目录下所有测试")
print("使用示例: python run_test.py -v tests/")

# 执行测试并获取退出码
exit_code = pytest.main(pytest_args)

Expand Down
Loading