diff --git a/allennlp/allennlp/models/archival.py b/allennlp/allennlp/models/archival.py index 7cca7b6..5db007c 100644 --- a/allennlp/allennlp/models/archival.py +++ b/allennlp/allennlp/models/archival.py @@ -190,7 +190,26 @@ def load_archive( tempdir = tempfile.mkdtemp() logger.info(f"extracting archive file {resolved_archive_file} to temp dir {tempdir}") with tarfile.open(resolved_archive_file, "r:gz") as archive: - archive.extractall(tempdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, tempdir) # Postpone cleanup until exit in case the unarchived contents are needed outside # this function. atexit.register(_cleanup_archive_dir, tempdir) diff --git a/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_multilang_test.py b/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_multilang_test.py index 6b78ccb..eabc2cc 100644 --- a/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_multilang_test.py +++ b/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_multilang_test.py @@ -37,7 +37,26 @@ def test_file_archiving(self): archive_file = os.path.join(serialization_dir, "model.tar.gz") unarchive_dir = os.path.join(self.TEST_DIR, "unarchive") with tarfile.open(archive_file, "r:gz") as archive: - archive.extractall(unarchive_dir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, unarchive_dir) # It should contain `files_to_archive.json` fta_file = os.path.join(unarchive_dir, "files_to_archive.json") diff --git a/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_test.py b/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_test.py index 96d6500..2349901 100644 --- a/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_test.py +++ b/allennlp/allennlp/tests/modules/token_embedders/elmo_token_embedder_test.py @@ -51,7 +51,26 @@ def test_file_archiving(self): archive_file = os.path.join(serialization_dir, "model.tar.gz") unarchive_dir = os.path.join(self.TEST_DIR, "unarchive") with tarfile.open(archive_file, "r:gz") as archive: - archive.extractall(unarchive_dir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, unarchive_dir) # It should contain `files_to_archive.json` fta_file = os.path.join(unarchive_dir, "files_to_archive.json") diff --git a/allennlp/allennlp/tools/archive_surgery.py b/allennlp/allennlp/tools/archive_surgery.py index fc1014d..8070d30 100644 --- a/allennlp/allennlp/tools/archive_surgery.py +++ b/allennlp/allennlp/tools/archive_surgery.py @@ -68,7 +68,26 @@ def main(): # Extract archive to temp dir tempdir = tempfile.mkdtemp() with tarfile.open(archive_file, "r:gz") as archive: - archive.extractall(tempdir) + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(archive, tempdir) atexit.register(lambda: shutil.rmtree(tempdir)) config_path = os.path.join(tempdir, CONFIG_NAME)