Skip to content

Commit

Permalink
fix more type errors after upgrading mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
jfischer committed Mar 14, 2020
1 parent dd6c944 commit 8b91a77
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 41 deletions.
2 changes: 1 addition & 1 deletion dataworkspaces/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def restore(tag_or_hash:str, workspace_uri_or_path:Optional[str]=None,

def make_lineage_table(workspace_uri_or_path:Optional[str]=None,
tag_or_hash:Optional[str]=None, verbose:bool=False) \
-> Iterable[Tuple[str, str, str, str, Optional[List[str]]]]:
-> Iterable[Tuple[str, str, str, Optional[List[str]]]]:
"""Make a table of the lineage for each resource.
The columns are: ref, lineage type, details, inputs
"""
Expand Down
2 changes: 1 addition & 1 deletion dataworkspaces/commands/delete_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def delete_snapshot_command(workspace:Workspace, tag_or_hash:str,
md = mixin.get_snapshot_by_tag_or_hash(tag_or_hash)
snapshot_name = '%s (Tagged as: %s)' % (md.hashval[0:7], ', '.join(md.tags)) \
if md.tags is not None \
else md.hashvale
else md.hashval
if not workspace.batch:
if not click.confirm("Should I delete snapshot %s? This is not reversible." % snapshot_name):
raise UserAbort()
Expand Down
7 changes: 4 additions & 3 deletions dataworkspaces/commands/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dataworkspaces.errors import ConfigurationError
from dataworkspaces.workspace import Workspace, LocalStateResourceMixin,\
SyncedWorkspaceMixin, CentralWorkspaceMixin
SyncedWorkspaceMixin, CentralWorkspaceMixin, Resource

def build_resource_list(workspace:Workspace, only:Optional[List[str]], skip:Optional[List[str]]) \
-> List[str]:
Expand Down Expand Up @@ -64,11 +64,12 @@ def push_command(workspace:Workspace, only:Optional[List[str]]=None, skip:Option
click.echo("No resources to push.")
return 0
else:
print("Pushing resources: %s" % ', '.join([r.name for r in resource_list]))
print("Pushing resources: %s" % ', '.join([cast(Resource,r).name for r in resource_list]))
workspace.push_resources(resource_list)
elif isinstance(workspace, SyncedWorkspaceMixin):
if len(resource_list)>0:
click.echo("Pushing workspace and resources: %s" % ', '.join([r.name for r in resource_list]))
click.echo("Pushing workspace and resources: %s" % ', '.join([cast(Resource,r).name
for r in resource_list]))
elif not only_workspace:
click.echo("No resources to push, will still push workspace")
else:
Expand Down
3 changes: 2 additions & 1 deletion dataworkspaces/commands/restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def fmt_rlist(rnames):
raise UserAbort()

# do the work!
mixin.restore(md.hashval, restore_hashes, restore_resource_list)
mixin.restore(md.hashval, restore_hashes,
cast(List[SnapshotResourceMixin], restore_resource_list))
workspace.save("Restore to %s" % md.hashval)


Expand Down
4 changes: 2 additions & 2 deletions dataworkspaces/commands/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def snapshot_command(workspace:Workspace, tag:Optional[str]=None, message:str=''
# Remove existing tag if present
if tag is not None:
try:
existing_tag_md = mixin.get_snapshot_by_tag(tag)
existing_tag_md = mixin.get_snapshot_by_tag(tag) # type: Optional[SnapshotMetadata]
except ConfigurationError:
existing_tag_md = None
if existing_tag_md is not None:
Expand All @@ -75,7 +75,7 @@ def snapshot_command(workspace:Workspace, tag:Optional[str]=None, message:str=''
(md, manifest) = mixin.snapshot(tag, message)

try:
old_md = mixin.get_snapshot_metadata(md.hashval)
old_md = mixin.get_snapshot_metadata(md.hashval) # type: Optional[SnapshotMetadata]
except:
old_md = None
if old_md is not None:
Expand Down
3 changes: 2 additions & 1 deletion dataworkspaces/kits/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(self, shell):
if self.disabled:
print("Loaded Data Workspaces magic commands in disabled state.", file=sys.stderr)
return
self._snapshot_args = None
self._snapshot_args = None # type: Optional[argparse.Namespace]
def target_func(comm, open_msg):
self.comm = comm
@comm.on_msg
Expand Down Expand Up @@ -295,6 +295,7 @@ def _recv(msg):
elif msg_type=='snapshot':
cell = data['cell']
try:
assert self._snapshot_args is not None
r = take_snapshot(self.dws_jupyter_info.workspace_dir,
tag=self._snapshot_args.tag,
message=self._snapshot_args.message)
Expand Down
2 changes: 2 additions & 0 deletions dataworkspaces/kits/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def load_dataset_from_resource(resource_name:str, subpath:Optional[str]=None,
raise ConfigurationError("Unable to instantiate a data set for resource '%s': currently not supported for non-local resources"%
resource_name)
local_path = r.get_local_path_if_any()
assert local_path is not None
dataset_path = join(local_path, subpath) if subpath is not None else local_path
result = {} # this will be the args to the result Bunch
# First load data and target files, which are required
Expand Down Expand Up @@ -412,6 +413,7 @@ def _init_dws_state(self):
self.results_resource)

def _save_model(self):
assert self.model_save_file
if not self.model_save_file.endswith('.joblib'):
model_save_file = self.model_save_file + '.joblib'
else:
Expand Down
6 changes: 2 additions & 4 deletions dataworkspaces/kits/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ def _verify_eager_if_dataset(x, y, api_resource):
to evaluate the dataset outside the tensor graph.
"""
if (not USING_TENSORFLOW2) and \
(isinstance(x, tensorflow.data.Dataset) or
isinstance(y, tensorflow.data.Dataset)) and \
(not tensorflow.executing_eagerly()):
(isinstance(x, tensorflow.data.Dataset) or isinstance(y, tensorflow.data.Dataset)) and (not tensorflow.executing_eagerly()): # type: ignore
raise NotSupportedError("Using an API resource ("+ api_resource.name+
") with non-eager datasets is not "+
"supported with TensorFlow 1.x.")
Expand Down Expand Up @@ -412,7 +410,7 @@ def __init__(self,*args,**kwargs):
save_freq=checkpoint_config.save_freq,
results_resource=results_resource,
workspace_dir=workspace_dir,
verbose=verbose)
verbose=verbose) # type: Optional[DwsModelCheckpoint]
else:
self.checkpoint_cb = None
def compile(self, optimizer,
Expand Down
4 changes: 2 additions & 2 deletions dataworkspaces/kits/wrapper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def _add_to_hash(array_data, hash_state):
# the columns of a CSV.
for column in array_data.values():
_add_to_hash(column, hash_state)
elif (tensorflow is not None) and isinstance(array_data, tensorflow.data.Dataset):
elif (tensorflow is not None) and isinstance(array_data, tensorflow.data.Dataset): # type: ignore
# We need to iterate through the dataset, to force an eager evaluation
for t in array_data:
_add_to_hash(t, hash_state)
elif (tensorflow is not None) and isinstance(array_data, tensorflow.Tensor):
elif (tensorflow is not None) and isinstance(array_data, tensorflow.Tensor): # type: ignore
if hasattr(array_data, 'numpy'):
_add_to_hash(array_data.numpy(), hash_state)
else:
Expand Down
8 changes: 0 additions & 8 deletions dataworkspaces/resources/rclone_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,6 @@ def push(self) -> None:
"""
pass



def add(self):
print("rclone: Add is called")
self.add_from_remote()



def snapshot_precheck(self) -> None:
pass

Expand Down
8 changes: 4 additions & 4 deletions dataworkspaces/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,13 @@ def convert(self, value, param, ctx):
while not isdir(parent_dir) and parent_dir!='/':
parent_dir = parent_path(parent_dir)
if not exists(parent_dir):
self.fail('%s "%s" does not exist.' % (self.path_type, parent_dir), param, ctx)
self.fail('%s "%s" does not exist.' % (self.path_type, parent_dir), param, ctx) # type: ignore
if not isdir(parent_dir):
self.fail('%s "%s" is a file.' % (self.path_type, parent_dir), param, ctx)
self.fail('%s "%s" is a file.' % (self.path_type, parent_dir), param, ctx) # type: ignore
if not os.access(parent_dir, os.W_OK):
self.fail('%s "%s" is not writable.' % (self.path_type, parent_dir), param, ctx)
self.fail('%s "%s" is not writable.' % (self.path_type, parent_dir), param, ctx) # type: ignore
if (self.must_be_outside_of_workspace is not None) and \
commonpath([self.must_be_outside_of_workspace, rv]) in (self.must_be_outside_of_workspace, rv):
self.fail('%s must be outside of workspace "%s"' %
(self.path_type, self.must_be_outside_of_workspace), param, ctx)
(self.path_type, self.must_be_outside_of_workspace), param, ctx) # type: ignore
return rv
2 changes: 1 addition & 1 deletion dataworkspaces/utils/lineage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def snapshot_lineage(self, instance:str, snapshot_hash:str,

@abstractmethod
def restore_lineage(self, instance:str, snapshot_hash:str,
resources_to_restore:List[str], vebose:bool=False) -> None:
resources_to_restore:List[str], verbose:bool=False) -> None:
"""Restore the lineage for the specified resources from the specified snapshot.
Any existing entries for the specified resources should first be cleared.
Then, any entries for those resources copied to the current lineage.
Expand Down
36 changes: 23 additions & 13 deletions dataworkspaces/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def validate_local_path_for_resource(self, proposed_resource_name:str,
if not isinstance(r, LocalStateResourceMixin) or \
r.get_local_path_if_any() is None:
continue
other_real_path = os.path.realpath(r.get_local_path_if_any())
other_real_path = os.path.realpath(cast(str, r.get_local_path_if_any()))
common = os.path.commonpath([real_local_path, other_real_path])
if other_real_path==common or real_local_path==common:
raise ConfigurationError("Proposed path %s for resource %s, conflicts with local path %s for resource %s"%
Expand Down Expand Up @@ -679,7 +679,7 @@ def upload_file(self, src_local_path:str,
pass

@abstractmethod
def read_results_file(self, subpath:str) -> Union[JSONDict,JSONList]:
def read_results_file(self, subpath:str) -> Union[JSONDict]:
"""Read and parse json results data from the specified path
in the resource. If the path does not exist or is not a file
throw a ConfigurationError.
Expand Down Expand Up @@ -846,7 +846,9 @@ def pull_resources(self, resource_list:List[LocalStateResourceMixin]) -> None:
self._pull_resources_precheck(resource_list)
assert isinstance(self, Workspace)
for r in resource_list:
print("[pull] pulling resource %s" %r.name) # XXX
assert isinstance(r, Resource)
if self.verbose:
print("[pull] pulling resource %s" %r.name)
r.pull()

# We need to clear the current lineage for pulled resources since we
Expand All @@ -855,7 +857,7 @@ def pull_resources(self, resource_list:List[LocalStateResourceMixin]) -> None:
instance = self.get_instance()
lstore = self.get_lineage_store()
for r in resource_list:
print("[pull] clearing resource %s" %r.name) # XXX
assert isinstance(r, Resource)
if self.verbose:
print("Clearing lineage on resource %s" % r.name)
lstore.clear_entry(instance, ResourceRef(r.name, None))
Expand Down Expand Up @@ -966,7 +968,7 @@ def matches_partial_hash(self, partial_hash):
"""A partial hash matches if the full hash starts with it,
normalizing to lower case.
"""
return True if self.hashval.startwith(partial_hash.lower()) else False
return True if self.hashval.startswith(partial_hash.lower()) else False

def to_json(self) -> JSONDict:
v = {
Expand Down Expand Up @@ -1130,7 +1132,7 @@ def snapshot(self, tag:Optional[str]=None, message:str='') -> Tuple[SnapshotMeta
print("Cleared lineage for results resource %s" % rname)
return metadata, manifest_bytes

def _restore_precheck(self, restore_hashes:Dict[str,str],
def _restore_precheck(self, restore_hashes:Dict[str,Optional[str]],
restore_resources:List['SnapshotResourceMixin']) -> None:
"""Run any prechecks before restoring to the specified hash value
(aka certificate). This should throw a ConfigurationError if the
Expand All @@ -1141,9 +1143,11 @@ def _restore_precheck(self, restore_hashes:Dict[str,str],
This method is called by restore()
"""
for r in restore_resources:
r.restore_precheck(restore_hashes[cast(Resource, r).name])
hashval = restore_hashes[cast(Resource, r).name]
assert hashval is not None
r.restore_precheck(hashval)

def restore(self, snapshot_hash:str, restore_hashes:Dict[str,str],
def restore(self, snapshot_hash:str, restore_hashes:Dict[str,Optional[str]],
restore_resources:List['SnapshotResourceMixin']) -> None:
"""Restore the specified resources to the specified hashes.
The list should have been previously filtered to include only
Expand All @@ -1152,13 +1156,15 @@ def restore(self, snapshot_hash:str, restore_hashes:Dict[str,str],
self._restore_precheck(restore_hashes, restore_resources)

for r in restore_resources:
r.restore(restore_hashes[cast(Resource, r).name])
hashval = restore_hashes[cast(Resource, r).name]
assert hashval is not None
r.restore(hashval)

if self.supports_lineage():
assert isinstance(self, Workspace)
self.get_lineage_store().restore_lineage(self.get_instance(),
snapshot_hash,
[r.name for r in
[cast(Resource,r).name for r in
restore_resources],
verbose=self.verbose)

Expand Down Expand Up @@ -1205,7 +1211,7 @@ def _get_snapshot_manifest_as_bytes(self, hash_val:str) -> bytes:
"""
pass

def get_snapshot_manifest(self, hash_val:str) -> JSONDict:
def get_snapshot_manifest(self, hash_val:str) -> JSONList:
"""Returns the snapshot manifest for the given hash
as a parsed JSON structure. The top-level dict maps
resource names resource parameters.
Expand Down Expand Up @@ -1263,8 +1269,12 @@ def delete_snapshot(self, hash_val:str, include_resources=False)-> None:
for rname in to_delete:
r = cast(Workspace, self).get_resource(rname)
if isinstance(r, SnapshotResourceMixin):
r.delete_snapshot(md.hashval, md.restore_hashes[rname],
md.relative_destination_path)
delete_hash = md.restore_hashes[rname]
if delete_hash is not None:
r.delete_snapshot(md.hashval, delete_hash,
md.relative_destination_path)
else:
print("Cannot delete snapshot for resource %s, no restore hash" % rname)
self._delete_snapshot_metadata_and_manifest(hash_val)
if self.supports_lineage():
instance = cast(Workspace, self).get_instance()
Expand Down

0 comments on commit 8b91a77

Please sign in to comment.