Skip to content

Commit

Permalink
[Projects] Adding allow_cross_project flag for creating project fro…
Browse files Browse the repository at this point in the history
…m template (#5453)
  • Loading branch information
roei3000b committed Apr 30, 2024
1 parent be38973 commit 64a306a
Show file tree
Hide file tree
Showing 21 changed files with 388 additions and 101 deletions.
20 changes: 19 additions & 1 deletion docs/projects/project-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,22 @@ train_function = project.set_function(
)
train_function.with_limits(gpus=gpus, cpu=cpu, mem=mem)
train_function.save()
```
```

### Loading a project from a template
You can load a project from a template only if you make one of these changes:
1. Set the allow_cross_project flag = True and change the name of the project.
2. Change the name in the yaml file or delete the file.
3. Change the context dir.

```python
import mlrun

project = mlrun.load_project(
name="my-project",
context="./src", # assuming here there is a project.yaml with name that is not my-project
allow_cross_project=True,
)
```

**Note:** This is relevant also for the `get_or_create_project` function.
66 changes: 48 additions & 18 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,16 @@ def setup(project):
"Unsupported option, cannot use subpath argument with project templates"
)
if from_template.endswith(".yaml"):
project = _load_project_file(from_template, name, secrets)
project = _load_project_file(
from_template, name, secrets, allow_cross_project=True
)
elif from_template.startswith("git://"):
clone_git(from_template, context, secrets, clone=True)
shutil.rmtree(path.join(context, ".git"))
project = _load_project_dir(context, name)
project = _load_project_dir(context, name, allow_cross_project=True)
elif from_template.endswith(".zip"):
clone_zip(from_template, context, secrets)
project = _load_project_dir(context, name)
project = _load_project_dir(context, name, allow_cross_project=True)
else:
raise ValueError("template must be a path to .yaml or .zip file")
project.metadata.name = name
Expand Down Expand Up @@ -296,6 +298,7 @@ def load_project(
save: bool = True,
sync_functions: bool = False,
parameters: dict = None,
allow_cross_project: bool = False,
) -> "MlrunProject":
"""Load an MLRun project from git or tar or dir
Expand Down Expand Up @@ -342,6 +345,8 @@ def setup(project):
:param save: whether to save the created project and artifact in the DB
:param sync_functions: sync the project's functions into the project object (will be saved to the DB if save=True)
:param parameters: key/value pairs to add to the project.spec.params
:param allow_cross_project: if True, override the loaded project name. This flag ensures awareness of
loading an existing project yaml as a baseline for a new project with a different name
:returns: project object
"""
Expand All @@ -357,7 +362,7 @@ def setup(project):
if url:
url = str(url) # to support path objects
if is_yaml_path(url):
project = _load_project_file(url, name, secrets)
project = _load_project_file(url, name, secrets, allow_cross_project)
project.spec.context = context
elif url.startswith("git://"):
url, repo = clone_git(url, context, secrets, clone)
Expand All @@ -384,7 +389,7 @@ def setup(project):
repo, url = init_repo(context, url, init_git)

if not project:
project = _load_project_dir(context, name, subpath)
project = _load_project_dir(context, name, subpath, allow_cross_project)

if not project.metadata.name:
raise ValueError("Project name must be specified")
Expand Down Expand Up @@ -438,6 +443,7 @@ def get_or_create_project(
from_template: str = None,
save: bool = True,
parameters: dict = None,
allow_cross_project: bool = False,
) -> "MlrunProject":
"""Load a project from MLRun DB, or create/import if it does not exist
Expand Down Expand Up @@ -482,12 +488,12 @@ def setup(project):
:param from_template: path to project YAML file that will be used as from_template (for new projects)
:param save: whether to save the created project in the DB
:param parameters: key/value pairs to add to the project.spec.params
:param allow_cross_project: if True, override the loaded project name. This flag ensures awareness of
loading an existing project yaml as a baseline for a new project with a different name
:returns: project object
"""
context = context or "./"
spec_path = path.join(context, subpath or "", "project.yaml")
load_from_path = url or path.isfile(spec_path)
try:
# load project from the DB.
# use `name` as `url` as we load the project from the DB
Expand All @@ -503,13 +509,15 @@ def setup(project):
# only loading project from db so no need to save it
save=False,
parameters=parameters,
allow_cross_project=allow_cross_project,
)
logger.info("Project loaded successfully", project_name=name)
return project

except mlrun.errors.MLRunNotFoundError:
logger.debug("Project not found in db", project_name=name)

spec_path = path.join(context, subpath or "", "project.yaml")
load_from_path = url or path.isfile(spec_path)
# do not nest under "try" or else the exceptions raised below will be logged along with the "not found" message
if load_from_path:
# loads a project from archive or local project.yaml
Expand All @@ -525,6 +533,7 @@ def setup(project):
user_project=user_project,
save=save,
parameters=parameters,
allow_cross_project=allow_cross_project,
)

logger.info(
Expand Down Expand Up @@ -599,7 +608,7 @@ def setup(project):
return project


def _load_project_dir(context, name="", subpath=""):
def _load_project_dir(context, name="", subpath="", allow_cross_project=False):
subpath_str = subpath or ""

# support both .yaml and .yml file extensions
Expand All @@ -613,7 +622,7 @@ def _load_project_dir(context, name="", subpath=""):
with open(project_file_path) as fp:
data = fp.read()
struct = yaml.load(data, Loader=yaml.FullLoader)
project = _project_instance_from_struct(struct, name)
project = _project_instance_from_struct(struct, name, allow_cross_project)
project.spec.context = context
elif function_files := glob.glob(function_file_path):
function_path = function_files[0]
Expand Down Expand Up @@ -686,19 +695,32 @@ def _delete_project_from_db(project_name, secrets, deletion_strategy):
return db.delete_project(project_name, deletion_strategy=deletion_strategy)


def _load_project_file(url, name="", secrets=None):
def _load_project_file(url, name="", secrets=None, allow_cross_project=False):
try:
obj = get_object(url, secrets)
except FileNotFoundError as exc:
raise FileNotFoundError(f"cant find project file at {url}") from exc
struct = yaml.load(obj, Loader=yaml.FullLoader)
return _project_instance_from_struct(struct, name)
return _project_instance_from_struct(struct, name, allow_cross_project)


def _project_instance_from_struct(struct, name):
struct.setdefault("metadata", {})["name"] = name or struct.get("metadata", {}).get(
"name", ""
)
def _project_instance_from_struct(struct, name, allow_cross_project):
name_from_struct = struct.get("metadata", {}).get("name", "")
if name and name_from_struct and name_from_struct != name:
if allow_cross_project:
logger.warn(
"Project name is different than specified on its project yaml. Overriding.",
existing_name=name_from_struct,
overriding_name=name,
)
else:
raise ValueError(
f"project name mismatch, {name_from_struct} != {name}, please do one of the following:\n"
"1. Set the `allow_cross_project=True` when loading the project.\n"
f"2. Delete the existing project yaml, or ensure its name is equal to {name}.\n"
"3. Use different project context dir."
)
struct.setdefault("metadata", {})["name"] = name or name_from_struct
return MlrunProject.from_dict(struct)


Expand Down Expand Up @@ -1814,10 +1836,18 @@ def reload(self, sync=False, context=None) -> "MlrunProject":
"""
context = context or self.spec.context
if context:
project = _load_project_dir(context, self.metadata.name, self.spec.subpath)
project = _load_project_dir(
context,
self.metadata.name,
self.spec.subpath,
allow_cross_project=False,
)
else:
project = _load_project_file(
self.spec.origin_url, self.metadata.name, self._secrets
self.spec.origin_url,
self.metadata.name,
self._secrets,
allow_cross_project=False,
)
project.spec.source = self.spec.source
project.spec.repo = self.spec.repo
Expand Down
6 changes: 4 additions & 2 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def aioresponses_mock():

@pytest.fixture
def ensure_default_project() -> mlrun.projects.project.MlrunProject:
return mlrun.get_or_create_project("default")
return mlrun.get_or_create_project("default", allow_cross_project=True)


@pytest.fixture()
Expand Down Expand Up @@ -599,7 +599,9 @@ def rundb_mock() -> RunDBMock:

# Create the default project to mimic real MLRun DB (the default project is always available for use):
with tempfile.TemporaryDirectory() as tmp_dir:
mlrun.get_or_create_project("default", context=tmp_dir)
mlrun.get_or_create_project(
"default", context=tmp_dir, allow_cross_project=True
)

yield mock_object

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/sdk_api/artifacts/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setup_method(self, method, extra_env=None):
def test_artifacts(self):
db = mlrun.get_run_db()
prj, tree, key, body = "p9", "t19", "k802", "tomato"
mlrun.get_or_create_project(prj, "./")
mlrun.get_or_create_project(prj, "./", allow_cross_project=True)
artifact = mlrun.artifacts.Artifact(key, body, target_path="/a.txt")

db.store_artifact(key, artifact, tree=tree, project=prj)
Expand All @@ -60,7 +60,7 @@ def test_artifacts(self):

def test_list_artifacts_filter_by_kind(self):
prj, tree, key, body = "p9", "t19", "k802", "tomato"
mlrun.get_or_create_project(prj, context="./")
mlrun.get_or_create_project(prj, context="./", allow_cross_project=True)
model_artifact = mlrun.artifacts.model.ModelArtifact(
key, body, target_path="/a.txt"
)
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/sdk_api/httpdb/test_exception_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def test_exception_handling(self):
handlers to be triggered and verifies that for all of them the actual error details returned in the response and
that the client successfully parses them and raise the right error class
"""
mlrun.get_or_create_project("some-project", context="./")
mlrun.get_or_create_project(
"some-project", context="./", allow_cross_project=True
)
# log_and_raise - mlrun code uses log_and_raise (common) which raises fastapi.HTTPException because we're
# sending a store artifact request with an invalid json body
# This is practically verifies that log_and_raise puts the kwargs under the details
Expand Down
34 changes: 25 additions & 9 deletions tests/integration/sdk_api/projects/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_sync_functions(self):
project_function_object = project.spec._function_objects
project_file_path = pathlib.Path(tests.conftest.results) / "project.yaml"
project.export(str(project_file_path))
imported_project = mlrun.load_project("./", str(project_file_path))
imported_project = mlrun.load_project(
"./", str(project_file_path), allow_cross_project=True
)
assert imported_project.spec._function_objects == {}
imported_project.sync_functions()
_assert_project_function_objects(imported_project, project_function_object)
Expand Down Expand Up @@ -201,7 +203,7 @@ def test_overwrite_project_failure(self):
def test_load_project_from_db(self):
project_name = "some-project"
mlrun.new_project(project_name)
mlrun.load_project(".", f"db://{project_name}")
mlrun.load_project(".", f"db://{project_name}", allow_cross_project=True)

def test_load_project_with_save(self):
project_name = "some-project"
Expand All @@ -212,30 +214,42 @@ def test_load_project_with_save(self):
imported_project_name = "imported-project"
# loaded project but didn't saved
mlrun.load_project(
"./", str(project_file_path), name=imported_project_name, save=False
"./",
str(project_file_path),
name=imported_project_name,
save=False,
allow_cross_project=True,
)

# loading project from db, but earlier load didn't saved, expected to fail
with pytest.raises(mlrun.errors.MLRunNotFoundError):
mlrun.load_project(".", f"db://{imported_project_name}", save=False)
mlrun.load_project(
".",
f"db://{imported_project_name}",
save=False,
allow_cross_project=True,
)

# loading project and saving
expected_project = mlrun.load_project(
"./", str(project_file_path), name=imported_project_name
"./",
str(project_file_path),
name=imported_project_name,
allow_cross_project=True,
)

# loading project from db, expected to succeed
loaded_project_from_db = mlrun.load_project(
".", f"db://{imported_project_name}", save=False
".", f"db://{imported_project_name}", save=False, allow_cross_project=True
)
_assert_projects(expected_project, loaded_project_from_db)

def test_get_project(self):
project_name = "some-project"
# create an empty project
mlrun.get_or_create_project(project_name)
mlrun.get_or_create_project(project_name, allow_cross_project=True)
# get it from the db
project = mlrun.get_or_create_project(project_name)
project = mlrun.get_or_create_project(project_name, allow_cross_project=True)

# verify default values
assert project.metadata.name == project_name
Expand All @@ -250,7 +264,9 @@ def test_get_project(self):
def test_set_project_secrets(self):
# A basic test verifying that we can access (mocked) project-secrets functionality in integration tests.
project_name = "some-project"
project_object = mlrun.get_or_create_project(project_name)
project_object = mlrun.get_or_create_project(
project_name, allow_cross_project=True
)

secrets = {"secret1": "value1", "secret2": "value2"}
project_object.set_secrets(secrets)
Expand Down
10 changes: 6 additions & 4 deletions tests/integration/sdk_api/run/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TestMain(tests.integration.sdk_api.base.TestMLRunIntegration):

def custom_setup(self):
# ensure default project exists
mlrun.get_or_create_project("default")
mlrun.get_or_create_project("default", allow_cross_project=True)

def test_main_run_basic(self):
out = self._exec_run(
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_main_run_archive_subdir(self):
assert out.find("state: completed") != -1, out

def test_main_local_project(self):
mlrun.get_or_create_project("testproject")
mlrun.get_or_create_project("testproject", allow_cross_project=True)
project_path = str(self.assets_path)
args = "-f simple -p x=2 --dump"
out = self._exec_main("run", args.split(), cwd=project_path)
Expand Down Expand Up @@ -432,7 +432,7 @@ def test_main_env_file(self):

def test_main_run_function_from_another_project(self):
# test running function from another project and validate that the function is stored in the current project
project = mlrun.get_or_create_project("first-project")
project = mlrun.get_or_create_project("first-project", allow_cross_project=True)

fn = mlrun.code_to_function(
name="new-func",
Expand All @@ -444,7 +444,9 @@ def test_main_run_function_from_another_project(self):
fn.save()

# create another project
project2 = mlrun.get_or_create_project("second-project")
project2 = mlrun.get_or_create_project(
"second-project", allow_cross_project=True
)

# from the second project - run the function that we stored in the first project
args = (
Expand Down
2 changes: 1 addition & 1 deletion tests/package/test_context_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_custom_packagers(
:param is_mandatory: If the packager is mandatory for the run or not. Mandatory packagers will always raise
exception if they couldn't be collected.
"""
project = mlrun.get_or_create_project(name="default")
project = mlrun.get_or_create_project(name="default", allow_cross_project=True)
project.add_custom_packager(
packager=packager,
is_mandatory=is_mandatory,
Expand Down
4 changes: 3 additions & 1 deletion tests/package/test_packagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def _setup_test(
mlrun.mlconf.packagers.pack_tuples = True

# Create a project for this tester:
project = mlrun.get_or_create_project(name="default", context=test_directory)
project = mlrun.get_or_create_project(
name="default", context=test_directory, allow_cross_project=True
)

# Create a MLRun function using the tester source file (all the functions must be located in it):
return project.set_function(
Expand Down

0 comments on commit 64a306a

Please sign in to comment.