diff --git a/.gitignore b/.gitignore index 7626168a..70ca3865 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ data8assets/ .autopull_list summer/ test-repo/ +venv/ .ipynb_checkpoints docs/_build \ No newline at end of file diff --git a/nbgitpuller/handlers.py b/nbgitpuller/handlers.py index 289f7228..f83ad7d5 100644 --- a/nbgitpuller/handlers.py +++ b/nbgitpuller/handlers.py @@ -52,7 +52,7 @@ def get(self): try: repo = self.get_argument('repo') - branch = self.get_argument('branch') + branch = self.get_argument('branch', None) depth = self.get_argument('depth', None) if depth: depth = int(depth) @@ -73,7 +73,7 @@ def get(self): self.set_header('content-type', 'text/event-stream') self.set_header('cache-control', 'no-cache') - gp = GitPuller(repo, branch, repo_dir, depth=depth, parent=self.settings['nbapp']) + gp = GitPuller(repo, repo_dir, branch=branch, depth=depth, parent=self.settings['nbapp']) q = Queue() @@ -149,7 +149,7 @@ def get(self): app_env = os.getenv('NBGITPULLER_APP', default='notebook') repo = self.get_argument('repo') - branch = self.get_argument('branch', 'master') + branch = self.get_argument('branch', None) depth = self.get_argument('depth', None) urlPath = self.get_argument('urlpath', None) or \ self.get_argument('urlPath', None) diff --git a/nbgitpuller/pull.py b/nbgitpuller/pull.py index 369c4292..f64b18e3 100644 --- a/nbgitpuller/pull.py +++ b/nbgitpuller/pull.py @@ -66,15 +66,75 @@ def _depth_default(self): where the GitPuller class hadn't been loaded already.""" return int(os.environ.get('NBGITPULLER_DEPTH', 1)) - def __init__(self, git_url, branch_name, repo_dir, **kwargs): - assert git_url and branch_name + def __init__(self, git_url, repo_dir, **kwargs): + assert git_url self.git_url = git_url - self.branch_name = branch_name + self.branch_name = kwargs.pop("branch") + + if self.branch_name is None: + self.branch_name = self.resolve_default_branch() + elif not self.branch_exists(self.branch_name): + raise ValueError(f"Branch: {self.branch_name} -- not found in repo: {self.git_url}") + self.repo_dir = repo_dir newargs = {k: v for k, v in kwargs.items() if v is not None} super(GitPuller, self).__init__(**newargs) + def branch_exists(self, branch): + """ + This checks to make sure the branch we are told to access + exists in the repo + """ + try: + heads = subprocess.run( + ["git", "ls-remote", "--heads", self.git_url], + capture_output=True, + text=True, + check=True + ) + tags = subprocess.run( + ["git", "ls-remote", "--tags", self.git_url], + capture_output=True, + text=True, + check=True + ) + lines = heads.stdout.splitlines() + tags.stdout.splitlines() + branches = [] + for line in lines: + _, ref = line.split() + refs, heads, branch_name = ref.split("/", 2) + branches.append(branch_name) + return branch in branches + except subprocess.CalledProcessError: + m = f"Problem accessing list of branches and/or tags: {self.git_url}" + logging.exception(m) + raise ValueError(m) + + def resolve_default_branch(self): + """ + This will resolve the default branch of the repo in + the case where the branch given does not exist + """ + try: + head_branch = subprocess.run( + ["git", "ls-remote", "--symref", self.git_url, "HEAD"], + capture_output=True, + text=True, + check=True + ) + for line in head_branch.stdout.splitlines(): + if line.startswith("ref:"): + # line resembles --> ref: refs/heads/main HEAD + _, ref, head = line.split() + refs, heads, branch_name = ref.split("/", 2) + return branch_name + raise ValueError(f"default branch not found in {self.git_url}") + except subprocess.CalledProcessError: + m = f"Problem accessing HEAD branch: {self.git_url}" + logging.exception(m) + raise ValueError(m) + def pull(self): """ Pull selected repo from a remote git repository, @@ -243,13 +303,11 @@ def main(): parser = argparse.ArgumentParser(description='Synchronizes a github repository with a local repository.') parser.add_argument('git_url', help='Url of the repo to sync') - parser.add_argument('branch_name', default='master', help='Branch of repo to sync', nargs='?') parser.add_argument('repo_dir', default='.', help='Path to clone repo under', nargs='?') args = parser.parse_args() for line in GitPuller( args.git_url, - args.branch_name, args.repo_dir ).pull(): print(line) diff --git a/nbgitpuller/static/index.js b/nbgitpuller/static/index.js index 53cd0cc0..c85d5897 100644 --- a/nbgitpuller/static/index.js +++ b/nbgitpuller/static/index.js @@ -44,12 +44,14 @@ require([ // Start git pulling handled by SyncHandler, declared in handlers.py var syncUrlParams = { repo: this.repo, - branch: this.branch, targetpath: this.targetpath } if (typeof this.depth !== 'undefined' && this.depth != undefined) { syncUrlParams['depth'] = this.depth; } + if (typeof this.branch !== 'undefined' && this.branch != undefined) { + syncUrlParams['branch'] = this.branch; + } var syncUrl = this.baseUrl + 'git-pull/api?' + $.param(syncUrlParams); this.eventSource = new EventSource(syncUrl); diff --git a/nbgitpuller/templates/status.html b/nbgitpuller/templates/status.html index b20a7b7f..1fcd00dc 100644 --- a/nbgitpuller/templates/status.html +++ b/nbgitpuller/templates/status.html @@ -5,7 +5,7 @@ data-base-url="{{ base_url | urlencode }}" data-repo="{{ repo | urlencode }}" data-path="{{ path | urlencode }}" -data-branch="{{ branch | urlencode }}" +{% if branch %}data-branch="{{ branch | urlencode }}"{% endif %} {% if depth %}data-depth="{{ depth | urlencode }}"{% endif %} data-targetpath="{{ targetpath | urlencode }}" {% endblock %} diff --git a/tests/test_gitpuller.py b/tests/test_gitpuller.py index c336ef06..05c01354 100644 --- a/tests/test_gitpuller.py +++ b/tests/test_gitpuller.py @@ -61,10 +61,10 @@ def push_file(self, path, content): class Puller(Repository): - def __init__(self, remote, path='puller', *args, **kwargs): + def __init__(self, remote, path='puller', branch="master", *args, **kwargs): super().__init__(path) remotepath = "file://%s" % os.path.abspath(remote.path) - self.gp = GitPuller(remotepath, 'master', path, *args, **kwargs) + self.gp = GitPuller(remotepath, path, branch=branch, *args, **kwargs) def pull_all(self): for line in self.gp.pull(): @@ -96,6 +96,47 @@ def test_initialize(): assert puller.git('rev-parse', 'HEAD') == pusher.git('rev-parse', 'HEAD') +def test_branch_exists(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + assert not puller.gp.branch_exists("wrong") + assert puller.gp.branch_exists("master") + + +def test_exception_branch_exists(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + orig_url = puller.gp.git_url + puller.gp.git_url = "" + try: + puller.gp.branch_exists("wrong") + except Exception as e: + assert type(e) == ValueError + puller.gp.git_url = orig_url + + +def test_resolve_default_branch(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + assert puller.gp.resolve_default_branch() == "master" + + +def test_exception_resolve_default_branch(): + with Remote() as remote, Pusher(remote) as pusher: + pusher.push_file('README.md', '1') + with Puller(remote, 'puller') as puller: + orig_url = puller.gp.git_url + puller.gp.git_url = "" + try: + puller.gp.resolve_default_branch() + except Exception as e: + assert type(e) == ValueError + puller.gp.git_url = orig_url + + def test_simple_push_pull(): """ Test the 'happy path' push/pull interaction