Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ tests/transformers/
tests/huggingface_transformers/
.gradio/

huanhuan.json
huanhuan.json
pytorch/
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import mindnlp
from mindnlp.core import distributed as dist
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "openai/gpt-oss-20b"
Expand Down
23,662 changes: 19,954 additions & 3,708 deletions examples/transformers/peft/lora/Qwen2.5-7B-Instruct-Lora.ipynb

Large diffs are not rendered by default.

84 changes: 81 additions & 3 deletions mindnlp/core/_C/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from typing import Any
from mindspore import Generator as msGenerator
import mindspore

from mindnlp import core
from . import _nn
from ..types import device as device_
from ..configs import DEVICE_TARGET

DEVICE_MAP = {
'GPU': 'cuda',
'Ascend': 'npu',
'CPU': 'cpu'
}


def _jit_set_profiling_executor(mode):
pass

Expand All @@ -30,6 +39,72 @@ def _debug_set_autodiff_subgraph_inlining(mode):

DisableTorchFunctionSubclass = None


class device():
def __init__(self, type=None, index=None):
if type is not None:
if isinstance(type, str):
if ':' in type:
if index is not None:
raise ValueError("`type` must not include an index because index was "
f"passed explicitly: {type}")
_target, _id = type.split(':')
_id = int(_id)
else:
_target = type
_id = None if _target == 'cpu' else 0
elif isinstance(type, device):
if index is not None:
raise ValueError("core.device(): When input is core.device, `index` can not be set.")
_target = type.type
_id = type.index
elif isinstance(type, int):
_id = type
try:
device_target = mindspore.get_current_device().device_target
except:
device_target = mindspore.get_context('device_target')
_target = DEVICE_MAP[device_target]
else:
print(type)
raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.")
else:
raise ValueError("core.device(): `type` can not be None")

self.type = _target
self.index = _id
if DEVICE_TARGET == 'Ascned' and self.type == 'cuda':
self.type = 'npu'

def __repr__(self):
if self.index is None:
return f"device(type={self.type})"
return f"device(type={self.type}, index={self.index})"

def __eq__(self, __value):
if not isinstance(__value, device):
return False
return hash(self) == hash(__value)

def __hash__(self):
return hash(self.type) ^ hash(self.index)

def __gt__(self, other):
if self.type == 'cpu':
return False
return True

def __enter__(self):
# self.prev_idx = torch.cuda._exchange_device(self.idx)
core._bind.set_device_in_context(self)

def __exit__(self, type: Any, value: Any, traceback: Any):
# self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
core._bind.set_device_in_context(None)
return False

device_ = device

class Generator(msGenerator):
def __init__(self, device='cpu'):
super().__init__()
Expand All @@ -41,11 +116,14 @@ def __init__(self, device='cpu'):
def device(self):
if hasattr(self, '_device'):
return self._device
return device_('cpu')
return device('cpu')

default_generator = Generator()

class Tag: pass

def _log_api_usage_once(*args):
pass
pass

ScriptDict = dict
ScriptList = list
3 changes: 1 addition & 2 deletions mindnlp/core/_C/_nn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from mindnlp import core
from ..types import device as device_

def _parse_to(*args, **kwargs):
"""
Expand All @@ -22,7 +21,7 @@ def _parse_to(*args, **kwargs):
device = args[0]
dtype = None
elif isinstance(args[0], (str, int)):
device = device_(args[0])
device = core.device(args[0])
dtype = None
else:
raise TypeError(f"Expected core.dtype or core.device, but got {type(args[0])}")
Expand Down
8 changes: 4 additions & 4 deletions mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,16 @@
preserve_format = None
legacy_contiguous_format = None
channels_last_3d = None
memory_format = None

inf = float("inf")
nan = float("nan")

from ._C import *
from ._dtype import *
from ._tensor import Tensor, tensor, is_tensor, \
LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor, IntTensor
from .types import device
from ._C import *
from ._C.size import Size
from .types import device
from .autograd import *
from .ops import *
from .serialization import load, save
Expand All @@ -57,8 +56,9 @@
from .func import vmap
from .configs import set_pyboost

from . import _dynamo
from . import profiler, cuda, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx, backends, testing, nn, fft, _jit_internal, utils, optim
return_types, linalg, fx, backends, nn, fft, _jit_internal, utils, optim, testing
from ._lowrank import svd_lowrank
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state

Expand Down
2 changes: 1 addition & 1 deletion mindnlp/core/_bind.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ctypes
from typing import Any
from ._dtype import *
from .types import device as device_
from ._C import device as device_
from .configs import ON_A1

DEFAULT_DTYPE, DEFAULT_DEVICE = float32, device_('cpu')
Expand Down
5 changes: 5 additions & 0 deletions mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def __gt__(self, other):

float8_e4m3fn = None # TODO: not support fp8 for now
float8_e5m2 = None
float8_e4m3fnuz = None
float8_e5m2fnuz = None
complex32 = None
cfloat = complex32
cdouble = complex64

uint1 = None
uint2 = None
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/_dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

from . import eval_frame
# from . import config

def reset():
pass
Loading
Loading