diff --git a/bw2io/__init__.py b/bw2io/__init__.py index 6b83fbf3..ca05841c 100644 --- a/bw2io/__init__.py +++ b/bw2io/__init__.py @@ -28,6 +28,7 @@ "exiobase_monetary", "get_csv_example_filepath", "get_xlsx_example_filepath", + "install_project", "lci_matrices_to_excel", "lci_matrices_to_matlab", "load_json_data_file", @@ -91,6 +92,7 @@ from .units import normalize_units from .unlinked_data import unlinked_data, UnlinkedData from .utils import activity_hash, es2_activity_hash, load_json_data_file +from .remote import install_project from bw2data import config, databases diff --git a/bw2io/backup.py b/bw2io/backup.py index 69827436..7d55ac36 100644 --- a/bw2io/backup.py +++ b/bw2io/backup.py @@ -1,8 +1,12 @@ +from pathlib import Path +from typing import Optional import codecs import datetime import json import os +import shutil import tarfile +import tempfile from bw2data import projects from bw_processing import safe_filename @@ -30,7 +34,7 @@ def backup_data_directory(): tar.add(projects.dir, arcname=os.path.basename(projects.dir)) -def backup_project_directory(project): +def backup_project_directory(project: str): """ Backup project data directory to a ``.tar.gz`` (compressed tar archive) in the user's home directory. @@ -70,9 +74,10 @@ def backup_project_directory(project): with tarfile.open(fp, "w:gz") as tar: tar.add(dir_path, arcname=safe_filename(project)) - return project_name + return project + -def restore_project_directory(fp): +def restore_project_directory(fp: str, project_name: Optional[str] = None, overwrite_existing: Optional[bool] = False): """ Restore a backed up project data directory from a ``.tar.gz`` (compressed tar archive) in the user's home directory. @@ -80,6 +85,9 @@ def restore_project_directory(fp): ---------- fp : str File path of the project to restore. + project_name : str, optional + Name of new project to create + overwrite_existing : bool, optional Returns ------- @@ -107,31 +115,42 @@ def get_project_name(fp): assert os.path.isfile(fp), "Can't find file at path: {}".format(fp) print("Restoring project backup archive - this could take a few minutes...") - project_name = get_project_name(fp) + project_name = get_project_name(fp) if project_name is None else project_name + + if project_name in projects and not overwrite_existing: + raise ValueError("Project {} already exists".format(project_name)) + + with tempfile.TemporaryDirectory() as td: + with tarfile.open(fp, "r:gz") as tar: + def is_within_directory(directory, target): - with tarfile.open(fp, "r:gz") as tar: - def is_within_directory(directory, target): + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) - abs_directory = os.path.abspath(directory) - abs_target = os.path.abspath(target) + prefix = os.path.commonprefix([abs_directory, abs_target]) - prefix = os.path.commonprefix([abs_directory, abs_target]) + return prefix == abs_directory - return prefix == abs_directory + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): - def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") - for member in tar.getmembers(): - member_path = os.path.join(path, member.name) - if not is_within_directory(path, member_path): - raise Exception("Attempted Path Traversal in Tar File") + tar.extractall(path, members, numeric_owner=numeric_owner) - tar.extractall(path, members, numeric_owner=numeric_owner) + safe_extract(tar, td) + # Find single extracted directory; don't know it ahead of time + extracted_dir = [(Path(td) / dirname) for dirname in Path(td).iterdir() if (Path(td) / dirname).is_dir()] + if not len(extracted_dir) == 1: + raise ValueError("Can't find single directory extracted from project archive") + extracted_path = extracted_dir[0] - safe_extract(tar, projects._base_data_dir) + _current = projects.current + projects.set_current(project_name, update=False) + shutil.copytree(extracted_path, projects.dir, dirs_exist_ok=True) + projects.set_current(_current) - _current = projects.current - projects.set_current(project_name, update=False) - projects.set_current(_current) return project_name diff --git a/bw2io/remote.py b/bw2io/remote.py new file mode 100644 index 00000000..5c965234 --- /dev/null +++ b/bw2io/remote.py @@ -0,0 +1,84 @@ +from pathlib import Path +from typing import Optional + +import bw2data as bd +import requests +from platformdirs import user_data_dir + +from .backup import restore_project_directory +from .download_utils import download_with_progressbar + +PROJECTS_BW2 = { + "ecoinvent-3.8-biosphere": "ecoinvent-3.8-biosphere.bw2.tar.gz", + "ecoinvent-3.9.1-biosphere": "ecoinvent-3.9.1-biosphere.bw2.tar.gz", +} + +PROJECTS_BW25 = { + "ecoinvent-3.8-biosphere": "ecoinvent-3.8-biosphere.tar.gz", + "ecoinvent-3.9.1-biosphere": "ecoinvent-3.9.1-biosphere.tar.gz", + "USEEIO-1.1": "USEEIO-1.1.tar.gz", +} + +cache_dir = Path( + user_data_dir(appname="bw2io-project-cache", appauthor="brightway-team") +) +cache_dir.mkdir(exist_ok=True) + + +def get_projects(update_config: Optional[bool] = True) -> dict: + BW2 = bd.__version__ < (4,) + projects = PROJECTS_BW2 if BW2 else PROJECTS_BW25 + URL = "https://files.brightway.dev/" + FILENAME = "projects-config.bw2.json" if BW2 else "projects-config.json" + if update_config: + try: + projects = requests.get(URL + FILENAME).json() + except: + pass + return projects + + +def install_project( + project_key: str, + project_name: Optional[str] = None, + projects_config: Optional[dict] = get_projects(), + url: Optional[str] = "https://files.brightway.dev/", + overwrite_existing: Optional[bool] = False, +): + """ + Install an existing Brightway project archive. + + By default uses ``https://files.brightway.dev/`` as the file repository, but you can run your own. + + Parameters + ---------- + project_key: str + A string uniquely identifying a project, e.g. ``ecoinvent-3.8-biosphere``. + project_name: str, optional + The name of the new project to create. If not provided will be taken from the archive file. + projects_config: dict, optional + A dictionary that maps ``project_key`` values to filenames at the repository + url: str, optional + The URL, with trailing slash ``/``, where the file can be found. + overwrite_existing: bool, optional + Allow overwriting an existing project + + Returns + ------- + str + The name of the created project. + """ + try: + filename = projects_config[project_key] + except KeyError: + raise KeyError(f"Project key {project_key} not in `project_config`") + + fp = cache_dir / filename + if not fp.exists(): + download_with_progressbar( + url=url + filename, filename=filename, dirpath=cache_dir + ) + + return restore_project_directory( + fp=fp, project_name=project_name, overwrite_existing=overwrite_existing + ) diff --git a/requirements-test.txt b/requirements-test.txt index 55436f71..ab4be926 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,14 +7,15 @@ matrix_utils mrio_common_metadata numpy openpyxl +platformdirs pyprind pytest pytest-cov python-coveralls requests -tqdm scipy stats_arrays +tqdm unidecode voluptuous xlrd diff --git a/setup.py b/setup.py index fdb9e9cc..52ea50b2 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ "mrio_common_metadata", "numpy", "openpyxl", + "platformdirs", "pyprind", "requests", "scipy",