Skip to content

Commit

Permalink
Simplify model download (#414)
Browse files Browse the repository at this point in the history
* Simplify model download

* Update model cache
  • Loading branch information
adamltyson committed May 9, 2024
1 parent eeffd78 commit 5f4882e
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 166 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:
uses: actions/cache@v3
with:
path: "~/.cellfinder"
key: models-${{ hashFiles('~/.cellfinder/**') }}
key: models-${{ hashFiles('~/.brainglobe/**') }}
# Setup pyqt libraries
- name: Setup qtpy libraries
uses: tlambert03/setup-qt-libs@v1
Expand All @@ -83,7 +83,7 @@ jobs:
uses: actions/cache@v3
with:
path: "~/.cellfinder"
key: models-${{ hashFiles('~/.cellfinder/**') }}
key: models-${{ hashFiles('~/.brainglobe/**') }}
# Setup pyqt libraries
- name: Setup qtpy libraries
uses: tlambert03/setup-qt-libs@v1
Expand All @@ -108,7 +108,7 @@ jobs:
uses: actions/cache@v3
with:
path: "~/.cellfinder"
key: models-${{ hashFiles('~/.cellfinder/**') }}
key: models-${{ hashFiles('~/.brainglobe/**') }}

- name: Checkout brainglobe-workflows
uses: actions/checkout@v3
Expand Down
3 changes: 3 additions & 0 deletions cellfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path

try:
__version__ = version("cellfinder")
Expand All @@ -22,3 +23,5 @@

__author__ = "Adam Tyson, Christian Niedworok, Charly Rousseau"
__license__ = "BSD-3-Clause"

DEFAULT_CELLFINDER_DIRECTORY = Path.home() / ".brainglobe" / "cellfinder"
71 changes: 39 additions & 32 deletions cellfinder/core/download/cli.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,42 @@
import tempfile
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from pathlib import Path

from cellfinder.core.download import models
from cellfinder.core.download.download import amend_user_configuration
from cellfinder.core.download.download import (
DEFAULT_DOWNLOAD_DIRECTORY,
amend_user_configuration,
download_models,
)

home = Path.home()
DEFAULT_DOWNLOAD_DIRECTORY = home / ".cellfinder"
temp_dir = tempfile.TemporaryDirectory()
temp_dir_path = Path(temp_dir.name)

def download_parser(parser: ArgumentParser) -> ArgumentParser:
"""
Configure the argument parser for downloading files.
Parameters
----------
parser : ArgumentParser
The argument parser to configure.
Returns
-------
ArgumentParser
The configured argument parser.
"""

def download_directory_parser(parser):
parser.add_argument(
"--install-path",
dest="install_path",
type=Path,
default=DEFAULT_DOWNLOAD_DIRECTORY,
help="The path to install files to.",
)
parser.add_argument(
"--download-path",
dest="download_path",
type=Path,
default=temp_dir_path,
help="The path to download files into.",
)
parser.add_argument(
"--no-amend-config",
dest="no_amend_config",
action="store_true",
help="Don't amend the config file",
)
return parser


def model_parser(parser):
parser.add_argument(
"--no-models",
dest="no_models",
action="store_true",
help="Don't download the model",
)
parser.add_argument(
"--model",
dest="model",
Expand All @@ -52,17 +47,29 @@ def model_parser(parser):
return parser


def download_parser():
def get_parser() -> ArgumentParser:
"""
Create an argument parser for downloading files.
Returns
-------
ArgumentParser
The configured argument parser.
"""
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser = model_parser(parser)
parser = download_directory_parser(parser)
parser = download_parser(parser)
return parser


def main():
args = download_parser().parse_args()
if not args.no_models:
model_path = models.main(args.model, args.install_path)
def main() -> None:
"""
Run the main download function, and optionally amend the user
configuration.
"""
args = get_parser().parse_args()
model_path = download_models(args.model, args.install_path)

if not args.no_amend_config:
amend_user_configuration(new_model_path=model_path)
Expand Down
100 changes: 44 additions & 56 deletions cellfinder/core/download/download.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,67 @@
import os
import shutil
import tarfile
import urllib.request
from pathlib import Path
from typing import Literal

import pooch
from brainglobe_utils.general.config import get_config_obj
from brainglobe_utils.general.system import disk_free_gb

from cellfinder import DEFAULT_CELLFINDER_DIRECTORY
from cellfinder.core.tools.source_files import (
default_configuration_path,
user_specific_configuration_path,
)

DEFAULT_DOWNLOAD_DIRECTORY = DEFAULT_CELLFINDER_DIRECTORY / "models"

class DownloadError(Exception):
pass

MODEL_URL = "https://gin.g-node.org/cellfinder/models/raw/master"

def download_file(destination_path, file_url, filename):
direct_download = True
file_url = file_url.format(int(direct_download))
print(f"Downloading file: {filename}")
with urllib.request.urlopen(file_url) as response:
with open(destination_path, "wb") as outfile:
shutil.copyfileobj(response, outfile)
model_filenames = {
"resnet50_tv": "resnet50_tv.h5",
"resnet50_all": "resnet50_weights.h5",
}

model_hashes = {
"resnet50_tv": "63d36af456640590ba6c896dc519f9f29861015084f4c40777a54c18c1fc4edd", # noqa: E501
"resnet50_all": None,
}

def extract_file(tar_file_path, destination_path):
tar = tarfile.open(tar_file_path)
tar.extractall(path=destination_path)
tar.close()

model_type = Literal["resnet50_tv", "resnet50_all"]

# TODO: check that intermediate folders exist
def download(
download_path,
url,
file_name,
install_path=None,
download_requires=None,
extract_requires=None,
):
if not os.path.exists(os.path.dirname(download_path)):
raise DownloadError(
f"Could not find directory '{os.path.dirname(download_path)}' "
f"to download file: {file_name}"
)

if (download_requires is not None) and (
disk_free_gb(os.path.dirname(download_path)) < download_requires
):
raise DownloadError(
f"Insufficient disk space in {os.path.dirname(download_path)} to"
f"download file: {file_name}"
)
def download_models(
model_name: model_type, download_path: os.PathLike
) -> Path:
"""
For a given model name and download path, download the model file
and return the path to the downloaded file.
Parameters
----------
model_name : model_type
The name of the model to be downloaded.
download_path : os.PathLike
The path where the model file will be downloaded.
if install_path is not None:
if not os.path.exists(install_path):
raise DownloadError(
f"Could not find directory '{install_path}' "
f"to extract file: {file_name}"
)
Returns
-------
Path
The path to the downloaded model file.
"""

if (extract_requires is not None) and (
disk_free_gb(install_path) < extract_requires
):
raise DownloadError(
f"Insufficient disk space in {install_path} to"
f"extract file: {file_name}"
)
download_path = Path(download_path)
filename = model_filenames[model_name]
model_path = pooch.retrieve(
url=f"{MODEL_URL}/{filename}",
known_hash=model_hashes[model_name],
path=download_path,
fname=filename,
progressbar=True,
)

download_file(download_path, url, file_name)
if install_path is not None:
extract_file(download_path, install_path)
os.remove(download_path)
return Path(model_path)


def amend_user_configuration(new_model_path=None) -> None:
Expand All @@ -83,7 +71,7 @@ def amend_user_configuration(new_model_path=None) -> None:
Parameters
----------
new_model_path : str, optional
new_model_path : Path, optional
The path to the new model configuration.
"""
print("(Over-)writing custom user configuration")
Expand Down
49 changes: 0 additions & 49 deletions cellfinder/core/download/models.py

This file was deleted.

2 changes: 1 addition & 1 deletion cellfinder/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainglobe_utils.general.logging import suppress_specific_logs

from cellfinder.core import logger
from cellfinder.core.download.models import model_type
from cellfinder.core.download.download import model_type
from cellfinder.core.train.train_yml import depth_type

tf_suppress_log_messages = [
Expand Down
21 changes: 11 additions & 10 deletions cellfinder/core/tools/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@

import cellfinder.core.tools.tf as tf_tools
from cellfinder.core import logger
from cellfinder.core.download import models as model_download
from cellfinder.core.download.download import amend_user_configuration
from cellfinder.core.download.download import (
DEFAULT_DOWNLOAD_DIRECTORY,
amend_user_configuration,
download_models,
model_type,
)
from cellfinder.core.tools.source_files import user_specific_configuration_path

home = Path.home()
DEFAULT_INSTALL_PATH = home / ".cellfinder"


def prep_model_weights(
model_weights: Optional[os.PathLike],
install_path: Optional[os.PathLike],
model_name: model_download.model_type,
model_name: model_type,
n_free_cpus: int,
) -> Path:
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
Expand All @@ -42,9 +43,9 @@ def prep_tensorflow(max_threads: int) -> None:
def prep_models(
model_weights_path: Optional[os.PathLike],
install_path: Optional[os.PathLike],
model_name: model_download.model_type,
model_name: model_type,
) -> Path:
install_path = install_path or DEFAULT_INSTALL_PATH
install_path = install_path or DEFAULT_DOWNLOAD_DIRECTORY
# if no model or weights, set default weights
if model_weights_path is None:
logger.debug("No model supplied, so using the default")
Expand All @@ -53,13 +54,13 @@ def prep_models(

if not Path(config_file).exists():
logger.debug("Custom config does not exist, downloading models")
model_path = model_download.main(model_name, install_path)
model_path = download_models(model_name, install_path)
amend_user_configuration(new_model_path=model_path)

model_weights = get_model_weights(config_file)
if not model_weights.exists():
logger.debug("Model weights do not exist, downloading")
model_path = model_download.main(model_name, install_path)
model_path = download_models(model_name, install_path)
amend_user_configuration(new_model_path=model_path)
model_weights = get_model_weights(config_file)
else:
Expand Down
Loading

0 comments on commit 5f4882e

Please sign in to comment.