From cdaed7a4973c31c491db2ea6938a4df1e42141c1 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Wed, 8 Mar 2023 23:32:56 +0800 Subject: [PATCH 1/2] [WIP]Upload models in mmseg to HF hub --- mmseg/utils/__init__.py | 4 +- mmseg/utils/hub.py | 242 ++++++++++++++++++++++++++++++++++ tools/misc/huggingface_hub.py | 21 +++ 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 mmseg/utils/hub.py create mode 100644 tools/misc/huggingface_hub.py diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index cb1436c198..1fe9bb8b64 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -10,6 +10,7 @@ vaihingen_palette, voc_classes, voc_palette) # yapf: enable from .collect_env import collect_env +from .hub import has_hf_hub, push_to_hf_hub from .io import datafrombytes from .misc import add_prefix, stack_batch from .set_env import register_all_modules @@ -27,5 +28,6 @@ 'cityscapes_palette', 'ade_palette', 'voc_palette', 'cocostuff_palette', 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', - 'datafrombytes', 'synapse_palette', 'synapse_classes' + 'datafrombytes', 'synapse_palette', 'synapse_classes', 'push_to_hf_hub', + 'has_hf_hub' ] diff --git a/mmseg/utils/hub.py b/mmseg/utils/hub.py new file mode 100644 index 0000000000..cc768fa517 --- /dev/null +++ b/mmseg/utils/hub.py @@ -0,0 +1,242 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import json +import os.path as osp +from collections import OrderedDict +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple, Union + +import torch +from mmengine import Config, ConfigDict +from mmengine.fileio import load +from mmengine.utils.dl_utils import load_url +from torch.nn import Module + +ConfigType = Union[Config, ConfigDict] + +try: + from huggingface_hub import (create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub.utils import EntryNotFoundError + hf_hub_download = partial( + hf_hub_download, library_name="openmmlab", library_version='2.0') + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False +HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' + ) + return _has_hf_hub + + +def generate_readme(model_card: dict, model_name: str): + readme_text = "---\n" + readme_text += "tags:\n- semantic segmentation\n- openmmlab\n" + readme_text += "library_tag: openmmlab/mmsegmentation\n" + readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + if 'Pretrain Dataset' in model_card['details']: + readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += "\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + return readme_text + + +def push_to_hf_hub(model: Union[str, ConfigType, Module], + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_config: Optional[dict] = None, + model_card: Optional[dict] = None): + """_summary_ + + Args: + model (_type_): _description_ + repo_id (str): _description_ + commit_message (str, optional): _description_. Defaults to 'Add model'. + token (Optional[str], optional): _description_. Defaults to None. + revision (Optional[str], optional): _description_. Defaults to None. + private (bool, optional): _description_. Defaults to False. + create_pr (bool, optional): _description_. Defaults to False. + model_config (Optional[dict], optional): _description_. Defaults to None. + model_card (Optional[dict], optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + # Create repo if it doesn't exist yet + repo_url = create_repo( + repo_id, token=token, private=private, exist_ok=True) + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f"{repo_owner}/{repo_name}" + + # Check if README file already exist in repo + try: + get_hf_file_metadata( + hf_hub_url( + repo_id=repo_id, filename="README.md", revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + save_for_hf(model, tmpdir) + + # Add readme if it does not exist + if not has_readme: + model_card = model_card or {} + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / "README.md" + readme_text = generate_readme(model_card, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def save_for_hf(model, save_directory): + model_path = Path(save_directory) / HF_WEIGHTS_NAME + config_path = Path(save_directory) / 'config.json' + if isinstance(model, str): + # input is model name + config, ckpt = _load_model_from_metafile(model) + load_url(ckpt, model_path) + + config.dump(config_path) + + elif isinstance(model, torch.nn.Module): + # input is a torch module + ckpt = model.state_dict() + torch.save(ckpt, model_path) + elif isinstance(model, OrderedDict): + torch.save(ckpt, model_path) + + +def save_config_for_hf(config, save_directory): + pass + + +def _load_model_from_metafile(model: str) -> Tuple[Config, str]: + """Load config and weights from metafile. + + Args: + model (str): model name defined in metafile. + + Returns: + Tuple[Config, str]: Loaded Config and weights path defined in + metafile. + """ + model = model.lower() + + repo_or_mim_dir = _get_repo_or_mim_dir() + for model_cfg in _get_models_from_metafile(repo_or_mim_dir): + model_name = model_cfg['Name'].lower() + model_aliases = model_cfg.get('Alias', []) + if isinstance(model_aliases, str): + model_aliases = [model_aliases.lower()] + else: + model_aliases = [alias.lower() for alias in model_aliases] + if (model_name == model or model in model_aliases): + cfg = Config.fromfile( + osp.join(repo_or_mim_dir, model_cfg['Config'])) + weights = model_cfg['Weights'] + weights = weights[0] if isinstance(weights, list) else weights + return cfg, weights + raise ValueError(f'Cannot find model: {model} in {self.scope}') + + +def _get_models_from_metafile(dir: str): + """Load model config defined in metafile from package path. + + Args: + dir (str): Path to the directory of Config. It requires the + directory ``Config``, file ``model-index.yml`` exists in the + ``dir``. + + Yields: + dict: Model config defined in metafile. + """ + meta_indexes = load(osp.join(dir, 'model-index.yml')) + for meta_path in meta_indexes['Import']: + # meta_path example: mmcls/.mim/configs/conformer/metafile.yml + meta_path = osp.join(dir, meta_path) + metainfo = load(meta_path) + yield from metainfo['Models'] + + +def _get_repo_or_mim_dir(): + + module = importlib.import_module('mmseg') + # Since none of OpenMMLab series packages are namespace packages + # (https://docs.python.org/3/glossary.html#term-namespace-package), + # The first element of module.__path__ means package installation path. + package_path = module.__path__[0] + + if osp.exists(osp.join(osp.dirname(package_path), 'configs')): + repo_dir = osp.dirname(package_path) + return repo_dir + else: + mim_dir = osp.join(package_path, '.mim') + if not osp.exists(osp.join(mim_dir, 'Configs')): + raise FileNotFoundError( + f'Cannot find Configs directory in {package_path}!, ' + f'please check the completeness of the mmseg.') + return mim_dir \ No newline at end of file diff --git a/tools/misc/huggingface_hub.py b/tools/misc/huggingface_hub.py new file mode 100644 index 0000000000..89f93614ff --- /dev/null +++ b/tools/misc/huggingface_hub.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from mmseg.utils import push_to_hf_hub, has_hf_hub + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Upload model to Hugging Face Hub') + + parser.add_argument('model', help='model name in metafiles') + parser.add_argument( + '--repo-id', default=None, type=str, help='repo-id for this model') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + assert has_hf_hub(True) + repo_id = args.repo_id if args.repo_id is not None else args.model + push_to_hf_hub(args.model, repo_id=repo_id) From 68685e43ca81ef29e13c8a5053b2e4fbb53f7165 Mon Sep 17 00:00:00 2001 From: MeowZheng Date: Thu, 9 Mar 2023 21:00:18 +0800 Subject: [PATCH 2/2] add create_from_hf_hub --- mmseg/utils/__init__.py | 4 +- mmseg/utils/hub.py | 242 ------------------------------------ mmseg/utils/hugging_face.py | 239 +++++++++++++++++++++++++++++++++++ 3 files changed, 241 insertions(+), 244 deletions(-) delete mode 100644 mmseg/utils/hub.py create mode 100644 mmseg/utils/hugging_face.py diff --git a/mmseg/utils/__init__.py b/mmseg/utils/__init__.py index 1fe9bb8b64..daae08d9bc 100644 --- a/mmseg/utils/__init__.py +++ b/mmseg/utils/__init__.py @@ -10,7 +10,7 @@ vaihingen_palette, voc_classes, voc_palette) # yapf: enable from .collect_env import collect_env -from .hub import has_hf_hub, push_to_hf_hub +from .hugging_face import create_from_hf_hub, has_hf_hub, push_to_hf_hub from .io import datafrombytes from .misc import add_prefix, stack_batch from .set_env import register_all_modules @@ -29,5 +29,5 @@ 'loveda_palette', 'potsdam_palette', 'vaihingen_palette', 'isaid_palette', 'stare_palette', 'dataset_aliases', 'get_classes', 'get_palette', 'datafrombytes', 'synapse_palette', 'synapse_classes', 'push_to_hf_hub', - 'has_hf_hub' + 'has_hf_hub', 'create_from_hf_hub' ] diff --git a/mmseg/utils/hub.py b/mmseg/utils/hub.py deleted file mode 100644 index cc768fa517..0000000000 --- a/mmseg/utils/hub.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import importlib -import json -import os.path as osp -from collections import OrderedDict -from functools import partial -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Tuple, Union - -import torch -from mmengine import Config, ConfigDict -from mmengine.fileio import load -from mmengine.utils.dl_utils import load_url -from torch.nn import Module - -ConfigType = Union[Config, ConfigDict] - -try: - from huggingface_hub import (create_repo, get_hf_file_metadata, - hf_hub_download, hf_hub_url, - repo_type_and_id_from_hf_id, upload_folder) - from huggingface_hub.utils import EntryNotFoundError - hf_hub_download = partial( - hf_hub_download, library_name="openmmlab", library_version='2.0') - _has_hf_hub = True -except ImportError: - hf_hub_download = None - _has_hf_hub = False -HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl - - -def has_hf_hub(necessary=False): - if not _has_hf_hub and necessary: - # if no HF Hub module installed, and it is necessary to continue, raise error - raise RuntimeError( - 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.' - ) - return _has_hf_hub - - -def generate_readme(model_card: dict, model_name: str): - readme_text = "---\n" - readme_text += "tags:\n- semantic segmentation\n- openmmlab\n" - readme_text += "library_tag: openmmlab/mmsegmentation\n" - readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" - if 'details' in model_card and 'Dataset' in model_card['details']: - readme_text += 'datasets:\n' - readme_text += f"- {model_card['details']['Dataset'].lower()}\n" - if 'Pretrain Dataset' in model_card['details']: - readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n" - readme_text += "---\n" - readme_text += f"# Model card for {model_name}\n" - if 'description' in model_card: - readme_text += f"\n{model_card['description']}\n" - if 'details' in model_card: - readme_text += "\n## Model Details\n" - for k, v in model_card['details'].items(): - if isinstance(v, (list, tuple)): - readme_text += f"- **{k}:**\n" - for vi in v: - readme_text += f" - {vi}\n" - elif isinstance(v, dict): - readme_text += f"- **{k}:**\n" - for ki, vi in v.items(): - readme_text += f" - {ki}: {vi}\n" - else: - readme_text += f"- **{k}:** {v}\n" - if 'usage' in model_card: - readme_text += f"\n## Model Usage\n" - readme_text += model_card['usage'] - readme_text += '\n' - - if 'comparison' in model_card: - readme_text += f"\n## Model Comparison\n" - readme_text += model_card['comparison'] - readme_text += '\n' - - if 'citation' in model_card: - readme_text += f"\n## Citation\n" - if not isinstance(model_card['citation'], (list, tuple)): - citations = [model_card['citation']] - else: - citations = model_card['citation'] - for c in citations: - readme_text += f"```bibtex\n{c}\n```\n" - return readme_text - - -def push_to_hf_hub(model: Union[str, ConfigType, Module], - repo_id: str, - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_config: Optional[dict] = None, - model_card: Optional[dict] = None): - """_summary_ - - Args: - model (_type_): _description_ - repo_id (str): _description_ - commit_message (str, optional): _description_. Defaults to 'Add model'. - token (Optional[str], optional): _description_. Defaults to None. - revision (Optional[str], optional): _description_. Defaults to None. - private (bool, optional): _description_. Defaults to False. - create_pr (bool, optional): _description_. Defaults to False. - model_config (Optional[dict], optional): _description_. Defaults to None. - model_card (Optional[dict], optional): _description_. Defaults to None. - - Returns: - _type_: _description_ - """ - # Create repo if it doesn't exist yet - repo_url = create_repo( - repo_id, token=token, private=private, exist_ok=True) - # Infer complete repo_id from repo_url - # Can be different from the input `repo_id` if repo_owner was implicit - _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) - repo_id = f"{repo_owner}/{repo_name}" - - # Check if README file already exist in repo - try: - get_hf_file_metadata( - hf_hub_url( - repo_id=repo_id, filename="README.md", revision=revision)) - has_readme = True - except EntryNotFoundError: - has_readme = False - - # Dump model and push to Hub - with TemporaryDirectory() as tmpdir: - # Save model weights and config. - save_for_hf(model, tmpdir) - - # Add readme if it does not exist - if not has_readme: - model_card = model_card or {} - model_name = repo_id.split('/')[-1] - readme_path = Path(tmpdir) / "README.md" - readme_text = generate_readme(model_card, model_name) - readme_path.write_text(readme_text) - - # Upload model and return - return upload_folder( - repo_id=repo_id, - folder_path=tmpdir, - revision=revision, - create_pr=create_pr, - commit_message=commit_message, - ) - - -def save_for_hf(model, save_directory): - model_path = Path(save_directory) / HF_WEIGHTS_NAME - config_path = Path(save_directory) / 'config.json' - if isinstance(model, str): - # input is model name - config, ckpt = _load_model_from_metafile(model) - load_url(ckpt, model_path) - - config.dump(config_path) - - elif isinstance(model, torch.nn.Module): - # input is a torch module - ckpt = model.state_dict() - torch.save(ckpt, model_path) - elif isinstance(model, OrderedDict): - torch.save(ckpt, model_path) - - -def save_config_for_hf(config, save_directory): - pass - - -def _load_model_from_metafile(model: str) -> Tuple[Config, str]: - """Load config and weights from metafile. - - Args: - model (str): model name defined in metafile. - - Returns: - Tuple[Config, str]: Loaded Config and weights path defined in - metafile. - """ - model = model.lower() - - repo_or_mim_dir = _get_repo_or_mim_dir() - for model_cfg in _get_models_from_metafile(repo_or_mim_dir): - model_name = model_cfg['Name'].lower() - model_aliases = model_cfg.get('Alias', []) - if isinstance(model_aliases, str): - model_aliases = [model_aliases.lower()] - else: - model_aliases = [alias.lower() for alias in model_aliases] - if (model_name == model or model in model_aliases): - cfg = Config.fromfile( - osp.join(repo_or_mim_dir, model_cfg['Config'])) - weights = model_cfg['Weights'] - weights = weights[0] if isinstance(weights, list) else weights - return cfg, weights - raise ValueError(f'Cannot find model: {model} in {self.scope}') - - -def _get_models_from_metafile(dir: str): - """Load model config defined in metafile from package path. - - Args: - dir (str): Path to the directory of Config. It requires the - directory ``Config``, file ``model-index.yml`` exists in the - ``dir``. - - Yields: - dict: Model config defined in metafile. - """ - meta_indexes = load(osp.join(dir, 'model-index.yml')) - for meta_path in meta_indexes['Import']: - # meta_path example: mmcls/.mim/configs/conformer/metafile.yml - meta_path = osp.join(dir, meta_path) - metainfo = load(meta_path) - yield from metainfo['Models'] - - -def _get_repo_or_mim_dir(): - - module = importlib.import_module('mmseg') - # Since none of OpenMMLab series packages are namespace packages - # (https://docs.python.org/3/glossary.html#term-namespace-package), - # The first element of module.__path__ means package installation path. - package_path = module.__path__[0] - - if osp.exists(osp.join(osp.dirname(package_path), 'configs')): - repo_dir = osp.dirname(package_path) - return repo_dir - else: - mim_dir = osp.join(package_path, '.mim') - if not osp.exists(osp.join(mim_dir, 'Configs')): - raise FileNotFoundError( - f'Cannot find Configs directory in {package_path}!, ' - f'please check the completeness of the mmseg.') - return mim_dir \ No newline at end of file diff --git a/mmseg/utils/hugging_face.py b/mmseg/utils/hugging_face.py new file mode 100644 index 0000000000..da6a3e0531 --- /dev/null +++ b/mmseg/utils/hugging_face.py @@ -0,0 +1,239 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +import os +import os.path as osp +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Tuple + +from mmengine import Config +from mmengine.fileio import load +from mmengine.utils.dl_utils import load_url + +try: + from huggingface_hub import (create_repo, get_hf_file_metadata, + hf_hub_download, hf_hub_url, + repo_type_and_id_from_hf_id, upload_folder) + from huggingface_hub.utils import EntryNotFoundError + hf_hub_download = partial( + hf_hub_download, library_name=None, library_version=None) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False +HF_WEIGHTS_NAME = 'pytorch_model.bin' # default pytorch pkl + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, + # raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. ' + 'Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def create_from_hf_hub(repo_id): + + from mmseg.apis import init_model + + # Download config from HF Hub + config_file = hf_hub_download(repo_id=repo_id, filename='config.json') + # Download ckpt from HF Hub + checkpoint_file = hf_hub_download( + repo_id=repo_id, filename=HF_WEIGHTS_NAME) + return init_model(config=config_file, checkpoint=checkpoint_file) + + +def generate_readme(results_info: dict, model_name: str) -> str: + """Generate README (model card for Hugging face Hub) + + Args: + results_info (dict): The results information of model. + model_name (str): The model name for the model card. + + Returns: + str: The text readme for the model. + """ + readme_text = '---\n' + readme_text += 'language:\n- en\n' + readme_text += 'license: apache-2.0\n' + readme_text += 'library_name: mmsegmentation\n' + readme_text += 'tags:\n- semantic segmentation\n- openmmlab\n' + if results_info.get('Dataset'): + readme_text += f'datasets:\n- {results_info["Dataset"]}' + if results_info.get('Metrics'): + readme_text += f'metrics:\n- {results_info["Metrics"]["mIoU"]}' + + readme_text += '---\n' + + readme_text += f'# Model card for {model_name}\n' + # TODO: Add more description + return readme_text + + +def push_to_hf_hub(model: str, + repo_id: str, + commit_message: Optional[str] = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False) -> str: + """Push model from MMSegmentation to Hugging face Hub. + + Args: + model (str): The model which will be uploaded. It can be the model name + or alias in metafile. + repo_id (str): The repository to which the file will be uploaded, for + example: `"username/custom_transformers"`. + commit_message (str, optional): The summary / title / first line of the + generated commit. Defaults to: 'Add model'. + token (str, optional): Authentication token, obtained by `HfApi.login` + method. Will default to the stored token. Defaults to None. + revision (str, optional): The git revision to commit from. Defaults to + None, i.e. the head of the `"main"` branch. + private (bool, optional): Whether the model repo should be private. + Defaults to False. + create_pr (bool, optional): Whether or not to create a Pull Request + with that commit. Defaults to `False`. If `revision` is not set, + PR is opened against the `"main"` branch. If `revision` is set and + is a branch, PR is opened against this branch. If `revision` is set + and is not a branch name (example: a commit oid), an + `RevisionNotFoundError` is returned by the server. + + Returns: + str: A URL to visualize the uploaded folder on the hub + """ + # Create repo if it doesn't exist yet + repo_url = create_repo( + repo_id, token=token, private=private, exist_ok=True) + # Infer complete repo_id from repo_url + # Can be different from the input `repo_id` if repo_owner was implicit + _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) + repo_id = f'{repo_owner}/{repo_name}' + + # Check if README file already exist in repo + try: + get_hf_file_metadata( + hf_hub_url( + repo_id=repo_id, filename='README.md', revision=revision)) + has_readme = True + except EntryNotFoundError: + has_readme = False + + # Dump model and push to Hub + with TemporaryDirectory() as tmpdir: + # Save model weights and config. + results_info = save_for_hf(model, tmpdir) + + # Add readme if it does not exist + if not has_readme: + model_name = repo_id.split('/')[-1] + readme_path = Path(tmpdir) / 'README.md' + readme_text = generate_readme(results_info, model_name) + readme_path.write_text(readme_text) + + # Upload model and return + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) + + +def save_for_hf(model, save_directory) -> dict: + """Save the files for Hugging face Hub. + + Args: + model (str): The model which will be uploaded. It can be the model name + or alias in metafile. + save_directory (str): The directory to save the checkpont file and + config file which will be uploaded to hub. + + Returns: + ditc: The results information of model. + """ + config_path = Path(save_directory) / 'config.json' + model_path = Path(save_directory) / HF_WEIGHTS_NAME + # input is model name + config, ckpt, results_info = _load_model_from_metafile(model) + config.dump(config_path) + ckpt_org_name = osp.basename(ckpt) + load_url(ckpt, Path(save_directory)) + os.rename(osp.join(Path(save_directory), ckpt_org_name), model_path) + return results_info + + +def _load_model_from_metafile(model: str) -> Tuple[Config, str]: + """Load config and weights from metafile. + + Args: + model (str): model name defined in metafile. + + Returns: + Tuple[Config, str]: Loaded Config and weights path defined in + metafile. + """ + model = model.lower() + + repo_or_mim_dir = _get_repo_or_mim_dir() + for model_meta in _get_models_from_metafile(repo_or_mim_dir): + model_name = model_meta['Name'].lower() + model_aliases = model_meta.get('Alias', []) + if isinstance(model_aliases, str): + model_aliases = [model_aliases.lower()] + else: + model_aliases = [alias.lower() for alias in model_aliases] + if (model_name == model or model in model_aliases): + cfg = Config.fromfile( + osp.join(repo_or_mim_dir, model_meta['Config'])) + weights = model_meta['Weights'] + weights = weights[0] if isinstance(weights, list) else weights + results_info = model_meta['Results'] + results_info = results_info[0] if isinstance( + results_info, list) else results_info + return cfg, weights, results_info + raise ValueError(f'Cannot find model: {model} in mmsegmentation') + + +def _get_models_from_metafile(dir: str): + """Load model config defined in metafile from package path. + + Args: + dir (str): Path to the directory of Config. It requires the + directory ``Config``, file ``model-index.yml`` exists in the + ``dir``. + + Yields: + dict: Model config defined in metafile. + """ + meta_indexes = load(osp.join(dir, 'model-index.yml')) + for meta_path in meta_indexes['Import']: + # meta_path example: mmcls/.mim/configs/conformer/metafile.yml + meta_path = osp.join(dir, meta_path) + metainfo = load(meta_path) + yield from metainfo['Models'] + + +def _get_repo_or_mim_dir(): + + module = importlib.import_module('mmseg') + # Since none of OpenMMLab series packages are namespace packages + # (https://docs.python.org/3/glossary.html#term-namespace-package), + # The first element of module.__path__ means package installation path. + package_path = module.__path__[0] + + if osp.exists(osp.join(osp.dirname(package_path), 'configs')): + repo_dir = osp.dirname(package_path) + return repo_dir + else: + mim_dir = osp.join(package_path, '.mim') + if not osp.exists(osp.join(mim_dir, 'Configs')): + raise FileNotFoundError( + f'Cannot find Configs directory in {package_path}!, ' + f'please check the completeness of the mmseg.') + return mim_dir