diff --git a/mindnlp/utils/torch_proxy.py b/mindnlp/utils/torch_proxy.py index 20a1413f0..8c0c6ec32 100644 --- a/mindnlp/utils/torch_proxy.py +++ b/mindnlp/utils/torch_proxy.py @@ -17,6 +17,10 @@ def find_spec(self, fullname, path, target=None): if fullname == proxy_prefix or fullname.startswith(proxy_prefix + "."): # 计算实际模块名 target_name = fullname.replace(proxy_prefix, target_prefix, 1) + try: + importlib.import_module(target_name) + except Exception as e: + raise e return importlib.machinery.ModuleSpec( name=fullname, @@ -54,8 +58,11 @@ def __getattr__(_, name): # 动态导入实际模块中的属性 try: target_module = importlib.import_module(self.target_name) + except ImportError as e: + raise AttributeError(f"Target module {self.target_name} could not be imported: {e}") from e except Exception as e: raise e + # 处理子模块导入 (e.g. torch.nn -> mindnlp.core.nn) if hasattr(target_module, name): return getattr(target_module, name) @@ -64,15 +71,18 @@ def __getattr__(_, name): try: submodule_name = f"{self.target_name}.{name}" return importlib.import_module(submodule_name) - except ImportError: + except ImportError as e: raise AttributeError( f"Module '{self.target_name}' has no attribute '{name}'" ) def __setattr__(_, name, value): - target_module = importlib.import_module(self.target_name) - if not hasattr(target_module, name): - return + try: + target_module = importlib.import_module(self.target_name) + if not hasattr(target_module, name): + return + except Exception as e: + raise e return super().__setattr__(name, value) # 继承原始模块的特殊属性