diff --git a/src/neptune/new/handler.py b/src/neptune/new/handler.py index 85816f5cd..d32b79351 100644 --- a/src/neptune/new/handler.py +++ b/src/neptune/new/handler.py @@ -24,6 +24,7 @@ Iterable, Iterator, List, + NewType, Optional, Union, ) @@ -69,6 +70,8 @@ if TYPE_CHECKING: from neptune.new.metadata_containers import MetadataContainer + NeptuneObject = NewType("NeptuneObject", MetadataContainer) + def validate_path_not_protected(target_path: str, handler: "Handler"): path_protection_exception = handler._PROTECTED_PATHS.get(target_path) @@ -94,7 +97,7 @@ class Handler: SYSTEM_STAGE_ATTRIBUTE_PATH: NeptuneCannotChangeStageManually, } - def __init__(self, container: "MetadataContainer", path: str): + def __init__(self, container: "NeptuneObject", path: str): super().__init__() self._container = container self._path = path @@ -113,6 +116,24 @@ def __getitem__(self, path: str) -> "Handler": def __setitem__(self, key: str, value) -> None: self[key].assign(value) + def __getattr__(self, item: str): + run_level_methods = {"exists", "get_structure", "get_run_url", "print_structure", "stop", "sync", "wait"} + + if item in run_level_methods: + raise AttributeError( + "You're invoking an object-level method on a handler for a namespace" "inside the object.", + f""" + For example: You're trying run[{self._path}].{item}() + but you probably want run.{item}(). + + To obtain the root object of the namespace handler, you can do: + root_run = run[{self._path}].get_root_object() + root_run.{item}() + """, + ) + + return object.__getattribute__(self, item) + def _get_attribute(self): """Returns Attribute defined in `self._path` or throws MissingFieldException""" attr = self._container.get_attribute(self._path) @@ -125,6 +146,9 @@ def container(self) -> "MetadataContainer": """Returns the container that the attribute is attached to""" return self._container + def get_root_object(self) -> "NeptuneObject": + return self._container + @check_protected_paths def assign(self, value, wait: bool = False) -> None: """Assigns the provided value to the field. @@ -300,6 +324,7 @@ def log( self._container.set_attribute(self._path, attr) attr.log(value, step=step, timestamp=timestamp, wait=wait, **kwargs) + @check_protected_paths def append( self, value: Union[dict, Any], @@ -347,6 +372,7 @@ def append( value = ExtendUtils.validate_and_transform_to_extend_format(value) self.extend(value, step, timestamp, wait, **kwargs) + @check_protected_paths def extend( self, values: ExtendDictT,