From 80f815ee20ee35dec99cd4872fdceb3650b7df01 Mon Sep 17 00:00:00 2001 From: Iliyas Jorio Date: Mon, 4 Aug 2025 17:14:58 +0200 Subject: [PATCH 1/2] Add RemoteCallbacks.push_negotiation --- pygit2/callbacks.py | 25 ++++++++++++++++++++- pygit2/decl/callbacks.h | 5 +++++ pygit2/remotes.py | 28 ++++++++++++++++++++++++ test/test_remote.py | 48 ++++++++++++++++++++++++++++++++++++++--- 4 files changed, 102 insertions(+), 4 deletions(-) diff --git a/pygit2/callbacks.py b/pygit2/callbacks.py index 87914285..3261ce60 100644 --- a/pygit2/callbacks.py +++ b/pygit2/callbacks.py @@ -81,7 +81,7 @@ from pygit2._libgit2.ffi import GitProxyOptionsC from ._pygit2 import CloneOptions, PushOptions - from .remotes import TransferProgress + from .remotes import PushUpdate, TransferProgress # # The payload is the way to pass information from the pygit2 API, through # libgit2, to the Python callbacks. And back. @@ -198,6 +198,15 @@ def certificate_check(self, certificate: None, valid: bool, host: bytes) -> bool raise Passthrough + def push_negotiation(self, updates: list['PushUpdate']) -> None: + """ + During a push, called once between the negotiation step and the upload. + Provides information about what updates will be performed. + + Override with your own function to check the pending updates + and possibly reject them (by raising an exception). + """ + def transfer_progress(self, stats: 'TransferProgress') -> None: """ During the download of new data, this will be regularly called with @@ -427,6 +436,7 @@ def git_push_options(payload, opts=None): opts.callbacks.credentials = C._credentials_cb opts.callbacks.certificate_check = C._certificate_check_cb opts.callbacks.push_update_reference = C._push_update_reference_cb + opts.callbacks.push_negotiation = C._push_negotiation_cb # Per libgit2 sources, push_transfer_progress may incur a performance hit. # So, set it only if the user has overridden the no-op stub. if ( @@ -559,6 +569,19 @@ def _credentials_cb(cred_out, url, username, allowed, data): return 0 +@libgit2_callback +def _push_negotiation_cb(updates, num_updates, data): + from .remotes import PushUpdate + + push_negotiation = getattr(data, 'push_negotiation', None) + if not push_negotiation: + return 0 + + py_updates = [PushUpdate(updates[i]) for i in range(num_updates)] + push_negotiation(py_updates) + return 0 + + @libgit2_callback def _push_update_reference_cb(ref, msg, data): push_update_reference = getattr(data, 'push_update_reference', None) diff --git a/pygit2/decl/callbacks.h b/pygit2/decl/callbacks.h index 9d5409de..64582718 100644 --- a/pygit2/decl/callbacks.h +++ b/pygit2/decl/callbacks.h @@ -16,6 +16,11 @@ extern "Python" int _push_update_reference_cb( const char *status, void *data); +extern "Python" int _push_negotiation_cb( + const git_push_update **updates, + size_t len, + void *data); + extern "Python" int _remote_create_cb( git_remote **out, git_repository *repo, diff --git a/pygit2/remotes.py b/pygit2/remotes.py index 0c7d3c98..d4ddbc55 100644 --- a/pygit2/remotes.py +++ b/pygit2/remotes.py @@ -58,6 +58,34 @@ class LsRemotesDict(TypedDict): oid: Oid +class PushUpdate: + """ + Represents an update which will be performed on the remote during push. + """ + + src_refname: str + """The source name of the reference""" + + dst_refname: str + """The name of the reference to update on the server""" + + src: Oid + """The current target of the reference""" + + dst: Oid + """The new target for the reference""" + + def __init__(self, c_struct: Any) -> None: + src_refname = maybe_string(c_struct.src_refname) + dst_refname = maybe_string(c_struct.dst_refname) + assert src_refname is not None, 'libgit2 returned null src_refname' + assert dst_refname is not None, 'libgit2 returned null dst_refname' + self.src_refname = src_refname + self.dst_refname = dst_refname + self.src = Oid(raw=bytes(ffi.buffer(c_struct.src.id)[:])) + self.dst = Oid(raw=bytes(ffi.buffer(c_struct.dst.id)[:])) + + class TransferProgress: """Progress downloading and indexing data during a fetch.""" diff --git a/test/test_remote.py b/test/test_remote.py index 67a6a34a..09315ee4 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -31,7 +31,7 @@ import pygit2 from pygit2 import Remote, Repository -from pygit2.remotes import TransferProgress +from pygit2.remotes import PushUpdate, TransferProgress from . import utils @@ -406,9 +406,12 @@ def push_transfer_progress( assert origin.branches['master'].target == new_tip_id +@pytest.mark.parametrize('reject_from', ['push_transfer_progress', 'push_negotiation']) def test_push_interrupted_from_callbacks( - origin: Repository, clone: Repository, remote: Remote + origin: Repository, clone: Repository, remote: Remote, reject_from: str ) -> None: + reject_message = 'retreat! retreat!' + tip = clone[clone.head.target] clone.create_commit( 'refs/heads/master', @@ -420,10 +423,15 @@ def test_push_interrupted_from_callbacks( ) class MyCallbacks(pygit2.RemoteCallbacks): + def push_negotiation(self, updates: list[PushUpdate]) -> None: + if reject_from == 'push_negotiation': + raise InterruptedError(reject_message) + def push_transfer_progress( self, objects_pushed: int, total_objects: int, bytes_pushed: int ) -> None: - raise InterruptedError('retreat! retreat!') + if reject_from == 'push_transfer_progress': + raise InterruptedError(reject_message) assert origin.branches['master'].target == tip.id @@ -504,3 +512,37 @@ def test_push_threads(origin: Repository, clone: Repository, remote: Remote) -> callbacks = RemoteCallbacks() remote.push(['refs/heads/master'], callbacks, threads=1) assert callbacks.push_options.pb_parallelism == 1 + + +def test_push_negotiation( + origin: Repository, clone: Repository, remote: Remote +) -> None: + old_tip = clone[clone.head.target] + new_tip_id = clone.create_commit( + 'refs/heads/master', + old_tip.author, + old_tip.author, + 'empty commit', + old_tip.tree.id, + [old_tip.id], + ) + + the_updates: list[PushUpdate] = [] + + class MyCallbacks(pygit2.RemoteCallbacks): + def push_negotiation(self, updates: list[PushUpdate]) -> None: + the_updates.extend(updates) + + assert origin.branches['master'].target == old_tip.id + assert 'new_branch' not in origin.branches + + callbacks = MyCallbacks() + remote.push(['refs/heads/master'], callbacks=callbacks) + + assert len(the_updates) == 1 + assert the_updates[0].src_refname == 'refs/heads/master' + assert the_updates[0].dst_refname == 'refs/heads/master' + assert the_updates[0].src == old_tip.id + assert the_updates[0].dst == new_tip_id + + assert origin.branches['master'].target == new_tip_id From d7b67e4b55e3332b8d1397969a6968dc6e52920d Mon Sep 17 00:00:00 2001 From: Iliyas Jorio Date: Mon, 4 Aug 2025 17:33:19 +0200 Subject: [PATCH 2/2] Allow bypassing automatic connection in Remote.ls_remotes() Remote.ls_remotes() used to force a new connection to the remote. This can be harmful if a connection was already set up for other purposes, e.g. when calling ls_remotes from RemoteCallbacks.push_negotiation. This new argument `ls_remotes(..., connect)` lets you bypass the automatic connection. --- pygit2/remotes.py | 12 ++++++++++-- test/test_remote.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pygit2/remotes.py b/pygit2/remotes.py index d4ddbc55..ad34ee45 100644 --- a/pygit2/remotes.py +++ b/pygit2/remotes.py @@ -224,7 +224,10 @@ def fetch( return TransferProgress(C.git_remote_stats(self._remote)) def ls_remotes( - self, callbacks: RemoteCallbacks | None = None, proxy: str | None | bool = None + self, + callbacks: RemoteCallbacks | None = None, + proxy: str | None | bool = None, + connect: bool = True, ) -> list[LsRemotesDict]: """ Return a list of dicts that maps to `git_remote_head` from a @@ -235,9 +238,14 @@ def ls_remotes( callbacks : Passed to connect() proxy : Passed to connect() + + connect : Whether to connect to the remote first. You can pass False + if the remote has already connected. The list remains available after + disconnecting as long as a new connection is not initiated. """ - self.connect(callbacks=callbacks, proxy=proxy) + if connect: + self.connect(callbacks=callbacks, proxy=proxy) refs = ffi.new('git_remote_head ***') refs_len = ffi.new('size_t *') diff --git a/test/test_remote.py b/test/test_remote.py index 09315ee4..5a7a5027 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -204,6 +204,22 @@ def test_ls_remotes(testrepo: Repository) -> None: assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2')) +@utils.requires_network +def test_ls_remotes_without_implicit_connect(testrepo: Repository) -> None: + assert 1 == len(testrepo.remotes) + remote = testrepo.remotes[0] + + with pytest.raises(pygit2.GitError, match='this remote has never connected'): + remote.ls_remotes(connect=False) + + remote.connect() + refs = remote.ls_remotes(connect=False) + assert refs + + # Check that a known ref is returned. + assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2')) + + def test_remote_collection(testrepo: Repository) -> None: remote = testrepo.remotes['origin'] assert REMOTE_NAME == remote.name