Skip to content

Commit

Permalink
add stricter type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Mar 14, 2020
1 parent a152a24 commit dd6c944
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 28 deletions.
2 changes: 1 addition & 1 deletion dataworkspaces/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class SnapshotInfo(NamedTuple):
to :func:`~get_snapshot_history`
"""
snapshot_number: int
hashval : int
hashval : str
tags : List[str]
timestamp: str
message: str
Expand Down
9 changes: 5 additions & 4 deletions dataworkspaces/backends/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def __init__(self, workspace:'Workspace'):
join(workspace.workspace_dir, SNAPSHOT_LINEAGE_DIR_PATH))
self.workspace = workspace

def _add_to_git(self, path):
git_add(self.workspace.workspace_dir,
[get_subpath_from_absolute(self.workspace.workspace_dir, path)],
def _add_to_git(self, path:str):
ws_dir = cast(str, self.workspace.workspace_dir)
git_add(ws_dir,
[get_subpath_from_absolute(ws_dir, path)], # type: ignore
verbose=self.workspace.verbose)

def _save_rfile_to_snapshot(self, resource_name:str,
Expand Down Expand Up @@ -94,7 +95,7 @@ def delete_snapshot_lineage(self, instance:str, snapshot_hash:str) -> None:
class Workspace(ws.Workspace, ws.SyncedWorkspaceMixin, ws.SnapshotWorkspaceMixin):
def __init__(self, workspace_dir:str, batch:bool=False,
verbose:bool=False):
self.workspace_dir = workspace_dir
self.workspace_dir = workspace_dir # type: str
cf_data = self._load_json_file(CONFIG_FILE_PATH)
super().__init__(cf_data['name'], cf_data['dws-version'], batch, verbose)
self.global_params = cf_data['global_params']
Expand Down
7 changes: 3 additions & 4 deletions dataworkspaces/commands/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
"A snapshot with this hash already exists. Do you want to update "+\
"the message from '%s' to '%s'?"

def merge_snapshot_metadata(old, new, batch):
def merge_snapshot_metadata(old:SnapshotMetadata, new:SnapshotMetadata, batch:bool) \
-> SnapshotMetadata:
"""Merge two snapshot metadatas for when someone creates
a snapshot without making changes. They might have
added more tags or changed the message.
"""
assert old.hashval == new.hashval
tags = old.tags + [tag for tag in new.tags
if tag not in old.tags]
if old.message!=new.message and (new.message is not None) \
(new.message!='') and (not batch) and \
click.confirm(_CONF_MESSAGE%(old.message, new.message)):
if old.message!=new.message and (new.message is not None) and (new.message!='') and (batch is False) and click.confirm(_CONF_MESSAGE%(old.message, new.message)): # type:ignore
message = new.message
else:
message = old.message
Expand Down
10 changes: 5 additions & 5 deletions dataworkspaces/dws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import click
import re
from os.path import isdir, join, abspath, expanduser, basename, curdir
from typing import Optional
from typing import Optional, Union
from argparse import Namespace
from collections.abc import Sequence

Expand Down Expand Up @@ -159,15 +159,15 @@ def convert(self, value, param, ctx):
class ResourceParamType(click.ParamType):
name = 'resources'

def convert(self, value, param, ctx):
def convert(self, value:Union[str,Sequence], param:Optional[click.Parameter], ctx):
parsed = []
if isinstance(value, str):
rl = value.lower().split(',')
rl = value.lower().split(',') # type: Sequence[str]
elif isinstance(value, Sequence):
rl = value
else:
self.failed("Invalid resource role list '%s', must be a string or a sequence"
% str(value))
self.fail("Invalid resource role list '%s', must be a string or a sequence"
% str(value))
for r in rl:
if r=='all':
return [r for r in RESOURCE_ROLE_CHOICES]
Expand Down
4 changes: 2 additions & 2 deletions dataworkspaces/kits/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _save_model(self):
target_name = model_save_file
cast(FileResourceMixin, resource).upload_file(tempname, target_name)
finally:
if exists(tempname):
if (tempname is not None) and exists(tempname):
os.remove(tempname)
if self.verbose:
print("dws> saved model file to %s:%s" %
Expand Down Expand Up @@ -497,7 +497,7 @@ def score(self, X, y, sample_weight=None):
api_resource.pop_hash_state()
predictions = self.predictor.predict(X)
if isinstance(self.metrics, str):
metrics_inst = _METRICS[self.metrics](y, predictions, sample_weight=sample_weight)
metrics_inst = _METRICS[self.metrics](y, predictions, sample_weight=sample_weight) # type: ignore
else:
metrics_inst = self.metrics(y, predictions, sample_weight=sample_weight)
self._dws_state.write_metrics_and_complete(metrics_inst.to_dict())
Expand Down
2 changes: 1 addition & 1 deletion dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __getitem__(self, idx):
_add_to_hash(inputs, self.hash_state)
_add_to_hash(targets, self.hash_state)
if sample_weights is not None:
_add_to_hash(sample_weights)
_add_to_hash(sample_weights, self.hash_state)
return v

def __len__(self):
Expand Down
13 changes: 8 additions & 5 deletions dataworkspaces/resources/git_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def switch_git_branch_if_needed(local_path, branch, verbose, ok_if_not_present=F
switch_git_branch(local_path, branch, verbose)

class GitLocalPathType(LocalPathType):
def __init__(self, remote_url, verbose):
def __init__(self, remote_url:str, verbose:bool):
super().__init__()
self.remote_url = remote_url
self.verbose = verbose
Expand All @@ -377,7 +377,7 @@ def convert(self, value, param, ctx):
remote = get_remote_origin(rv, verbose=self.verbose)
if remote!=self.remote_url:
self.fail('%s "%s" is a git repo with remote origin "%s", but dataworkspace has remote "%s"'%
(self.path_type, rv), param, ctx)
(self.path_type, rv, self.remote_url, remote), param, ctx)
return rv


Expand All @@ -388,9 +388,11 @@ def from_command_line(self, role, name, workspace,
arguments"""
workspace.validate_local_path_for_resource(name, local_path)
lpr = realpath(local_path)
wspath = realpath(workspace.get_workspace_local_path_if_any()) \
if workspace.get_workspace_local_path_if_any() is not None else None
if not is_git_repo(local_path):
if isinstance(workspace, git_backend.Workspace) and \
lpr.startswith(realpath(workspace.get_workspace_local_path_if_any())):
if isinstance(workspace, git_backend.Workspace) and wspath is not None and \
lpr.startswith(wspath):
if branch!='master':
raise ConfigurationError("Only the branch 'master' is available for resources that are within the workspace's git repository")
elif read_only:
Expand All @@ -402,7 +404,8 @@ def from_command_line(self, role, name, workspace,
# The local path is a git repo. Double-check that it isn't already part
# of the workspace's repo. If it is, you will get an error when cloning.
if isinstance(workspace, git_backend.Workspace) and \
lpr.startswith(realpath(workspace.get_workspace_local_path_if_any())) and \
wspath is not None and \
lpr.startswith(wspath) and \
is_file_tracked_by_git(local_path, workspace.get_workspace_local_path_if_any(),
verbose=workspace.verbose):
raise ConfigurationError("%s is a git repository, but also part of the parent workspace's repo"%(local_path))
Expand Down
14 changes: 8 additions & 6 deletions tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ test: clean mypy pyflakes
./testcli.sh --batch
python -m unittest $(UNIT_TESTS)

MYPY=mypy --config-file=$(shell pwd)/mypy.ini

mypy:
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); mypy workspace.py dws.py lineage.py api.py
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); mypy utils/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); mypy backends/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); mypy resources/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); mypy commands/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); $(MYPY) workspace.py dws.py lineage.py api.py
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); $(MYPY) utils/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); $(MYPY) backends/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); $(MYPY) resources/
cd $(DATAWORKSPACES); export MYPYPATH=$(MYPYPATH); $(MYPY) commands/
cd $(DATAWORKSPACES)/kits; export MYPYPATH=$(MYPYPATH); \
mypy --ignore-missing-imports $(MYPY_KITS)
$(MYPY) --ignore-missing-imports $(MYPY_KITS)

pyflakes:
cd $(DATAWORKSPACES); pyflakes workspace.py dws.py lineage.py api.py
Expand Down
2 changes: 2 additions & 0 deletions tests/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
check_untyped_defs = True

0 comments on commit dd6c944

Please sign in to comment.