Skip to content

Add overwrite option to prepare #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
author = "Bradley Steinfeld, Sam Prokopchuk, James Reeve"

# The full version, including alpha/beta/rc tags
release = "0.20.3"
release = "0.20.4"


# -- General configuration ---------------------------------------------------
Expand Down
63 changes: 54 additions & 9 deletions skillsnetwork/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,27 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
raise Exception(f"Failed to read dataset at {url}") from None


def _verify_files_dont_exist(paths: Iterable[Union[str, Path]]) -> None:
def _verify_files_dont_exist(
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
) -> None:
"""
Verifies all paths in 'paths' don't exist.
:param paths: A iterable of strs or pathlib.Paths.
:param remove_if_exist=False: Removes file at path if they already exist.
:returns: None
:raises FileExistsError: On the first path found that already exists.
"""
for path in paths:
if Path(path).exists():
raise FileExistsError(f"Error: File '{path}' already exists.")
path = Path(path)
if path.exists():
if remove_if_exist:
if path.is_symlink():
realpath = path.resolve()
path.unlink(realpath)
else:
shutil.rmtree(path)
else:
raise FileExistsError(f"Error: File '{path}' already exists.")


def _is_file_to_symlink(path: Path) -> bool:
Expand Down Expand Up @@ -188,7 +199,9 @@ async def read(url: str, chunk_size: int = DEFAULT_CHUNK_SIZE) -> bytes:
return b"".join([chunk async for chunk in _get_chunks(url, chunk_size)])


async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) -> None:
async def prepare(
url: str, path: Optional[str] = None, verbose: bool = True, overwrite: bool = False
) -> None:
"""
Prepares a dataset for learners. Downloads a dataset from the given url,
decompresses it if necessary. If not using jupyterlite, will extract to
Expand All @@ -200,6 +213,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->

:param url: The URL to download the dataset from.
:param path: The path the dataset will be available at. Current working directory by default.
:param verbose=True: Prints saved path if True.
:param overwrite=False: Overwrites any existing files at destination if they exist.
:raise InvalidURLException: When URL is invalid.
:raise FileExistsError: it raises this when a file to be symlinked already exists.
:raise ValueError: When requested path is in /tmp, or cannot be saved to path.
Expand Down Expand Up @@ -239,7 +254,8 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
path / child.name
for child in map(Path, tf.getnames())
if len(child.parents) == 1 and _is_file_to_symlink(child)
]
],
overwrite,
) # Only check if top-level fileobject
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
pbar.set_description(f"Extracting {filename}")
Expand All @@ -253,15 +269,16 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
path / child.name
for child in map(Path, zf.namelist())
if len(child.parents) == 1 and _is_file_to_symlink(child)
]
],
overwrite,
)
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
pbar.set_description(f"Extracting {filename}")
for member in pbar:
zf.extract(member=member, path=extract_dir)
tmp_download_file.unlink()
else:
_verify_files_dont_exist([path / filename])
_verify_files_dont_exist([path / filename], overwrite)
shutil.move(tmp_download_file, extract_dir / filename)

# If in jupyterlite environment, the extract_dir = path, so the files are already there.
Expand All @@ -274,8 +291,36 @@ async def prepare(url: str, path: Optional[str] = None, verbose: bool = True) ->
print(f"Saved to '{relpath(path.resolve())}'")


if _is_jupyterlite():
tqdm.monitor_interval = 0
def setup() -> None:
if _is_jupyterlite():
tqdm.monitor_interval = 0

try:
import sys # pyright: ignore

ipython = get_ipython()

def hide_traceback(
exc_tuple=None,
filename=None,
tb_offset=None,
exception_only=False,
running_compiled_code=False,
):
etype, value, tb = sys.exc_info()
value.__cause__ = None # suppress chained exceptions
return ipython._showtraceback(
etype, value, ipython.InteractiveTB.get_exception_only(etype, value)
)

ipython.showtraceback = hide_traceback

except NameError:
pass


setup()


# For backwards compatibility
download_dataset = download
Expand Down