In [1]:
from kfp import dsl, compiler, components

In [2]:
hf_downloader_component_url = "_components/hf_downloader.yaml"
download_artifact = components.load_component_from_file(hf_downloader_component_url)

osm_inference_component_url = "_components/osm_inference.yaml"
osm_inference = components.load_component_from_file(osm_inference_component_url)

In [3]:
@dsl.pipeline
def dowload_hf_artifact(
    repo_id: str,
    mapbox_token: str,
    repo_type: str = "model",
    filename: str = "model.pt",
    latitude: float = 0.0,
    longitude: float = 0.0,
    batch_size: int = 32,
):
    """Download a model from Hugging Face Hub.

    Args:
        :param repo_id: The Hugging Face Hub repo ID of the model to download.
    """
    download_document_step = download_artifact(
        repo_id=repo_id, repo_type=repo_type, filename=filename)
    download_document_step.set_caching_options(False)

    osm_inference_step = osm_inference(
        model_dir=download_document_step.outputs["output_dir"],
        model_name=filename,
        latitude=latitude,
        longitude=longitude,
        mapbox_token=mapbox_token,
        batch_size=batch_size,
    )

In [4]:
compiler.Compiler().compile(dowload_hf_artifact, package_path='download_hf_artifact.yaml')