Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to upload Ludwig models to Predibase. #3687

Merged
merged 16 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ludwig/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self):
init_config Initialize a user config from a dataset and targets
render_config Renders the fully populated config with all defaults set
check_install Runs a quick training run on synthetic data to verify installation status
upload Push trained model artifacts to a registry (e.g., HuggingFace Hub)
upload Push trained model artifacts to a registry (e.g., Predibase, HuggingFace Hub)
""",
)
parser.add_argument("command", help="Subcommand to run")
Expand Down
22 changes: 19 additions & 3 deletions ludwig/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Optional

from ludwig.utils.print_utils import get_logging_level_registry
from ludwig.utils.upload_utils import HuggingFaceHub
from ludwig.utils.upload_utils import HuggingFaceHub, Predibase

logger = logging.getLogger(__name__)


def get_upload_registry():
return {
"hf_hub": HuggingFaceHub,
"predibase": Predibase,
}


Expand All @@ -23,14 +24,16 @@ def upload_cli(
private: bool = False,
commit_message: str = "Upload trained [Ludwig](https://ludwig.ai/latest/) model weights",
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
**kwargs,
) -> None:
"""Create an empty repo on the HuggingFace Hub and upload trained model artifacts to that repo.

Args:
service (`str`):
Name of the hosted model service to push the trained artifacts to.
Currently, this only supports `hf_hub`.
Currently, this only supports `hf_hub` and `predibase`.
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
Expand All @@ -49,6 +52,12 @@ def upload_cli(
`f"Upload {path_in_repo} with huggingface_hub"`
commit_description (`str` *optional*):
The description of the generated commit
dataset_file (`str`, *optional*):
The path to the dataset file. Required if `service` is set to
`"predibase"` for new model repos.
dataset_name (`str`, *optional*):
The name of the dataset. Used by the `service`
`"predibase"`.
"""
model_service = get_upload_registry().get(service, "hf_hub")
hub = model_service()
Expand All @@ -60,6 +69,8 @@ def upload_cli(
private=private,
commit_message=commit_message,
commit_description=commit_description,
dataset_file=dataset_file,
dataset_name=dataset_name,
)


Expand All @@ -77,7 +88,7 @@ def cli(sys_argv):
"service",
help="Name of the model repository service.",
default="hf_hub",
choices=["hf_hub"],
choices=["hf_hub", "predibase"],
)

parser.add_argument(
Expand Down Expand Up @@ -115,6 +126,11 @@ def cli(sys_argv):
choices=["critical", "error", "warning", "info", "debug", "notset"],
)

parser.add_argument("-df", "--dataset_file", help="The location of the dataset file", default=None)
parser.add_argument(
"-dn", "--dataset_name", help="(Optional) The name of the dataset in the Provider", default=None
)

args = parser.parse_args(sys_argv)

args.logging_level = get_logging_level_registry()[args.logging_level]
Expand Down
226 changes: 204 additions & 22 deletions ludwig/utils/upload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def upload(
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
) -> bool:
"""Abstract method to upload trained model artifacts to the target repository.

Expand Down Expand Up @@ -68,9 +70,7 @@ def _validate_upload_parameters(
trained model artifacts to the target repository.

Args:
repo_id (str): The ID of the target repository. It must be a namespace (user or an organization)
and a repository name separated by a '/'. For example, if your HF username is 'johndoe' and you
want to create a repository called 'test', the repo_id should be 'johndoe/test'.
repo_id (str): The ID of the target repository. Each provider will verify their specific rules.
model_path (str): The path to the directory containing the trained model artifacts. It should contain
the model's weights, usually saved under 'model/model_weights'.
repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses
Expand All @@ -85,18 +85,10 @@ def _validate_upload_parameters(
implementations. Defaults to None.

Raises:
AssertionError: If the repo_id does not have both a namespace and a repo name separated by a '/'.
FileNotFoundError: If the model_path does not exist.
Exception: If the trained model artifacts are not found at the expected location within model_path, or
if the artifacts are not in the required format (i.e., 'pytorch_model.bin' or 'adapter_model.bin').
"""
# Validate repo_id has both a namespace and a repo name
assert "/" in repo_id, (
"`repo_id` must be a namespace (user or an organization) and a repo name separated by a `/`."
" For example, if your HF username is `johndoe` and you want to create a repository called `test`, the"
" repo_id should be johndoe/test"
)

# Make sure the model's save path is actually a valid path
if not os.path.exists(model_path):
raise FileNotFoundError(f"The path '{model_path}' does not exist.")
Expand All @@ -110,17 +102,6 @@ def _validate_upload_parameters(
"wrong during training where the model's weights were not saved."
)

# Make sure the model's saved artifacts either contain:
# 1. pytorch_model.bin -> regular model training, such as ECD or for LLMs
# 2. adapter_model.bin -> LLM fine-tuning using PEFT
files = set(os.listdir(trained_model_artifacts_path))
if "pytorch_model.bin" not in files and "adapter_model.bin" not in files:
raise Exception(
f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should "
"either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`"
"if using parameter efficient fine-tuning methods like LoRA."
)


class HuggingFaceHub(BaseModelUpload):
def __init__(self):
Expand All @@ -142,6 +123,67 @@ def login(self):

self.api = hf_api

@staticmethod
def _validate_upload_parameters(
repo_id: str,
model_path: str,
repo_type: Optional[str] = None,
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
):
"""Validate parameters before uploading trained model artifacts.

This method checks if the input parameters meet the necessary requirements before uploading
trained model artifacts to the target repository.

Args:
repo_id (str): The ID of the target repository. It must be a namespace (user or an organization)
and a repository name separated by a '/'. For example, if your HF username is 'johndoe' and you
want to create a repository called 'test', the repo_id should be 'johndoe/test'.
model_path (str): The path to the directory containing the trained model artifacts. It should contain
the model's weights, usually saved under 'model/model_weights'.
repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses
may use it for specific repository implementations. Defaults to None.
private (bool, optional): Whether the repository should be private or not. Not used in the base class,
but subclasses may use it for specific repository implementations. Defaults to False.
commit_message (str, optional): A message to attach to the commit when uploading to version control
systems. Not used in the base class, but subclasses may use it for specific repository
implementations. Defaults to None.
commit_description (str, optional): A description of the commit when uploading to version control
systems. Not used in the base class, but subclasses may use it for specific repository
implementations. Defaults to None.

Raises:
AssertionError: If the repo_id does not have both a namespace and a repo name separated by a '/'.
"""
# Validate repo_id has both a namespace and a repo name
assert "/" in repo_id, (
"`repo_id` must be a namespace (user or an organization) and a repo name separated by a `/`."
" For example, if your HF username is `johndoe` and you want to create a repository called `test`, the"
" repo_id should be johndoe/test"
)
BaseModelUpload._validate_upload_parameters(
repo_id,
model_path,
repo_type,
private,
commit_message,
commit_description,
)

trained_model_artifacts_path = os.path.join(model_path, "model", "model_weights")
# Make sure the model's saved artifacts either contain:
# 1. pytorch_model.bin -> regular model training, such as ECD or for LLMs
# 2. adapter_model.bin -> LLM fine-tuning using PEFT
files = set(os.listdir(trained_model_artifacts_path))
if "pytorch_model.bin" not in files and "adapter_model.bin" not in files:
raise Exception(
f"Can't find model weights at {trained_model_artifacts_path}. Trained model weights should "
"either be saved as `pytorch_model.bin` for regular model training, or have `adapter_model.bin`"
"if using parameter efficient fine-tuning methods like LoRA."
)

def upload(
self,
repo_id: str,
Expand All @@ -150,6 +192,7 @@ def upload(
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
**kwargs,
) -> bool:
"""Create an empty repo on the HuggingFace Hub and upload trained model artifacts to that repo.

Expand Down Expand Up @@ -205,3 +248,142 @@ def upload(
return True

return False


class Predibase(BaseModelUpload):
def __init__(self):
self.pc = None
arnavgarg1 marked this conversation as resolved.
Show resolved Hide resolved

def login(self):
"""Login to Predibase using the token stored in the PREDIBASE_API_TOKEN environment variable and return a
PredibaseClient object that can be used to interact with Predibase."""
from predibase import PredibaseClient

token = os.environ.get("PREDIBASE_API_TOKEN")
if token is None:
raise ValueError(
"Unable to find PREDIBASE_API_TOKEN environment variable. Please log into Predibase, generate a token and use `export PREDIBASE_API_TOKEN=` to use Predibase"
)

try:
pc = PredibaseClient()

# TODO: Check if subscription has expired

self.pc = pc
except Exception as e:
raise Exception(f"Failed to login to Predibase: {e}")
return False

return True

@staticmethod
def _validate_upload_parameters(
repo_id: str,
model_path: str,
repo_type: Optional[str] = None,
private: Optional[bool] = False,
commit_message: Optional[str] = None,
commit_description: Optional[str] = None,
):
"""Validate parameters before uploading trained model artifacts.

This method checks if the input parameters meet the necessary requirements before uploading
trained model artifacts to the target repository.

Args:
repo_id (str): The ID of the target repository. It must be a less than 256 characters.
model_path (str): The path to the directory containing the trained model artifacts. It should contain
the model's weights, usually saved under 'model/model_weights'.
repo_type (str, optional): The type of the repository. Not used in the base class, but subclasses
may use it for specific repository implementations. Defaults to None.
private (bool, optional): Whether the repository should be private or not. Not used in the base class,
but subclasses may use it for specific repository implementations. Defaults to False.
commit_message (str, optional): A message to attach to the commit when uploading to version control
systems. Not used in the base class, but subclasses may use it for specific repository
implementations. Defaults to None.
commit_description (str, optional): A description of the commit when uploading to version control
systems. Not used in the base class, but subclasses may use it for specific repository
implementations. Defaults to None.

Raises:
AssertionError: If the repo_id has non-url safe characters.
martindavis marked this conversation as resolved.
Show resolved Hide resolved
"""
assert len(repo_id) <= 255, "`repo_id` must be 255 characters or less."
martindavis marked this conversation as resolved.
Show resolved Hide resolved
BaseModelUpload._validate_upload_parameters(
repo_id,
model_path,
repo_type,
private,
commit_message,
commit_description,
)

def upload(
self,
repo_id: str,
model_path: str,
commit_description: Optional[str] = None,
dataset_file: Optional[str] = None,
dataset_name: Optional[str] = None,
**kwargs,
) -> bool:
"""Create an empty repo in Predibase and upload trained model artifacts to that repo.

Args:
model_path (`str`):
The path of the saved model. This is the top level directory where
the models weights as well as other associated training artifacts
are saved.
repo_name (`str`):
A repo name.
repo_description (`str` *optional*):
The description of the repo.
dataset_file (`str` *optional*):
The path to the dataset file. Required if `service` is set to
`"predibase"` for new model repos.
dataset_name (`str` *optional*):
The name of the dataset. Used by the `service`
`"predibase"`. Falls back to the filename.
"""
# Validate upload parameters are in the right format
Predibase._validate_upload_parameters(
repo_id,
model_path,
None,
False,
"",
commit_description,
)

# Upload the dataset to Predibase
try:
dataset = self.pc.upload_dataset(file_path=dataset_file, name=dataset_name)
except Exception as e:
raise RuntimeError("Failed to upload dataset to Predibase") from e
return True

# Create empty model repo using repo_name, but it is okay if it already exists.
try:
repo = self.pc.create_model_repo(
name=repo_id,
description=commit_description,
exists_ok=True,
)
except Exception as e:
raise RuntimeError("Failed to create repo in Predibase") from e
return True

# Upload the zip file to Predibase
try:
self.pc.upload_model(
repo=repo,
model_path=model_path,
dataset=dataset,
)
except Exception as e:
raise RuntimeError("Failed to upload model to Predibase") from e
return True

logger.info(f"Model uploaded to Predibase with repository name `{repo_id}`")
return False
martindavis marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ pyxlsb>=1.0.8 # excel
pyarrow # parquet
lxml # html
html5lib # html

# Allows users to upload
predibase>=2023.10.2
martindavis marked this conversation as resolved.
Show resolved Hide resolved
Loading