Skip to content

Commit

Permalink
Typing for automation package (#7812)
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed May 18, 2022
1 parent 2cbf3cd commit 1bd83e4
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from .utils import get_python_versions_for_branch

MYPY_EXCLUDES = [
"python_modules/automation",
"python_modules/libraries/dagster-docker",
"examples/docs_snippets",
]

Expand Down
5 changes: 3 additions & 2 deletions python_modules/automation/automation/docker/dagster_docker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import os
from typing import Callable, Dict, List, NamedTuple, Optional
from typing import Callable, Dict, Iterator, List, NamedTuple, Optional

import yaml

Expand All @@ -22,7 +22,7 @@


@contextlib.contextmanager
def do_nothing(_cwd):
def do_nothing(_cwd: str) -> Iterator[None]:
yield


Expand Down Expand Up @@ -97,6 +97,7 @@ def aws_image(
check.opt_str_param(python_version, "python_version")
check.opt_str_param(custom_tag, "custom_tag")

tag: Optional[str]
if python_version:
last_updated = self._get_last_updated_for_python_version(python_version)
tag = python_version_image_tag(python_version, last_updated)
Expand Down
27 changes: 15 additions & 12 deletions python_modules/automation/automation/docker/ecr.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
import os
import subprocess
from typing import Optional

import dagster._check as check

# We default to using the ECR region here
DEFAULT_AWS_ECR_REGION = "us-west-2"


def aws_ecr_repository(aws_account_id, aws_region=DEFAULT_AWS_ECR_REGION):
def aws_ecr_repository(aws_account_id: str, aws_region: str = DEFAULT_AWS_ECR_REGION) -> str:
"""Returns the DNS hostname of the ECR registry for a given AWS account and region.
Args:
aws_account_id (str): Account ID, e.g. 123456789000
aws_region (str, optional): AWS region to use. Defaults to DEFAULT_AWS_ECR_REGION.
aws_region (str): AWS region to use. Defaults to DEFAULT_AWS_ECR_REGION.
Returns:
str: DNS hostname of the ECR registry to use.
"""
check.str_param(aws_account_id, "aws_account_id")
check.str_param(aws_region, "aws_region")

return "{aws_account_id}.dkr.ecr.{aws_region}.amazonaws.com".format(
aws_account_id=aws_account_id, aws_region=aws_region
)
return f"{aws_account_id}.dkr.ecr.{aws_region}.amazonaws.com"


def get_aws_account_id():
check.invariant(os.environ.get("AWS_ACCOUNT_ID"), "must have AWS_ACCOUNT_ID set")
return os.environ.get("AWS_ACCOUNT_ID")
def get_aws_account_id() -> str:
return check.not_none(os.environ.get("AWS_ACCOUNT_ID"), "must have AWS_ACCOUNT_ID set")


def get_aws_region():
def get_aws_region() -> str:
"""Can override ECR region by setting the AWS_REGION environment variable."""
return os.environ.get("AWS_REGION", DEFAULT_AWS_ECR_REGION)


def ensure_ecr_login(aws_region=DEFAULT_AWS_ECR_REGION):
def ensure_ecr_login(aws_region: str = DEFAULT_AWS_ECR_REGION):
check.str_param(aws_region, "aws_region")

cmd = "aws ecr get-login --no-include-email --region {} | sh".format(aws_region)
Expand All @@ -46,9 +44,14 @@ def ensure_ecr_login(aws_region=DEFAULT_AWS_ECR_REGION):
)


def ecr_image(image, tag, aws_account_id, aws_region=DEFAULT_AWS_ECR_REGION):
def ecr_image(
image: str, tag: Optional[str], aws_account_id: str, aws_region: str = DEFAULT_AWS_ECR_REGION
) -> str:
check.str_param(image, "image")
check.opt_str_param(aws_account_id, "tag")
check.str_param(aws_account_id, "aws_account_id")
check.str_param(aws_region, "aws_region")

return "{}/{}:{}".format(aws_ecr_repository(aws_account_id, aws_region), image, tag)
repo = aws_ecr_repository(aws_account_id, aws_region)
tail = f"{image}:{tag}" if tag else image
return f"{repo}/{tail}"
19 changes: 10 additions & 9 deletions python_modules/automation/automation/git.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import re
import subprocess
from typing import Optional

from .utils import check_output


def git_check_status():
def git_check_status() -> None:
changes = subprocess.check_output(["git", "status", "--porcelain"]).decode("utf-8").strip()

if changes != "":
Expand All @@ -16,11 +17,11 @@ def git_check_status():
)


def git_user():
def git_user() -> str:
return subprocess.check_output(["git", "config", "--get", "user.name"]).decode("utf-8").strip()


def git_repo_root(path=None):
def git_repo_root(path: Optional[str] = None) -> str:
if not path:
path = os.getcwd()

Expand All @@ -31,7 +32,7 @@ def git_repo_root(path=None):
)


def git_push(tag=None, dry_run=True, cwd=None):
def git_push(tag: Optional[str] = None, dry_run: bool = True, cwd: Optional[str] = None):
github_token = os.getenv("GITHUB_TOKEN")
github_username = os.getenv("GITHUB_USERNAME")
if github_token and github_username:
Expand Down Expand Up @@ -65,7 +66,7 @@ def git_push(tag=None, dry_run=True, cwd=None):
check_output(["git", "push"], dry_run=dry_run, cwd=cwd)


def get_git_tag():
def get_git_tag() -> str:
try:
git_tag = str(
subprocess.check_output(
Expand All @@ -87,7 +88,7 @@ def get_git_tag():
return git_tag


def get_most_recent_git_tag():
def get_most_recent_git_tag() -> str:
try:
git_tag = (
subprocess.check_output(["git", "describe", "--abbrev=0"], stderr=subprocess.STDOUT)
Expand All @@ -99,7 +100,7 @@ def get_most_recent_git_tag():
return git_tag


def get_git_repo_branch(cwd=None):
def get_git_repo_branch(cwd: Optional[str] = None) -> str:
git_branch = (
subprocess.check_output(["git", "branch", "--show-current"], cwd=cwd)
.decode("utf-8")
Expand All @@ -108,7 +109,7 @@ def get_git_repo_branch(cwd=None):
return git_branch


def set_git_tag(tag, signed=False, dry_run=True):
def set_git_tag(tag: str, signed: bool = False, dry_run: bool = True) -> str:
try:
if signed:
if not dry_run:
Expand Down Expand Up @@ -144,7 +145,7 @@ def set_git_tag(tag, signed=False, dry_run=True):
return tag


def git_commit_updates(repo_dir, message):
def git_commit_updates(repo_dir: str, message: str) -> None:
cmds = [
"git add -A",
'git commit -m "{}"'.format(message),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import os
import re
from typing import Dict, NamedTuple, Set, Tuple
from typing import AbstractSet, Dict, NamedTuple, Tuple, cast

import dagster_graphql_tests
from dagster_graphql.client import client_queries

import dagster._check as check


class LegacyQueryHistoryInfo(NamedTuple):
directory: str
legacy_queries: Set[str]
legacy_queries: AbstractSet[str]

@staticmethod
def get() -> "LegacyQueryHistoryInfo":
Expand Down Expand Up @@ -42,4 +44,9 @@ def serialize_to_query_filename(dagster_version: str, date: str) -> str:


def deserialize_from_query_filename(query_filename: str) -> Tuple[str, str]:
return query_filename.rstrip(".graphql").split("-")
parts = tuple(query_filename.rstrip(".graphql").split("-"))
check.invariant(
len(parts) == 2,
f"Invalid query filename {query_filename}; must have 2 '-' separated parts.",
)
return cast(Tuple[str, str], parts)
27 changes: 17 additions & 10 deletions python_modules/automation/automation/parse_spark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
"""
import re
import sys
from collections import namedtuple
from enum import Enum
from typing import Any, Dict, List, NamedTuple, Optional, Union, cast

import click
import requests
Expand Down Expand Up @@ -149,8 +149,8 @@ class ConfigType(Enum):
}


class SparkConfig(namedtuple("_SparkConfig", "path default meaning")):
def __new__(cls, path, default, meaning):
class SparkConfig(NamedTuple("_SparkConfig", [("path", str), ("default", str), ("meaning", str)])):
def __new__(cls, path: str, default: object, meaning: str):
# The original documentation strings include extraneous newlines, spaces
return super(SparkConfig, cls).__new__(
cls,
Expand All @@ -160,10 +160,10 @@ def __new__(cls, path, default, meaning):
)

@property
def split_path(self):
def split_path(self) -> List[str]:
return self.path.split(".")

def write(self, printer):
def write(self, printer: IndentingBufferPrinter) -> None:
config_type = CONFIG_TYPES.get(self.path, ConfigType.STRING).value

printer.append("Field(")
Expand All @@ -180,14 +180,21 @@ def write(self, printer):


class SparkConfigNode:
def __init__(self, value=None):

value: Optional[SparkConfig]
children: Dict[str, Any]

def __init__(self, value: Optional[SparkConfig] = None):
self.value = value
self.children = {}

def write(self, printer):
def write(self, printer: IndentingBufferPrinter) -> str:
if not self.children:
assert self.value
self.value.write(printer)
else:
self.children = cast(Dict[str, Union[SparkConfig, SparkConfigNode]], self.children)
retdict: Dict[str, Union[SparkConfig, SparkConfigNode]]
if self.value:
retdict = {"root": self.value}
retdict.update(self.children)
Expand All @@ -213,7 +220,7 @@ def write(self, printer):
return printer.read()


def extract(spark_docs_markdown_text):
def extract(spark_docs_markdown_text: str) -> SparkConfigNode:
import pytablereader as ptr

tables = re.findall(TABLE_REGEX, spark_docs_markdown_text, re.DOTALL | re.MULTILINE)
Expand Down Expand Up @@ -248,7 +255,7 @@ def extract(spark_docs_markdown_text):
return result


def serialize(result):
def serialize(result: SparkConfigNode) -> bytes:
with IndentingBufferPrinter() as printer:
printer.write_header()
printer.line("from dagster import Bool, Field, Float, IntSource, Permissive, StringSource")
Expand All @@ -264,7 +271,7 @@ def serialize(result):


@click.command()
def run():
def run() -> None:
r = requests.get(
"https://raw.githubusercontent.com/apache/spark/{}/docs/configuration.md".format(
SPARK_VERSION
Expand Down
20 changes: 14 additions & 6 deletions python_modules/automation/automation/printer.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
import os
import sys
from io import StringIO
from typing import Any, Callable, List, Type

from dagster.utils.indenting_printer import IndentingPrinter


class IndentingBufferPrinter(IndentingPrinter):
"""Subclass of IndentingPrinter wrapping a StringIO."""

def __init__(self, indent_level=4, current_indent=0):
buffer: StringIO

def __init__(self, indent_level: int = 4, current_indent: int = 0):
self.buffer = StringIO()
self.printer = lambda x: self.buffer.write(x + "\n")
self.printer: Callable[[str], Any] = lambda x: self.buffer.write(x + "\n")
super(IndentingBufferPrinter, self).__init__(
indent_level=indent_level, printer=self.printer, current_indent=current_indent
)

def __enter__(self):
def __enter__(self) -> "IndentingBufferPrinter":
return self

def __exit__(self, _exception_type, _exception_value, _traceback):
def __exit__(
self,
_exception_type: Type[BaseException],
_exception_value: BaseException,
_traceback: List[str],
) -> None:
self.buffer.close()

def read(self):
def read(self) -> str:
"""Get the value of the backing StringIO."""
return self.buffer.getvalue()

def write_header(self):
def write_header(self) -> None:
args = [os.path.basename(sys.argv[0])] + sys.argv[1:]
self.line("'''NOTE: THIS FILE IS AUTO-GENERATED. DO NOT EDIT")
self.blank_line()
Expand Down
12 changes: 5 additions & 7 deletions python_modules/automation/automation/scaffold/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""


def add_to_examples_json(name):
def add_to_examples_json(name: str) -> None:
with open(EXAMPLES_JSON_PATH, "r", encoding="utf8") as examples_file:
examples = json.load(examples_file)

Expand All @@ -44,15 +44,13 @@ def cli():
@click.option(
"--name", prompt='Name of library (ex: "foo" will create dagster-foo)', help="Name of library"
)
def library(name):
def library(name: str):
"""Scaffolds a Dagster library <NAME> in python_modules/libraries/dagster-<NAME>."""
template_library_path = os.path.join(ASSETS_PATH, "dagster-library-tmpl")
new_template_library_path = os.path.abspath(
"python_modules/libraries/dagster-{name}".format(name=name)
)
new_template_library_path = os.path.abspath(f"python_modules/libraries/dagster-{name}")

if os.path.exists(new_template_library_path):
raise click.UsageError("Library with name {name} already exists".format(name=name))
raise click.UsageError(f"Library with name {name} already exists")

copy_directory(template_library_path, new_template_library_path)

Expand Down Expand Up @@ -126,7 +124,7 @@ def example(name):
print("Added metadata to {path}".format(path=EXAMPLES_JSON_PATH))


def main():
def main() -> None:
click_cli = click.CommandCollection(sources=[cli], help=CLI_HELP)
click_cli()

Expand Down
2 changes: 1 addition & 1 deletion python_modules/automation/automation/scaffold/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil


def copy_directory(src, dest):
def copy_directory(src: str, dest: str) -> None:
try:
shutil.copytree(src, dest, ignore=shutil.ignore_patterns(".DS_Store"))
# Directories are the same
Expand Down

0 comments on commit 1bd83e4

Please sign in to comment.