diff --git a/.gitignore b/.gitignore index 7626168a..70ca3865 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ data8assets/ .autopull_list summer/ test-repo/ +venv/ .ipynb_checkpoints docs/_build \ No newline at end of file diff --git a/nbgitpuller/handlers.py b/nbgitpuller/handlers.py index 289f7228..b6c41b62 100644 --- a/nbgitpuller/handlers.py +++ b/nbgitpuller/handlers.py @@ -149,7 +149,7 @@ def get(self): app_env = os.getenv('NBGITPULLER_APP', default='notebook') repo = self.get_argument('repo') - branch = self.get_argument('branch', 'master') + branch = self.get_argument('branch', None) depth = self.get_argument('depth', None) urlPath = self.get_argument('urlpath', None) or \ self.get_argument('urlPath', None) diff --git a/nbgitpuller/pull.py b/nbgitpuller/pull.py index 369c4292..09eecccd 100644 --- a/nbgitpuller/pull.py +++ b/nbgitpuller/pull.py @@ -70,11 +70,60 @@ def __init__(self, git_url, branch_name, repo_dir, **kwargs): assert git_url and branch_name self.git_url = git_url - self.branch_name = branch_name + + if branch_name == "None": + self.branch_name = self.resolve_default_branch() + elif not self.branch_exists(branch_name): + raise ValueError(f"{branch_name}: branch not found in {self.git_url}") + else: + self.branch_name = branch_name + self.repo_dir = repo_dir newargs = {k: v for k, v in kwargs.items() if v is not None} super(GitPuller, self).__init__(**newargs) + def branch_exists(self, branch): + """ + This checks to make sure the branch we are told to access + exists in the repo + """ + p_heads = subprocess.run( + ["git", "ls-remote", "--heads", self.git_url], + capture_output=True, + text=True, + ) + p_tags = subprocess.run( + ["git", "ls-remote", "--tags", self.git_url], + capture_output=True, + text=True, + ) + lines = p_heads.stdout.splitlines() + p_tags.stdout.splitlines() + branches = [] + for line in lines: + _, ref = line.split() + refs, heads, branch_name = ref.split("/", 2) + branches.append(branch_name) + return branch in branches + + def resolve_default_branch(self): + """ + This will resolve the default branch of the repo in + the case where the branch given does not exist + """ + p = subprocess.run( + ["git", "ls-remote", "--symref", self.git_url, "HEAD"], + capture_output=True, + text=True, + ) + + for line in p.stdout.splitlines(): + if line.startswith("ref:"): + # line resembles --> ref: refs/heads/main HEAD + _, ref, head = line.split() + refs, heads, branch_name = ref.split("/", 2) + return branch_name + raise ValueError(f"default branch not found in {self.git_url}") + def pull(self): """ Pull selected repo from a remote git repository, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..dba7ac5d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +traitlets +pytest +-e . diff --git a/tests/test_gitpuller.py b/tests/test_gitpuller.py index c336ef06..93ffb1b6 100644 --- a/tests/test_gitpuller.py +++ b/tests/test_gitpuller.py @@ -61,10 +61,10 @@ def push_file(self, path, content): class Puller(Repository): - def __init__(self, remote, path='puller', *args, **kwargs): + def __init__(self, remote, path='puller', branch="master", *args, **kwargs): super().__init__(path) remotepath = "file://%s" % os.path.abspath(remote.path) - self.gp = GitPuller(remotepath, 'master', path, *args, **kwargs) + self.gp = GitPuller(remotepath, branch, path, *args, **kwargs) def pull_all(self): for line in self.gp.pull(): @@ -96,6 +96,20 @@ def test_initialize(): assert puller.git('rev-parse', 'HEAD') == pusher.git('rev-parse', 'HEAD') +def test_branch_exists(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + assert not puller.gp.branch_exists("wrong") + + +def test_resolve_default_branch(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + assert puller.gp.resolve_default_branch() == "master" + + def test_simple_push_pull(): """ Test the 'happy path' push/pull interaction