From fd445beff6861067c7ff93fd9a6e8f82efdb8aeb Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Thu, 17 Jul 2025 18:52:29 +0800 Subject: [PATCH] fix lora training --- .../roberta_sequence_classification.ipynb | 271 ++++-------------- .../lora/roberta_sequence_classification.py | 161 ++++++++--- mindnlp/core/__init__.py | 4 +- mindnlp/core/autograd/function.py | 1 + mindnlp/utils/safetensors_patch.py | 17 +- mindnlp/utils/torch_proxy.py | 209 +++++++------- 6 files changed, 283 insertions(+), 380 deletions(-) diff --git a/llm/peft/lora/roberta_sequence_classification.ipynb b/llm/peft/lora/roberta_sequence_classification.ipynb index 37189f961..4414f67b5 100644 --- a/llm/peft/lora/roberta_sequence_classification.ipynb +++ b/llm/peft/lora/roberta_sequence_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "9ff5004e", "metadata": {}, "outputs": [ @@ -10,24 +10,31 @@ "name": "stderr", "output_type": "stream", "text": [ - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:27.976.434 [mindspore/run_check/_check_version.py:357] MindSpore version 2.3.1 and Ascend AI software package (Ascend Data Center Solution)version 7.5 does not match, the version of software package expect one of ['7.2', '7.3']. Please refer to the match info on: https://www.mindspore.cn/install\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", " return self._float_to_str(self.smallest_subnormal)\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for type is zero.\n", " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", " return self._float_to_str(self.smallest_subnormal)\n", - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:30.133.034 [mindspore/run_check/_check_version.py:375] MindSpore version 2.3.1 and \"te\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n", - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:30.135.047 [mindspore/run_check/_check_version.py:382] MindSpore version 2.3.1 and \"hccl\" wheel package version 7.5 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n", - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:30.135.636 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 3\n", - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:31.137.404 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 2\n", - "[WARNING] ME(4008793:281473599340576,MainProcess):2024-09-25-15:21:32.139.464 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 1\n", - "Building prefix dict from the default dictionary ...\n", - "Loading model from cache /tmp/jieba.cache\n", - "Loading model cost 0.924 seconds.\n", - "Prefix dict has been built successfully.\n" + "[WARNING] ME(2453836:281473813987360,MainProcess):2025-07-17-17:13:08.294.000 [mindspore/context.py:1335] For 'context.set_context', the parameter 'ascend_config' will be deprecated and removed in a future version. Please use the api mindspore.device_context.ascend.op_precision.precision_mode(),\n", + " mindspore.device_context.ascend.op_precision.op_precision_mode(),\n", + " mindspore.device_context.ascend.op_precision.matmul_allow_hf32(),\n", + " mindspore.device_context.ascend.op_precision.conv_allow_hf32(),\n", + " mindspore.device_context.ascend.op_tuning.op_compile() instead.\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/transformers/utils/generic.py:496: FutureWarning: `core.utils._pytree._register_pytree_node` is deprecated. Please use `core.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/transformers/utils/generic.py:353: FutureWarning: `core.utils._pytree._register_pytree_node` is deprecated. Please use `core.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/transformers/utils/generic.py:353: FutureWarning: `core.utils._pytree._register_pytree_node` is deprecated. Please use `core.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/transformers/utils/generic.py:496: FutureWarning: `core.utils._pytree._register_pytree_node` is deprecated. Please use `core.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n", + "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.11/site-packages/transformers/utils/generic.py:353: FutureWarning: `core.utils._pytree._register_pytree_node` is deprecated. Please use `core.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n" ] } ], @@ -44,7 +51,7 @@ "from mindnlp.dataset import load_dataset\n", "from mindnlp.engine import set_seed\n", "from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer\n", - "from mindnlp.common.optimization import get_linear_schedule_with_warmup\n", + "from mindnlp.transformers.optimization import get_linear_schedule_with_warmup\n", "from mindnlp.peft import (\n", " get_peft_config,\n", " get_peft_model,\n", @@ -57,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "e32c4a9e", "metadata": {}, "outputs": [], @@ -71,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "622fe9c8", "metadata": {}, "outputs": [], @@ -82,19 +89,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "74e9efe0", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/lvyufeng/miniconda3/envs/mindspore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", " padding_side = \"left\"\n", @@ -108,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "41a63e71-e7c4-4e5d-9e22-6953d981d4b8", "metadata": {}, "outputs": [ @@ -127,7 +125,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "bd2d7cd5-62b8-4b7a-ac69-338e6319152e", "metadata": {}, "outputs": [], @@ -136,7 +134,7 @@ "\n", "class MapFunc(BaseMapFunction):\n", " def __call__(self, sentence1, sentence2, label, idx):\n", - " outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)\n", + " outputs = tokenizer(str(sentence1), str(sentence2), truncation=True, max_length=None)\n", " return outputs['input_ids'], outputs['attention_mask'], label\n", "\n", "\n", @@ -155,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "3b1fd5fc-2285-409e-a4e5-cc3c9759d77a", "metadata": {}, "outputs": [ @@ -188,7 +186,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "efb606a2-1fb5-415c-bf12-7e6fd324fe0a", "metadata": { "scrolled": true @@ -200,7 +198,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "a3c15af0", "metadata": {}, "outputs": [ @@ -215,6 +213,10 @@ "name": "stderr", "output_type": "stream", "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] @@ -223,22 +225,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "trainable params: 1,838,082 || all params: 357,199,876 || trainable%: 0.5145808057335384\n" + "trainable params: 1,838,082 || all params: 357,199,876 || trainable%: 0.5146\n" ] } ], "source": [ - "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n", + "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, attn_implementation='eager')\n", "model = get_peft_model(model, peft_config)\n", "model.print_trainable_parameters()" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "6d3c5edb", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \n" + ] + } + ], "source": [ "optimizer = AdamW(params=tuple(param for param in model.parameters() if param.requires_grad), lr=lr)\n", "\n", @@ -260,186 +270,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 1%|█ | 1/115 [00:08<16:03, 8.45s/it]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\\\r" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:10<00:00, 1.63it/s]\n", - " 15%|█████████████████▊ | 2/13 [00:00<00:03, 2.89it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "|\r" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:02<00:00, 6.16it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 0: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.85it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 6.96it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 1: {'accuracy': 0.7426470588235294, 'f1': 0.7741935483870968}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.82it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 7.11it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 2: {'accuracy': 0.8872549019607843, 'f1': 0.9181494661921709}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:41<00:00, 2.80it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 6.99it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 3: {'accuracy': 0.8799019607843137, 'f1': 0.9094269870609981}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.86it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 7.12it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 4: {'accuracy': 0.8848039215686274, 'f1': 0.9124767225325885}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:39<00:00, 2.90it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 7.18it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 5: {'accuracy': 0.8921568627450981, 'f1': 0.9191176470588235}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.84it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 7.09it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 6: {'accuracy': 0.9019607843137255, 'f1': 0.9295774647887324}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:43<00:00, 2.64it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 6.92it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 7: {'accuracy': 0.8970588235294118, 'f1': 0.9273356401384083}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.83it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 7.05it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 8: {'accuracy': 0.9142156862745098, 'f1': 0.9369369369369369}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00, 2.86it/s]\n", - "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 6.97it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "epoch 9: {'accuracy': 0.8897058823529411, 'f1': 0.9194991055456172}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:42<00:00, 2.68it/s]\n", - " 8%|████████▉ | 1/13 [00:00<00:04, 3.00it/s]" + " 0%| | 0/115 [00:00 use the model max length (it's actually the default) - outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) - return outputs +# In[6]: -tokenized_datasets = datasets.map( - tokenize_function, - batched=True, - remove_columns=["idx", "sentence1", "sentence2"], -) +from mindnlp.dataset import BaseMapFunction -# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the -# transformers library -tokenized_datasets = tokenized_datasets.rename_column("label", "labels") +class MapFunc(BaseMapFunction): + def __call__(self, sentence1, sentence2, label, idx): + outputs = tokenizer(str(sentence1), str(sentence2), truncation=True, max_length=None) + return outputs['input_ids'], outputs['attention_mask'], label -def collate_fn(examples): - return tokenizer.pad(examples, padding="longest", return_tensors="pt") +def get_dataset(dataset, tokenizer): + input_colums=['sentence1', 'sentence2', 'label', 'idx'] + output_columns=['input_ids', 'attention_mask', 'labels'] + dataset = dataset.map(MapFunc(input_colums, output_columns), + input_colums, output_columns) + dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id), + 'attention_mask': (None, 0)}) + return dataset +train_dataset = get_dataset(datasets['train'], tokenizer) +eval_dataset = get_dataset(datasets['validation'], tokenizer) -# Instantiate dataloaders. -train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size) -eval_dataloader = DataLoader( - tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size -) -model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True) +# In[7]: + + +print(next(train_dataset.create_dict_iterator())) + + +# In[8]: + + +metric = evaluate.load("glue", task) + + +# In[9]: + +model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, attn_implementation='eager') model = get_peft_model(model, peft_config) model.print_trainable_parameters() -model \ No newline at end of file + + +# In[10]: + + +optimizer = AdamW(params=tuple(param for param in model.parameters() if param.requires_grad), lr=lr) + +# Instantiate scheduler +lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0.06 * (len(train_dataset) * num_epochs), + num_training_steps=(len(train_dataset) * num_epochs), +) + + +# In[ ]: + + +from mindnlp.core import value_and_grad +def forward_fn(**batch): + outputs = model(**batch) + loss = outputs.loss + return loss + +grad_fn = value_and_grad(forward_fn, tuple(param for param in model.parameters() if param.requires_grad)) + +for epoch in range(num_epochs): + model.set_train() + train_total_size = train_dataset.get_dataset_size() + for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)): + optimizer.zero_grad() + loss = grad_fn(**batch) + optimizer.step() + lr_scheduler.step() + + model.set_train(False) + eval_total_size = eval_dataset.get_dataset_size() + for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)): + with _no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(axis=-1) + predictions, references = predictions, batch["labels"] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + print(f"epoch {epoch}:", eval_metric) + + +# In[ ]: + + + + diff --git a/mindnlp/core/__init__.py b/mindnlp/core/__init__.py index d58db56e2..8ecdbdf5e 100644 --- a/mindnlp/core/__init__.py +++ b/mindnlp/core/__init__.py @@ -47,7 +47,7 @@ from .amp import autocast, GradScaler from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \ - return_types, linalg, fx, backends, testing + return_types, linalg, fx, backends, testing, nn from ._lowrank import svd_lowrank from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state @@ -95,3 +95,5 @@ def set_autocast_dtype(device_type, dtype): def get_autocast_dtype(device_type): return AUTO_CAST_DTYE[device_type] + +__version__ = 'test_version_no_value' \ No newline at end of file diff --git a/mindnlp/core/autograd/function.py b/mindnlp/core/autograd/function.py index 02a1e9855..59e799aa0 100644 --- a/mindnlp/core/autograd/function.py +++ b/mindnlp/core/autograd/function.py @@ -47,6 +47,7 @@ def value_and_grad_f(*args, **kwargs): if kwargs: run_args = args + tuple(kwargs.values()) + grads = _pynative_executor.check_run(grad_, fn_, params_or_argnums, None, *run_args) grads = _pynative_executor.grad(fn_, grad_, params_or_argnums, None, *run_args) grads = tuple(mindspore.Tensor(grad) for grad in grads) if attach_grads: diff --git a/mindnlp/utils/safetensors_patch.py b/mindnlp/utils/safetensors_patch.py index cc610521c..fb80e7fc6 100644 --- a/mindnlp/utils/safetensors_patch.py +++ b/mindnlp/utils/safetensors_patch.py @@ -85,13 +85,15 @@ def ndim(self): def get(self, *args, **kwargs): nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize - offset = self.start_offset - tensor = np.frombuffer(self.buffermmap, dtype=self.dtype, offset=offset, - count=nbytes // np.dtype(self.dtype).itemsize) + buffer = bytearray(nbytes) + self.bufferfile.seek(self.start_offset) + self.bufferfile.readinto(buffer) + tensor = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape) tensor = tensor.reshape(self.shape) if not SUPPORT_BF16 and self.info["dtype"] == 'BF16': tensor = tensor.astype(np.float16) tensor = Tensor.from_numpy(tensor) + return tensor @property @@ -135,17 +137,18 @@ def getSize(fileobject): def metadata_validate(metadata): - start = 0 + end = 0 for key, info in metadata.items(): s, e = info["data_offsets"] - if s != start or e < s: + if e < s: raise ValueError(f"SafeTensorError::InvalidOffset({key})") - start = e + if e > end: + end = e nelements = np.prod(info["shape"]) nbytes = nelements * _DTYPE_SIZE[info["dtype"]] if (e - s) != nbytes: raise ValueError("SafeTensorError::TensorInvalidInfo") - return start + return end def read_metadata(buffer): buffer_len = getSize(buffer) diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py index 2f6177b69..20a1413f0 100644 --- a/mindnlp/utils/torch_proxy.py +++ b/mindnlp/utils/torch_proxy.py @@ -2,145 +2,128 @@ import types import importlib import importlib.metadata -from collections import defaultdict - -class TorchProxyModule(types.ModuleType): - """递归代理模块,支持任意深度的模块路径""" - - # 缓存已创建的代理模块 - _proxy_cache = defaultdict(dict) - - def __new__(cls, real_module, proxy_name): - """使用缓存避免重复创建代理""" - # 生成缓存键:真实模块ID + 代理名称 - cache_key = (id(real_module), proxy_name) - - # 如果已存在缓存,直接返回 - if cache_key in cls._proxy_cache[real_module]: - return cls._proxy_cache[real_module][cache_key] - - # 创建新实例并缓存 - instance = super().__new__(cls, proxy_name) - cls._proxy_cache[real_module][cache_key] = instance - return instance - - def __init__(self, real_module, proxy_name): - """初始化代理模块""" - super().__init__(proxy_name) - self._real_module = real_module - self._proxy_name = proxy_name - self._submodule_proxies = {} - - # 设置关键元数据 - self.__name__ = proxy_name - self.__package__ = proxy_name - self.__file__ = "" - - def __getattr__(self, name): - """动态获取属性并创建子模块代理""" - # 1. 尝试从真实模块获取属性 +import importlib.abc +import importlib.machinery +from types import ModuleType + +class RedirectFinder(importlib.abc.MetaPathFinder): + def __init__(self, redirect_map): + # 重定向规则:被代理模块 -> 实际模块 + self.redirect_map = redirect_map + + def find_spec(self, fullname, path, target=None): + # 识别需要重定向的模块 + for proxy_prefix, target_prefix in self.redirect_map.items(): + if fullname == proxy_prefix or fullname.startswith(proxy_prefix + "."): + # 计算实际模块名 + target_name = fullname.replace(proxy_prefix, target_prefix, 1) + + return importlib.machinery.ModuleSpec( + name=fullname, + loader=RedirectLoader(target_name), + is_package=self._is_package(target_name), + ) + return None + + def _is_package(self, module_name): + # 检测模块是否为包(包含子模块) try: - real_attr = getattr(self._real_module, name) - except AttributeError: - raise AttributeError( - f"module '{self._proxy_name}' has no attribute '{name}'" - ) + module = importlib.import_module(module_name) + return hasattr(module, "__path__") + except ImportError: + return False - # 2. 如果是模块类型,创建递归代理 - if isinstance(real_attr, types.ModuleType): - # 构建子模块的代理名称 - sub_proxy_name = f"{self._proxy_name}.{name}" - if name in self._submodule_proxies: - return self._submodule_proxies[name] +class RedirectLoader(importlib.abc.Loader): + def __init__(self, target_name): + self.target_name = target_name - # 创建或获取子模块代理 - proxy_sub = TorchProxyModule( - real_attr, - sub_proxy_name - ) + def create_module(self, spec): + # 创建代理模块对象 + module = ModuleType(spec.name) + module.__spec__ = spec + module.__path__ = [] + module.__loader__ = self + module.__package__ = spec.name + return module - self._submodule_proxies[name] = proxy_sub - # 缓存子模块代理 - self._submodule_proxies[name] = proxy_sub - # 注册到sys.modules - sys.modules[sub_proxy_name] = proxy_sub - # 注册到父模块 - setattr(self, name, proxy_sub) - return self._submodule_proxies[name] - - # 4. 其他类型直接返回 - return real_attr - - def __setattr__(self, name, value): - """处理属性设置""" - # 内部属性直接设置 - if name in {"_real_module", "_proxy_name", "_submodule_proxies"}: - super().__setattr__(name, value) - return - - # 其他属性设置到真实模块 - if name not in self._submodule_proxies: - setattr(self._real_module, name, value) - - def __dir__(self): - """返回真实模块的属性列表""" - return dir(self._real_module) - - def __repr__(self): - """友好的代理模块表示""" - return f"" - - def __getattribute__(self, name): - """特殊处理元数据相关属性""" - if name == '__file__': - return '' - if name == '__package__': - return 'torch' - if name == '__spec__': - return self._create_mock_spec() - return super().__getattribute__(name) + def exec_module(self, module): + # 动态设置__class__以代理属性访问 + class ProxyModule(type(module)): + def __getattr__(_, name): + # 动态导入实际模块中的属性 + try: + target_module = importlib.import_module(self.target_name) + except Exception as e: + raise e + # 处理子模块导入 (e.g. torch.nn -> mindnlp.core.nn) + if hasattr(target_module, name): + return getattr(target_module, name) -def initialize_torch_proxy(): - import mindnlp - torch_proxy = TorchProxyModule(mindnlp.core, 'torch') - sys.modules["torch"] = torch_proxy + # 处理从子模块导入 (e.g. from torch.nn import Module) + try: + submodule_name = f"{self.target_name}.{name}" + return importlib.import_module(submodule_name) + except ImportError: + raise AttributeError( + f"Module '{self.target_name}' has no attribute '{name}'" + ) + + def __setattr__(_, name, value): + target_module = importlib.import_module(self.target_name) + if not hasattr(target_module, name): + return + return super().__setattr__(name, value) + + # 继承原始模块的特殊属性 + module.__class__ = ProxyModule - # 设置必要的元数据 - torch_proxy.__version__ = "2.1.1+dev" - return torch_proxy +# 配置重定向规则 +REDIRECT_MAP = { + "torch": "mindnlp.core", +} + +def initialize_torch_proxy(): + sys.meta_path.insert(0, RedirectFinder(REDIRECT_MAP)) + import torch + torch.__version__ = "2.1.1+dev" + def setup_metadata_patch(): """解决 importlib.metadata 找不到 torch 的问题""" # 保存原始函数 orig_distribution = importlib.metadata.distribution orig_distributions = importlib.metadata.distributions - + # 拦截对 torch 分发的查询 def patched_distribution(dist_name): if dist_name == "torch": return types.SimpleNamespace( version="2.1.1+dev", metadata={"Name": "torch", "Version": "2.1.1+dev"}, - read_text=lambda f: f"Name: torch\nVersion: 2.1.1+dev" if f == "METADATA" else None + read_text=lambda f: ( + f"Name: torch\nVersion: 2.1.1+dev" if f == "METADATA" else None + ), ) return orig_distribution(dist_name) - + # 确保分发列表中有 torch def patched_distributions(**kwargs): dists = list(orig_distributions(**kwargs)) - dists.append(types.SimpleNamespace( - name="torch", - version="2.1.1+dev", - metadata={"Name": "torch", "Version": "2.1.1+dev"}, - files=[], - locate_file=lambda p: None, - _normalized_name='torch', - entry_points=[] - )) + dists.append( + types.SimpleNamespace( + name="torch", + version="2.1.1+dev", + metadata={"Name": "torch", "Version": "2.1.1+dev"}, + files=[], + locate_file=lambda p: None, + _normalized_name="torch", + entry_points=[], + ) + ) return dists - + # 应用补丁 importlib.metadata.distribution = patched_distribution importlib.metadata.distributions = patched_distributions