Skip to content

Commit

Permalink
Merge branch 'main' into find-nested
Browse files Browse the repository at this point in the history
* main:
  Add type annotations to models
  • Loading branch information
jacebrowning committed Mar 12, 2022
2 parents 36913db + b511a1b commit d69cc41
Show file tree
Hide file tree
Showing 9 changed files with 470 additions and 321 deletions.
2 changes: 2 additions & 0 deletions .pylint.ini
Expand Up @@ -133,6 +133,8 @@ disable=
unsubscriptable-object,
too-many-instance-attributes,
too-many-lines,
line-too-long,
too-many-locals,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -104,7 +104,7 @@ endif
# TESTS #######################################################################

RANDOM_SEED ?= $(shell date +%s)
FAILURES := .cache/v/cache/lastfailed
FAILURES := .cache/pytest/v/cache/lastfailed

PYTEST_OPTIONS := --random --random-seed=$(RANDOM_SEED)
ifdef DISABLE_COVERAGE
Expand Down
2 changes: 1 addition & 1 deletion gitman/commands.py
Expand Up @@ -33,7 +33,7 @@ def init(*, force: bool = False):
rev="master",
)
config.sources.append(source)
source = source.lock(
source = source.lock( # type: ignore
rev="ebbbf773431ba07510251bb03f9525c7bab2b13a", verify_rev=False
)
config.sources_locked.append(source)
Expand Down
67 changes: 41 additions & 26 deletions gitman/models/config.py
@@ -1,14 +1,14 @@
import os
import sys
from typing import List, Optional
from typing import Iterator, List, Optional

import log
from datafiles import datafile, field

from .. import common, exceptions, shell
from ..decorators import preserve_cwd
from .group import Group
from .source import Source
from .source import Identity, Source


@datafile("{self.root}/{self.filename}", defaults=True, manual=True)
Expand Down Expand Up @@ -58,11 +58,11 @@ def validate(self):
).format(source.name)
raise exceptions.InvalidConfig(msg)

def get_path(self, name=None):
def get_path(self, name: Optional[str] = None) -> str:
"""Get the full path to a dependency or internal file."""
base = self.location_path
if name == "__config__":
return self.path
return self.path # type: ignore
if name == "__log__":
return self.log_path
if name:
Expand All @@ -71,24 +71,24 @@ def get_path(self, name=None):

def install_dependencies(
self,
*names,
depth=None,
update=True,
recurse=False,
force=False,
force_interactive=False,
fetch=False,
clean=True,
skip_changes=False,
skip_default_group=False,
): # pylint: disable=too-many-locals
*names: str,
depth: Optional[int] = None,
update: bool = True,
recurse: bool = False,
force: bool = False,
force_interactive: bool = False,
fetch: bool = False,
clean: bool = True,
skip_changes: bool = False,
skip_default_group: bool = False,
) -> int:
"""Download or update the specified dependencies."""
if depth == 0:
log.info("Skipped directory: %s", self.location_path)
return 0

sources = self._get_sources(use_locked=False if update else None)
sources_filter = self._get_sources_filtered(
sources_filter = self._get_sources_filter(
*names, sources=sources, skip_default_group=skip_default_group
)

Expand All @@ -113,6 +113,7 @@ def install_dependencies(
clean=clean,
skip_changes=skip_changes,
)
assert self.root, f"Missing root: {self}"
source.create_links(self.root, force=force)
common.newline()
count += 1
Expand Down Expand Up @@ -143,14 +144,20 @@ def install_dependencies(
return count

@preserve_cwd
def run_scripts(self, *names, depth=None, force=False, show_shell_stdout=False):
def run_scripts(
self,
*names: str,
depth: Optional[int] = None,
force: bool = False,
show_shell_stdout: bool = False,
) -> int:
"""Run scripts for the specified dependencies."""
if depth == 0:
log.info("Skipped directory: %s", self.location_path)
return 0

sources = self._get_sources()
sources_filter = self._get_sources_filtered(
sources_filter = self._get_sources_filter(
*names, sources=sources, skip_default_group=False
)

Expand Down Expand Up @@ -199,7 +206,9 @@ def _remap_names_and_revs(cls, names):
name_rev_map[base_name] = rev
return name_rev_map.keys(), name_rev_map

def lock_dependencies(self, *names, obey_existing=True, skip_changes=False):
def lock_dependencies(
self, *names: str, obey_existing: bool = True, skip_changes: bool = False
) -> int:
"""Lock down the immediate dependency versions."""
sources_to_install, source_to_install_revs = self._remap_names_and_revs(
[*names]
Expand All @@ -210,7 +219,7 @@ def lock_dependencies(self, *names, obey_existing=True, skip_changes=False):
if len(sources_to_install) == 0:
skip_default = False

sources_filter = self._get_sources_filtered(
sources_filter = self._get_sources_filter(
*sources_to_install, sources=sources, skip_default_group=skip_default
)

Expand Down Expand Up @@ -284,7 +293,9 @@ def get_top_level_dependencies(self):

common.dedent()

def get_dependencies(self, depth=None, nested=True, allow_dirty=True):
def get_dependencies(
self, depth: Optional[int] = None, nested: bool = True, allow_dirty: bool = True
) -> Iterator[Identity]:
"""Yield the path, repository, and hash of each dependency."""
if not os.path.exists(self.location_path):
return
Expand Down Expand Up @@ -318,13 +329,13 @@ def get_dependencies(self, depth=None, nested=True, allow_dirty=True):

common.dedent()

def log(self, message="", *args):
def log(self, message: str = "", *args):
"""Append a message to the log file."""
os.makedirs(self.location_path, exist_ok=True)
with open(self.log_path, "a") as outfile:
outfile.write(message.format(*args) + "\n")

def _get_sources(self, *, use_locked=None):
def _get_sources(self, *, use_locked: Optional[bool] = None) -> List[Source]:
"""Merge source lists using the requested section as the base."""
if use_locked is True:
if self.sources_locked:
Expand All @@ -351,7 +362,9 @@ def _get_sources(self, *, use_locked=None):

return sources + extras

def _get_sources_filtered(self, *names, sources, skip_default_group):
def _get_sources_filter(
self, *names: str, sources: List[Source], skip_default_group: bool
) -> List[str]:
"""Get a filtered subset of sources."""
names_list = list(names)

Expand All @@ -368,7 +381,7 @@ def _get_sources_filtered(self, *names, sources, skip_default_group):
)

if not sources_filter:
sources_filter = [source.name for source in sources]
sources_filter = [source.name for source in sources if source.name]

return list(set(sources_filter))

Expand All @@ -383,7 +396,9 @@ def _get_nested(self, *, allow_dirty: bool):
yield from config.get_dependencies(allow_dirty=allow_dirty)


def load_config(start=None, *, search=True) -> Optional[Config]:
def load_config(
start: Optional[str] = None, *, search: bool = True
) -> Optional[Config]:
"""Load the config for the current project."""
start = os.path.abspath(start) if start else _resolve_current_directory()

Expand Down
49 changes: 31 additions & 18 deletions gitman/models/source.py
@@ -1,11 +1,14 @@
import os
from collections import namedtuple
from dataclasses import dataclass, field
from typing import List, Optional

import log

from .. import common, exceptions, git, shell

Identity = namedtuple("Identity", ["path", "url", "rev"])


@dataclass
class Link:
Expand Down Expand Up @@ -38,11 +41,10 @@ def __post_init__(self):
self.type = self.type or "git"

def __repr__(self):
return "<source {}>".format(self)
return f"<source {self}>"

def __str__(self):
pattern = "['{t}'] '{r}' @ '{v}' in '{d}'"
return pattern.format(t=self.type, r=self.repo, v=self.rev, d=self.name)
return f"{self.repo!r} @ {self.rev!r} in {self.name!r}"

def __eq__(self, other):
return self.name == other.name
Expand All @@ -55,11 +57,11 @@ def __lt__(self, other):

def update_files(
self,
force=False,
force_interactive=False,
fetch=False,
clean=True,
skip_changes=False,
force: bool = False,
force_interactive: bool = False,
fetch: bool = False,
clean: bool = True,
skip_changes: bool = False,
):
"""Ensure the source matches the specified revision."""
log.info("Updating source files...")
Expand Down Expand Up @@ -136,7 +138,7 @@ def update_files(
self.type, self.repo, self.name, fetch=fetch, clean=clean, rev=self.rev
)

def create_links(self, root, force=False):
def create_links(self, root: str, *, force: bool = False):
"""Create links from the source to target directory."""
if not self.links:
return
Expand All @@ -145,9 +147,9 @@ def create_links(self, root, force=False):
target = os.path.join(root, os.path.normpath(link.target))
relpath = os.path.relpath(os.getcwd(), os.path.dirname(target))
source = os.path.join(relpath, os.path.normpath(link.source))
create_sym_link(source, target, force)
create_sym_link(source, target, force=force)

def run_scripts(self, force=False, show_shell_stdout=False):
def run_scripts(self, force: bool = False, show_shell_stdout: bool = False):
log.info("Running install scripts...")

# Enter the working tree
Expand Down Expand Up @@ -177,7 +179,12 @@ def run_scripts(self, force=False, show_shell_stdout=False):
raise exceptions.ScriptFailure(msg)
common.newline()

def identify(self, allow_dirty=True, allow_missing=True, skip_changes=False):
def identify(
self,
allow_dirty: bool = True,
allow_missing: bool = True,
skip_changes: bool = False,
) -> Identity:
"""Get the path and current repository URL and hash."""
assert self.name
if os.path.isdir(self.name):
Expand All @@ -197,30 +204,36 @@ def identify(self, allow_dirty=True, allow_missing=True, skip_changes=False):
if allow_dirty:
common.show(self.DIRTY, color="git_dirty", log=False)
common.newline()
return path, url, self.DIRTY
return Identity(path, url, self.DIRTY)

if skip_changes:
msg = ("Skipped lock due to uncommitted changes " "in {}").format(
os.getcwd()
)
common.show(msg, color="git_changes")
common.newline()
return path, url, self.DIRTY
return Identity(path, url, self.DIRTY)

msg = "Uncommitted changes in {}".format(os.getcwd())
raise exceptions.UncommittedChanges(msg)

rev = git.get_hash(self.type, _show=True)
common.show(rev, color="git_rev", log=False)
common.newline()
return path, url, rev
return Identity(path, url, rev)

if allow_missing:
return os.getcwd(), "<missing>", self.UNKNOWN
return Identity(os.getcwd(), "<missing>", self.UNKNOWN)

raise self._invalid_repository

def lock(self, rev=None, allow_dirty=False, skip_changes=False, verify_rev=True):
def lock(
self,
rev: Optional[str] = None,
allow_dirty: bool = False,
skip_changes: bool = False,
verify_rev: bool = True,
) -> Optional["Source"]:
"""Create a locked source object.
Return a locked version of the current source if not dirty
Expand Down Expand Up @@ -267,7 +280,7 @@ def _invalid_repository(self):
return exceptions.InvalidRepository(msg)


def create_sym_link(source, target, force):
def create_sym_link(source: str, target: str, *, force: bool):
log.info("Creating a symbolic link...")

if os.path.islink(target):
Expand Down
2 changes: 1 addition & 1 deletion gitman/tests/test_models_source.py
Expand Up @@ -38,7 +38,7 @@ def test_init_rev(self):

def test_repr(self, source):
"""Verify sources can be represented."""
assert "<source ['git'] 'repo' @ 'rev' in 'name'>" == repr(source)
assert "<source 'repo' @ 'rev' in 'name'>" == repr(source)

def test_eq(self, source):
source2 = copy(source)
Expand Down

0 comments on commit d69cc41

Please sign in to comment.