In [1]:
from hypster import HP, config


@config
def hp_config(hp: HP):
    from hypernodes import NodeRegistry
    from hypernodes.mlflow_utils import (
        EnvironmentGenerator,
        HyperNodeMLFlow,
        MLFlowSetup,
        get_existing_dependencies_files,
    )

    registry = NodeRegistry.initialize()
    nodes = registry.list_nodes(require_hamilton_dags=True, require_hypster_config=True)
    if len(nodes) == 0:
        raise ValueError("No nodes found in registry")

    node_name = hp.select(nodes, default=nodes[0])
    node = registry.load(node_name)

    node_configs = hp.propagate(node.hypster_config, name=f"{node_name}_inputs")
    node.set_instantiated_config(node_configs)

    available_vars = [n.name for n in node._driver.list_available_variables()]
    if len(available_vars) == 0:
        raise ValueError("No available variables found")

    dynamic_inputs = []
    for var in available_vars:
        include_var = hp.select([False, True], default=False, name=f"{var}_is_input")
        if include_var:
            dynamic_inputs.append(var)

    final_vars = []
    for var in available_vars:
        include_var = hp.select([False, True], default=False, name=f"{var}_is_output")
        if include_var:
            final_vars.append(var)

    env_name = hp.text(f"{node_name}_env")
    dependencies_files = get_existing_dependencies_files()  # TODO: find project root
    if len(dependencies_files) == 0:
        dependency_file = hp.text()
    else:
        dependency_file = hp.select(dependencies_files, default=dependencies_files[0])

    extra_dependencies = hp.text("")  # multi_text
    from hypernodes.mlflow_utils import EnvironmentGenerator

    sys_python_version = EnvironmentGenerator.detect_python_version()
    python_version = hp.text(sys_python_version)
    conda_env = EnvironmentGenerator(
        env_name=f"{node_name}-env",
        dependency_file=dependency_file,
        extra_dependencies=extra_dependencies,
        python_version=python_version,
    ).to_conda()

    code_paths = ["src"]
    dotenv_file = ".env"
    mlflow_setup = MLFlowSetup(
        registry=registry,
        artifacts={},
        dotenv_file=dotenv_file,
        code_paths=code_paths,
        conda_env=conda_env,
    )  # example_input=, signature=

    values = {
        k.replace(f"{node_name}.", ""): v
        for k, v in hp.values.items()
        if k.startswith(f"{node_name}.")
    }

    model = HyperNodeMLFlow(
        node_name=node_name,
        dynamic_inputs=dynamic_inputs,
        final_vars=final_vars,
        values=values,
    )

In [2]:
values = {
    "node_name": "parent_node",
    "input_is_input": True,
    "downstream_is_output": True,
    "parent_node.basic_usage.llm_model": "haiku",
    "dependency_file": "requirements.txt",
}

In [3]:
inputs = hp_config(final_vars=["model", "mlflow_setup"], values=values)


# Log & Register Model

In [4]:
mlflow_setup = inputs["mlflow_setup"]
model_uri = mlflow_setup.log_model(inputs["model"])
model_reg = mlflow_setup.register_model(model_uri, "hypernodes_model")

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]

Registered model 'hypernodes_model' already exists. Creating a new version of this model...
Created version '63' of model 'hypernodes_model'.


# Test Locally

In [5]:
import pandas as pd

model = mlflow_setup.load_model(model_uri)

{'node_registry_path': 'node_registry_updated.yaml', 'registry': <hypernodes.registry.NodeRegistry object at 0x1759af610>, 'nested_node': <hypernodes.hypernode.HyperNode object at 0x1759af430>, 'basic_usage_inputs': {'data_path': 'data', 'env': 'dev', 'llm_model': 'claude-3-haiku-20240307'}, 'downstream_node': <hypernodes.hypernode.HyperNode object at 0x1759adf60>, 'input': 'testing'}


In [6]:
example = pd.DataFrame({"input": ["hey"]}, index=[0])
result = model.predict(example)

In [7]:
result

'Querying claude-3-haiku-20240307... hey'