Skip to content

Commit

Permalink
Now detecting main branch of repository
Browse files Browse the repository at this point in the history
  • Loading branch information
lorinkoz committed Mar 6, 2021
1 parent 131417d commit c3944bd
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 18 deletions.
7 changes: 0 additions & 7 deletions README.rst
Expand Up @@ -87,13 +87,6 @@ forget to apply them again at the end::
python manage.py unmigrate --fake
python manage.py migrate --fake

No more master
--------------

This package (still) uses ``master`` as the name of the default branch. If that
is no longer the case for your repositories, you can define ``MAIN_BRANCH`` in
your Django settings.

Do you see potential?
---------------------

Expand Down
30 changes: 26 additions & 4 deletions django_unmigrate/core.py
@@ -1,13 +1,12 @@
import os
import sys
from functools import lru_cache

from django.db import connections
from django.db.migrations.loader import MigrationLoader
from git import Repo
from git.exc import GitCommandError

from .settings import MAIN_BRANCH


class GitError(Exception):
message = ""
Expand All @@ -19,19 +18,42 @@ def __str__(self):
return self.message


def get_targets(database="default", ref=MAIN_BRANCH):
@lru_cache(maxsize=None)
def get_main_branch():
"""
Detects main branch of repository.
"""
try:
# Getting repo
pathname = os.path.dirname(sys.argv[0])
repo = Repo(os.path.abspath(pathname), search_parent_directories=True)
branches = {x.strip(" *") for x in repo.git.branch("--list").splitlines()}
if "main" in branches:
return "main"
elif "master" in branches:
return "master"
except GitCommandError as error: # pragma: no cover
raise GitError(str(error))
raise GitError("Unable to detect main branch of repository.") # pragma: no cover


def get_targets(database="default", ref=None):
"""
Produce target migrations from ``database`` and ``ref``.
"""
if ref is None:
ref = get_main_branch()
added_targets = get_added_migrations(ref)
return (added_targets, get_parents_from_targets(added_targets, database))


def get_added_migrations(ref=MAIN_BRANCH):
def get_added_migrations(ref=None):
"""
Detect the added migrations when compared to ``ref``, and return them as
target tuples: ``(app_name, migration_name)``
"""
if ref is None:
ref = get_main_branch()
try:
# Getting repo
pathname = os.path.dirname(sys.argv[0])
Expand Down
6 changes: 3 additions & 3 deletions django_unmigrate/management/commands/unmigrate.py
Expand Up @@ -6,7 +6,6 @@
from django.db import DEFAULT_DB_ALIAS

from django_unmigrate.core import GitError, get_targets
from django_unmigrate.settings import MAIN_BRANCH


class Command(BaseCommand):
Expand All @@ -16,8 +15,9 @@ def add_arguments(self, parser):
parser.add_argument(
"ref",
nargs="?",
default=MAIN_BRANCH,
help="Git ref to compare existing migrations.",
default=None,
help="Git ref to compare existing migrations. "
"Defaults to None, which tries to detect a main or master branch.",
)
parser.add_argument(
"--database",
Expand Down
3 changes: 0 additions & 3 deletions django_unmigrate/settings.py

This file was deleted.

30 changes: 29 additions & 1 deletion dunm_sandbox/tests/test_core.py
@@ -1,14 +1,42 @@
import os
import sys

from django.test import TestCase
from git import Repo

from django_unmigrate.core import GitError, get_added_migrations, get_parents_from_targets, get_targets
from django_unmigrate.core import GitError, get_added_migrations, get_main_branch, get_parents_from_targets, get_targets
from dunm_sandbox.meta import COMMITS, PARENTS


class GetMainBranchTestCase(TestCase):
"""
Tests core.get_main_branch
"""

def test_master(self):
main_branch = get_main_branch()
self.assertEqual(main_branch, "master")

def test_main(self):
pathname = os.path.dirname(sys.argv[0])
repo = Repo(os.path.abspath(pathname), search_parent_directories=True)
repo.git.branch("main")
get_main_branch.cache_clear()
main_branch = get_main_branch()
self.assertEqual(main_branch, "main")
repo.git.branch("-d", "main")
get_main_branch.cache_clear()


class GetAddedMigrationsTestCase(TestCase):
"""
Tests core.get_added_migrations
"""

def test_plain(self):
response = get_added_migrations()
self.assertEqual(response, [])

def test_by_commit(self):
for commit, expected_migrations in COMMITS.items():
response = get_added_migrations(commit)
Expand Down

0 comments on commit c3944bd

Please sign in to comment.