Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branches #221

Merged
merged 4 commits into from May 28, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
152 changes: 152 additions & 0 deletions src/branch.c
Expand Up @@ -31,7 +31,10 @@
#include "reference.h"
#include "utils.h"


extern PyObject *GitError;
extern PyTypeObject ReferenceType;


PyDoc_STRVAR(Branch_delete__doc__,
"delete()\n"
Expand Down Expand Up @@ -105,6 +108,151 @@ PyObject* Branch_rename(Branch *self, PyObject *args)
}


PyDoc_STRVAR(Branch_branch_name__doc__,
"The name of the local or remote branch.");

PyObject* Branch_branch_name__get__(Branch *self)
{
int err;
const char *c_name;

CHECK_REFERENCE(self);

err = git_branch_name(&c_name, self->reference);
if (err == GIT_OK)
return to_unicode(c_name, NULL, NULL);
else
return Error_set(err);
}


PyDoc_STRVAR(Branch_remote_name__doc__,
"The name of the remote that the remote tracking branch belongs to.");

PyObject* Branch_remote_name__get__(Branch *self)
{
int err;
const char *branch_name;
char *c_name = NULL;

CHECK_REFERENCE(self);

branch_name = git_reference_name(self->reference);
// get the length of the remote name
err = git_branch_remote_name(NULL, 0, self->repo->repo, branch_name);
if (err < GIT_OK)
return Error_set(err);

// get the actual remote name
c_name = calloc(err, sizeof(char));
if (c_name == NULL)
return PyErr_NoMemory();

err = git_branch_remote_name(c_name,
err * sizeof(char),
self->repo->repo,
branch_name);
if (err < GIT_OK) {
free(c_name);
return Error_set(err);
}

PyObject *py_name = to_unicode(c_name, NULL, NULL);
free(c_name);

return py_name;
}


PyDoc_STRVAR(Branch_upstream__doc__,
"The branch supporting the remote tracking branch or None if this is not a "
"remote tracking branch. Set to None to unset.");

PyObject* Branch_upstream__get__(Branch *self)
{
int err;
git_reference *c_reference;

CHECK_REFERENCE(self);

err = git_branch_upstream(&c_reference, self->reference);
if (err == GIT_ENOTFOUND)
Py_RETURN_NONE;
else if (err < GIT_OK)
return Error_set(err);

return wrap_branch(c_reference, self->repo);
}

int Branch_upstream__set__(Branch *self, Reference *py_ref)
{
int err;
const char *branch_name = NULL;

CHECK_REFERENCE_INT(self);

if ((PyObject *)py_ref != Py_None) {
if (!PyObject_TypeCheck(py_ref, (PyTypeObject *)&ReferenceType)) {
PyErr_SetObject(PyExc_TypeError, (PyObject *)py_ref);
return -1;
}

CHECK_REFERENCE_INT(py_ref);
err = git_branch_name(&branch_name, py_ref->reference);
if (err < GIT_OK) {
Error_set(err);
return -1;
}
}

err = git_branch_set_upstream(self->reference, branch_name);
if (err < GIT_OK) {
Error_set(err);
return -1;
}

return 0;
}


PyDoc_STRVAR(Branch_upstream_name__doc__,
"The name of the reference supporting the remote tracking branch.");

PyObject* Branch_upstream_name__get__(Branch *self)
{
int err;
const char *branch_name;
char *c_name = NULL;

CHECK_REFERENCE(self);

branch_name = git_reference_name(self->reference);
// get the length of the upstream name
err = git_branch_upstream_name(NULL, 0, self->repo->repo, branch_name);
if (err < GIT_OK)
return Error_set(err);

// get the actual upstream name
c_name = calloc(err, sizeof(char));
if (c_name == NULL)
return PyErr_NoMemory();

err = git_branch_upstream_name(c_name,
err * sizeof(char),
self->repo->repo,
branch_name);
if (err < GIT_OK) {
free(c_name);
return Error_set(err);
}

PyObject *py_name = to_unicode(c_name, NULL, NULL);
free(c_name);

return py_name;
}


PyMethodDef Branch_methods[] = {
METHOD(Branch, delete, METH_NOARGS),
METHOD(Branch, is_head, METH_NOARGS),
Expand All @@ -113,6 +261,10 @@ PyMethodDef Branch_methods[] = {
};

PyGetSetDef Branch_getseters[] = {
GETTER(Branch, branch_name),
GETTER(Branch, remote_name),
GETSET(Branch, upstream),
GETTER(Branch, upstream_name),
{NULL}
};

Expand Down
43 changes: 43 additions & 0 deletions test/test_branch.py
Expand Up @@ -110,6 +110,15 @@ def test_branch_rename_fails_with_invalid_names(self):
self.assertRaises(ValueError,
lambda: original_branch.rename('abc@{123'))

def test_branch_name(self):
branch = self.repo.lookup_branch('master')
self.assertEqual(branch.branch_name, 'master')
self.assertEqual(branch.name, 'refs/heads/master')

branch = self.repo.lookup_branch('i18n')
self.assertEqual(branch.branch_name, 'i18n')
self.assertEqual(branch.name, 'refs/heads/i18n')


class BranchesEmptyRepoTestCase(utils.EmptyRepoTestCase):
def setUp(self):
Expand All @@ -131,6 +140,40 @@ def test_listall_branches(self):
branches = sorted(self.repo.listall_branches(pygit2.GIT_BRANCH_REMOTE))
self.assertEqual(branches, ['origin/master'])

def test_branch_remote_name(self):
self.repo.remotes[0].fetch()
branch = self.repo.lookup_branch('origin/master',
pygit2.GIT_BRANCH_REMOTE)
self.assertEqual(branch.remote_name, 'origin')

def test_branch_upstream(self):
self.repo.remotes[0].fetch()
remote_master = self.repo.lookup_branch('origin/master',
pygit2.GIT_BRANCH_REMOTE)
master = self.repo.create_branch('master',
self.repo[remote_master.target.hex])

self.assertTrue(master.upstream is None)
master.upstream = remote_master
self.assertEqual(master.upstream.branch_name, 'origin/master')

def set_bad_upstream():
master.upstream = 2.5
self.assertRaises(TypeError, set_bad_upstream)

master.upstream = None
self.assertTrue(master.upstream is None)

def test_branch_upstream_name(self):
self.repo.remotes[0].fetch()
remote_master = self.repo.lookup_branch('origin/master',
pygit2.GIT_BRANCH_REMOTE)
master = self.repo.create_branch('master',
self.repo[remote_master.target.hex])

master.upstream = remote_master
self.assertEqual(master.upstream_name, 'refs/remotes/origin/master')


if __name__ == '__main__':
unittest.main()