Skip to content

Commit

Permalink
[Projects] Fix second create_remote not overriding current remote (#5245
Browse files Browse the repository at this point in the history
)
  • Loading branch information
yaelgen committed Mar 18, 2024
1 parent daae2b7 commit 7218958
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
52 changes: 51 additions & 1 deletion mlrun/projects/project.py
Expand Up @@ -2403,13 +2403,47 @@ def pull(
clone_zip(url, self.spec.context, self._secrets)

def create_remote(self, url, name="origin", branch=None):
"""create remote for the project git
"""Create remote for the project git
This method creates a new remote repository associated with the project's Git repository.
If a remote with the specified name already exists, it will not be overwritten.
If you wish to update the URL of an existing remote, use the `set_remote` method instead.
:param url: remote git url
:param name: name for the remote (default is 'origin')
:param branch: Git branch to use as source
"""
self.set_remote(url, name=name, branch=branch, overwrite=False)

def set_remote(self, url, name="origin", branch=None, overwrite=True):
"""Create or update a remote for the project git repository.
This method allows you to manage remote repositories associated with the project.
It checks if a remote with the specified name already exists.
If a remote with the same name does not exist, it will be created.
If a remote with the same name already exists,
the behavior depends on the value of the 'overwrite' flag.
:param url: remote git url
:param name: name for the remote (default is 'origin')
:param branch: Git branch to use as source
:param overwrite: if True (default), updates the existing remote with the given URL if it already exists.
if False, raises an error when attempting to create a remote with a name that already exists.
:raises MLRunConflictError: If a remote with the same name already exists and overwrite
is set to False.
"""
self._ensure_git_repo()
if self._remote_exists(name):
if overwrite:
self.spec.repo.delete_remote(name)
else:
raise mlrun.errors.MLRunConflictError(
f"Remote '{name}' already exists in the project, "
f"each remote in the project must have a unique name."
"Use 'set_remote' with 'override=True' inorder to update the remote, or choose a different name."
)
self.spec.repo.create_remote(name, url=url)
url = url.replace("https://", "git://")
if not branch:
Expand All @@ -2422,6 +2456,22 @@ def create_remote(self, url, name="origin", branch=None):
self.spec._source = self.spec.source or url
self.spec.origin_url = self.spec.origin_url or url

def remove_remote(self, name):
"""Remove a remote from the project's Git repository.
This method removes the remote repository associated with the specified name from the project's Git repository.
:param name: Name of the remote to remove.
"""
if self._remote_exists(name):
self.spec.repo.delete_remote(name)
else:
logger.warning(f"The remote '{name}' does not exist. Nothing to remove.")

def _remote_exists(self, name):
"""Check if a remote with the given name already exists"""
return any(remote.name == name for remote in self.spec.repo.remotes)

def _ensure_git_repo(self):
if self.spec.repo:
return
Expand Down
118 changes: 118 additions & 0 deletions tests/projects/test_project.py
Expand Up @@ -1739,6 +1739,124 @@ def test_project_create_remote():
assert "mlrun-remote" in [remote.name for remote in project.spec.repo.remotes]


@pytest.mark.parametrize(
"url,set_url,name,set_name,overwrite,expected_url,expected",
[
# Remote doesn't exist, create normally
(
"https://github.com/mlrun/some-git-repo.git",
"https://github.com/mlrun/some-other-git-repo.git",
"mlrun-remote",
"mlrun-another-remote",
False,
"https://github.com/mlrun/some-other-git-repo.git",
does_not_raise(),
),
# Remote exists, overwrite False, raise MLRunConflictError
(
"https://github.com/mlrun/some-git-repo.git",
"https://github.com/mlrun/some-git-other-repo.git",
"mlrun-remote",
"mlrun-remote",
False,
"https://github.com/mlrun/some-git-repo.git",
pytest.raises(mlrun.errors.MLRunConflictError),
),
# Remote exists, overwrite True, update remote
(
"https://github.com/mlrun/some-git-repo.git",
"https://github.com/mlrun/some-git-other-repo.git",
"mlrun-remote",
"mlrun-remote",
True,
"https://github.com/mlrun/some-git-other-repo.git",
does_not_raise(),
),
],
)
def test_set_remote_as_update(
url, set_url, name, set_name, overwrite, expected_url, expected
):
with tempfile.TemporaryDirectory() as tmp_dir:
# create a project
project_name = "project-name"
project = mlrun.get_or_create_project(project_name, context=tmp_dir)

project.create_remote(
url=url,
name=name,
)
with expected:
project.set_remote(
url=set_url,
name=set_name,
overwrite=overwrite,
)

if name != set_name:
assert project.spec.repo.remote(name).url == url
assert project.spec.repo.remote(set_name).url == expected_url


@pytest.mark.parametrize(
"url,name,expected",
[
# Remote doesn't exist, create normally
(
"https://github.com/mlrun/some-other-git-repo.git",
"mlrun-remote2",
does_not_raise(),
),
# Remote exists, raise MLRunConflictError
(
"https://github.com/mlrun/some-git-repo.git",
"mlrun-remote",
pytest.raises(mlrun.errors.MLRunConflictError),
),
],
)
def test_create_remote(url, name, expected):
with tempfile.TemporaryDirectory() as tmp_dir:
# create a project
project_name = "project-name"
project = mlrun.get_or_create_project(project_name, context=tmp_dir)

project.create_remote(
url="https://github.com/mlrun/some-git-repo.git",
name="mlrun-remote",
)

with expected:
project.create_remote(
url=url,
name=name,
)
assert project.spec.repo.remote(name).url == url


@pytest.mark.parametrize(
"name",
[
# Remote exists
"mlrun-remote",
# Remote doesn't exist
"non-existent-remote",
],
)
def test_remove_remote(name):
with tempfile.TemporaryDirectory() as tmp_dir:
# create a project
project_name = "project-name"
project = mlrun.get_or_create_project(project_name, context=tmp_dir)

project.create_remote(
url="https://github.com/mlrun/some-git-repo.git",
name="mlrun-remote",
)
project.remove_remote(name)
assert name not in project.spec.repo.remotes


@pytest.mark.parametrize(
"source_url, pull_at_runtime, base_image, image_name, target_dir",
[
Expand Down

0 comments on commit 7218958

Please sign in to comment.