diff --git a/pygit2/callbacks.py b/pygit2/callbacks.py index 5a040fbe6..05d02fdd3 100644 --- a/pygit2/callbacks.py +++ b/pygit2/callbacks.py @@ -632,7 +632,7 @@ def _checkout_progress_cb(path, completed_steps, total_steps, data: CheckoutCall data.checkout_progress(maybe_string(path), completed_steps, total_steps) -def _git_checkout_options(callbacks=None, strategy=None, directory=None, paths=None): +def _git_checkout_options(callbacks=None, strategy=None, directory=None, paths=None, c_checkout_options_ptr=None): if callbacks is None: payload = CheckoutCallbacks() else: @@ -642,7 +642,10 @@ def _git_checkout_options(callbacks=None, strategy=None, directory=None, paths=N handle = ffi.new_handle(payload) # Create the options struct to pass - opts = ffi.new('git_checkout_options *') + if not c_checkout_options_ptr: + opts = ffi.new('git_checkout_options *') + else: + opts = c_checkout_options_ptr check_error(C.git_checkout_init_options(opts, 1)) # References we need to keep to strings and so forth @@ -714,23 +717,25 @@ def git_stash_apply_options(callbacks=None, reinstate_index=False, strategy=None if callbacks is None: callbacks = StashApplyCallbacks() - # First, set up checkout_options - payload = _git_checkout_options(callbacks=callbacks, strategy=strategy, directory=directory, paths=paths) - assert payload == callbacks - - # Now set up the rest of stash options + # Set up stash options # TODO: git_stash_apply_init_options is deprecated (along with a bunch of other git_XXX_init_options functions) - stash_options = ffi.new('git_stash_apply_options *') - check_error(C.git_stash_apply_init_options(stash_options, 1)) + stash_apply_options = ffi.new('git_stash_apply_options *') + check_error(C.git_stash_apply_init_options(stash_apply_options, 1)) flags = reinstate_index * C.GIT_STASH_APPLY_REINSTATE_INDEX - stash_options.flags = flags + stash_apply_options.flags = flags + + # Now set up checkout options + c_checkout_options_ptr = ffi.addressof(stash_apply_options.checkout_options) + payload = _git_checkout_options(callbacks=callbacks, strategy=strategy, directory=directory, paths=paths, c_checkout_options_ptr=c_checkout_options_ptr) + assert payload == callbacks + assert payload.checkout_options == c_checkout_options_ptr # Set up stash progress callback if the user has provided their own if type(callbacks).stash_apply_progress != StashApplyCallbacks.stash_apply_progress: - stash_options.progress_cb = C._stash_apply_progress_cb - stash_options.progress_payload = payload._ffi_handle + stash_apply_options.progress_cb = C._stash_apply_progress_cb + stash_apply_options.progress_payload = payload._ffi_handle # Give back control - payload.stash_options = stash_options + payload.stash_apply_options = stash_apply_options yield payload diff --git a/pygit2/repository.py b/pygit2/repository.py index 13d2f9a82..b33b3d560 100644 --- a/pygit2/repository.py +++ b/pygit2/repository.py @@ -1112,7 +1112,7 @@ def stash_apply(self, index=0, **kwargs): >>> repo.stash_apply(strategy=GIT_CHECKOUT_ALLOW_CONFLICTS) """ with git_stash_apply_options(**kwargs) as payload: - err = C.git_stash_apply(self._repo, index, payload.stash_options) + err = C.git_stash_apply(self._repo, index, payload.stash_apply_options) payload.check_error(err) def stash_drop(self, index=0): @@ -1133,7 +1133,7 @@ def stash_pop(self, index=0, **kwargs): For arguments, see Repository.stash_apply(). """ with git_stash_apply_options(**kwargs) as payload: - err = C.git_stash_pop(self._repo, index, payload.stash_options) + err = C.git_stash_pop(self._repo, index, payload.stash_apply_options) payload.check_error(err) # diff --git a/test/test_repository.py b/test/test_repository.py index 61bfe1d1c..a5bde8e63 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -414,6 +414,40 @@ def stash_apply_progress(self, progress: int): assert 1 == len(repo_stashes) assert repo_stashes[0].message == "On master: custom stash message" +def test_stash_apply_checkout_options(testrepo): + sig = pygit2.Signature(name='Stasher', email='stasher@example.com', time=1641000000, offset=0) + + hello_txt = Path(testrepo.workdir) / 'hello.txt' + + # some changes to working dir + with hello_txt.open('w') as f: + f.write('stashed content') + + # create the stash + testrepo.stash(sig, include_untracked=True, message="custom stash message") + + # define callbacks that raise an InterruptedError when checkout detects a conflict + class MyStashApplyCallbacks(pygit2.StashApplyCallbacks): + def checkout_notify(self, why, path, baseline, target, workdir): + if why == pygit2.GIT_CHECKOUT_NOTIFY_CONFLICT: + raise InterruptedError("Applying the stash would create a conflict") + + # overwrite hello.txt so that applying the stash would create a conflict + with hello_txt.open('w') as f: + f.write('conflicting content') + + # apply the stash with the default (safe) strategy; + # the callbacks should detect a conflict on checkout + with pytest.raises(InterruptedError): + testrepo.stash_apply(strategy=pygit2.GIT_CHECKOUT_SAFE, callbacks=MyStashApplyCallbacks()) + + # hello.txt should be intact + with hello_txt.open('r') as f: assert f.read() == 'conflicting content' + + # force apply the stash; this should work + testrepo.stash_apply(strategy=pygit2.GIT_CHECKOUT_FORCE, callbacks=MyStashApplyCallbacks()) + with hello_txt.open('r') as f: assert f.read() == 'stashed content' + def test_revert(testrepo): master = testrepo.head.peel()