Skip to content

Commit

Permalink
Model versioning (huggingface#8324)
Browse files Browse the repository at this point in the history
* fix typo

* rm use_cdn & references, and implement new hf_bucket_url

* I'm pretty sure we don't need to `read` this file

* same here

* [BIG] file_utils.networking: do not gobble up errors anymore

* Fix CI 馃槆

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Tiny doc tweak

* Add doc + pass kwarg everywhere

* Add more tests and explain

cc @sshleifer let me know if better

Co-Authored-By: Sam Shleifer <sshleifer@gmail.com>

* Also implement revision in pipelines

In the case where we're passing a task name or a string model identifier

* Fix CI 馃槆

* Fix CI

* [hf_api] new methods + command line implem

* make style

* Final endpoints post-migration

* Fix post-migration

* Py3.6 compat

cc @stefan-it

Thank you @stas00

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
  • Loading branch information
3 people authored and fabiocapsouza committed Nov 15, 2020
1 parent 4e33b5e commit efedc1d
Show file tree
Hide file tree
Showing 23 changed files with 472 additions and 210 deletions.
4 changes: 2 additions & 2 deletions model_cards/t5-11b-README.md
Expand Up @@ -12,8 +12,8 @@ inference: false

## Disclaimer

Due do it's immense size, `t5-11b` requires some special treatment.
First, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:
**Before `transformers` v3.5.0**, due do its immense size, `t5-11b` required some special treatment.
If you're using transformers `<= v3.4.0`, `t5-11b` should be loaded with flag `use_cdn` set to `False` as follows:

```python
t5 = transformers.T5ForConditionalGeneration.from_pretrained('t5-11b', use_cdn = False)
Expand Down
4 changes: 0 additions & 4 deletions scripts/fsmt/convert-allenai-wmt16.sh
Expand Up @@ -56,7 +56,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt16-en-de-dist-12-1", "wmt16-en-de-dist-6-1", "wmt16-en-de-12-1")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed

# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
4 changes: 0 additions & 4 deletions scripts/fsmt/convert-allenai-wmt19.sh
Expand Up @@ -44,7 +44,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for ("wmt19-de-en-6-6-base", "wmt19-de-en-6-6-big")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed

# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
4 changes: 0 additions & 4 deletions scripts/fsmt/convert-facebook-wmt19.sh
Expand Up @@ -55,7 +55,3 @@ cd -
perl -le 'for $f (@ARGV) { print qq[transformers-cli upload -y $_/$f --filename $_/$f] for map { "wmt19-$_" } ("en-ru", "ru-en", "de-en", "en-de")}' vocab-src.json vocab-tgt.json tokenizer_config.json config.json
# add/remove files as needed

# Caching note: Unfortunately due to CDN caching the uploaded model may be unavailable for up to 24hs after upload
# So the only way to start using the new model sooner is either:
# 1. download it to a local path and use that path as model_name
# 2. make sure you use: from_pretrained(..., use_cdn=False) everywhere
161 changes: 137 additions & 24 deletions src/transformers/commands/user.py
@@ -1,4 +1,5 @@
import os
import subprocess
import sys
from argparse import ArgumentParser
from getpass import getpass
Expand All @@ -21,8 +22,10 @@ def register_subcommand(parser: ArgumentParser):
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser("logout", help="Log out")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
# s3
s3_parser = parser.add_parser("s3", help="{ls, rm} Commands to interact with the files you upload on S3.")
# s3_datasets (s3-based system)
s3_parser = parser.add_parser(
"s3_datasets", help="{ls, rm} Commands to interact with the files you upload on S3."
)
s3_subparsers = s3_parser.add_subparsers(help="s3 related commands")
ls_parser = s3_subparsers.add_parser("ls")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
Expand All @@ -31,17 +34,42 @@ def register_subcommand(parser: ArgumentParser):
rm_parser.add_argument("filename", type=str, help="individual object filename to delete from S3.")
rm_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
rm_parser.set_defaults(func=lambda args: DeleteObjCommand(args))
# upload
upload_parser = parser.add_parser("upload", help="Upload a model to S3.")
upload_parser.add_argument(
"path", type=str, help="Local path of the model folder or individual file to upload."
)
upload_parser = s3_subparsers.add_parser("upload", help="Upload a file to S3.")
upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
)
upload_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
upload_parser.set_defaults(func=lambda args: UploadCommand(args))
# deprecated model upload
upload_parser = parser.add_parser(
"upload",
help=(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
),
)
upload_parser.set_defaults(func=lambda args: DeprecatedUploadCommand(args))

# new system: git-based repo system
repo_parser = parser.add_parser(
"repo", help="{create, ls-files} Commands to interact with your huggingface.co repos."
)
repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands")
ls_parser = repo_subparsers.add_parser("ls-files", help="List all your files on huggingface.co")
ls_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
ls_parser.set_defaults(func=lambda args: ListReposObjsCommand(args))
repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co")
repo_create_parser.add_argument(
"name",
type=str,
help="Name for your model's repo. Will be namespaced under your username to build the model id.",
)
repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.")
repo_create_parser.add_argument("-y", "--yes", action="store_true", help="Optional: answer Yes to the prompt")
repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args))


class ANSI:
Expand All @@ -51,6 +79,7 @@ class ANSI:

_bold = "\u001b[1m"
_red = "\u001b[31m"
_gray = "\u001b[90m"
_reset = "\u001b[0m"

@classmethod
Expand All @@ -61,6 +90,27 @@ def bold(cls, s):
def red(cls, s):
return "{}{}{}".format(cls._bold + cls._red, s, cls._reset)

@classmethod
def gray(cls, s):
return "{}{}{}".format(cls._gray, s, cls._reset)


def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)


class BaseUserCommand:
def __init__(self, args):
Expand Down Expand Up @@ -124,22 +174,6 @@ def run(self):


class ListObjsCommand(BaseUserCommand):
def tabulate(self, rows: List[List[Union[str, int]]], headers: List[str]) -> str:
"""
Inspired by:
- stackoverflow.com/a/8356620/593036
- stackoverflow.com/questions/9535954/printing-lists-as-tabular-data
"""
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = []
lines.append(row_format.format(*headers))
lines.append(row_format.format(*["-" * w for w in col_widths]))
for row in rows:
lines.append(row_format.format(*row))
return "\n".join(lines)

def run(self):
token = HfFolder.get_token()
if token is None:
Expand All @@ -155,7 +189,7 @@ def run(self):
print("No shared file yet")
exit()
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))


class DeleteObjCommand(BaseUserCommand):
Expand All @@ -173,6 +207,85 @@ def run(self):
print("Done")


class ListReposObjsCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
objs = self._api.list_repos_objs(token, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
if len(objs) == 0:
print("No shared file yet")
exit()
rows = [[obj.filename, obj.lastModified, obj.commit, obj.size] for obj in objs]
print(tabulate(rows, headers=["Filename", "LastModified", "Commit-Sha", "Size"]))


class RepoCreateCommand(BaseUserCommand):
def run(self):
token = HfFolder.get_token()
if token is None:
print("Not logged in")
exit(1)
try:
stdout = subprocess.check_output(["git", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print("Looks like you do not have git installed, please install.")

try:
stdout = subprocess.check_output(["git-lfs", "--version"]).decode("utf-8")
print(ANSI.gray(stdout.strip()))
except FileNotFoundError:
print(
ANSI.red(
"Looks like you do not have git-lfs installed, please install."
" You can install from https://git-lfs.github.com/."
" Then run `git lfs install` (you only have to do this once)."
)
)
print("")

user, _ = self._api.whoami(token)
namespace = self.args.organization if self.args.organization is not None else user

print("You are about to create {}".format(ANSI.bold(namespace + "/" + self.args.name)))

if not self.args.yes:
choice = input("Proceed? [Y/n] ").lower()
if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort")
exit()
try:
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
except HTTPError as e:
print(e)
print(ANSI.red(e.response.text))
exit(1)
print("\nYour repo now lives at:")
print(" {}".format(ANSI.bold(url)))
print("\nYou can clone it locally with the command below," " and commit/push as usual.")
print(f"\n git clone {url}")
print("")


class DeprecatedUploadCommand(BaseUserCommand):
def run(self):
print(
ANSI.red(
"Deprecated: used to be the way to upload a model to S3."
" We now use a git-based system for storing models and other artifacts."
" Use the `repo create` command instead."
)
)
exit(1)


class UploadCommand(BaseUserCommand):
def walk_dir(self, rel_path):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/configuration_auto.py
Expand Up @@ -289,6 +289,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/configuration_utils.py
Expand Up @@ -311,6 +311,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Pretr
proxies (:obj:`Dict[str, str]`, `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any
identifier allowed by git.
return_unused_kwargs (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`False`, then this function returns just the final configuration object.
Expand Down Expand Up @@ -362,14 +366,15 @@ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)

if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
pretrained_model_name_or_path, filename=CONFIG_NAME, revision=revision, mirror=None
)

try:
Expand All @@ -383,11 +388,10 @@ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[
local_files_only=local_files_only,
)
# Load config dict
if resolved_config_file is None:
raise EnvironmentError
config_dict = cls._dict_from_json_file(resolved_config_file)

except EnvironmentError:
except EnvironmentError as err:
logger.error(err)
msg = (
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
Expand Down

0 comments on commit efedc1d

Please sign in to comment.