Skip to content

Commit

Permalink
Change extract file to use Path
Browse files Browse the repository at this point in the history
  • Loading branch information
humrochagf committed Feb 14, 2021
1 parent 7c8f8b2 commit 560b068
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
14 changes: 8 additions & 6 deletions revelation/utils.py
Expand Up @@ -71,13 +71,15 @@ def move_and_replace(src: Path, dst: Path):
shutil.rmtree(src) # remove the dir structure from the source


def extract_file(compressed_file, path="."):
def extract_file(compressed_file: Path, path: Path = Path(".")) -> Path:
"""Extract function to extract from zip or tar file"""
if os.path.isfile(compressed_file):
if tarfile.is_tarfile(compressed_file):
path = path.resolve()

if compressed_file.is_file():
if tarfile.is_tarfile(str(compressed_file)):
with tarfile.open(compressed_file, "r:gz") as tfile:
basename = tfile.members[0].name
tfile.extractall(path + "/")
basename = tfile.getnames()[0]
tfile.extractall(str(path.resolve()))
elif zipfile.is_zipfile(compressed_file):
with zipfile.ZipFile(compressed_file, "r") as zfile:
basename = zfile.namelist()[0]
Expand All @@ -87,7 +89,7 @@ def extract_file(compressed_file, path="."):
else:
raise FileNotFoundError(f"{compressed_file} is not a valid file")

return os.path.abspath(os.path.join(path, basename))
return path / basename


def normalize_newlines(text):
Expand Down
12 changes: 4 additions & 8 deletions tests/test_utils.py
Expand Up @@ -35,9 +35,7 @@ def test_extract_file_zipfile(
):
src_files = sorted((f.name for f in presentation.root.iterdir()))

extracted_dir = Path(
extract_file(presentation_zip, str(presentation.parent))
)
extracted_dir = Path(extract_file(presentation_zip, presentation.parent))

extracted_files = sorted((f.name for f in extracted_dir.iterdir()))

Expand All @@ -50,9 +48,7 @@ def test_extract_file_tarfile(
):
src_files = sorted((f.name for f in presentation.root.iterdir()))

extracted_dir = Path(
extract_file(presentation_tar, str(presentation.parent))
)
extracted_dir = Path(extract_file(presentation_tar, presentation.parent))

extracted_files = sorted((f.name for f in extracted_dir.iterdir()))

Expand All @@ -61,15 +57,15 @@ def test_extract_file_tarfile(

def test_extract_file_on_non_file(tmp_path: Path):
with pytest.raises(FileNotFoundError):
extract_file(tmp_path / "notfound", str(tmp_path))
extract_file(tmp_path / "notfound", tmp_path)


def test_extract_file_on_non_tar_or_zip(tmp_path: Path):
wrong_format = tmp_path / "file.wrong"
wrong_format.write_text("", "utf8")

with pytest.raises(NotImplementedError):
extract_file(wrong_format, str(tmp_path))
extract_file(wrong_format, tmp_path)


def test_make_presentation(tmp_path: Path):
Expand Down

0 comments on commit 560b068

Please sign in to comment.