Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/cli_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,28 @@ $ llama-stack-client shields list
| llama_guard | {} | meta-reference | llama_guard |
+--------------+----------+----------------+-------------+
```

#### `llama-stack-client eval_tasks list`
```bash
$ llama-stack-client eval run_benchmark <task_id1> <task_id2> --num-examples 10 --output-dir ./ --eval-task-config ~/eval_task_config.json
```

where `eval_task_config.json` is the path to the eval task config file in JSON format. An example eval_task_config
```
$ cat ~/eval_task_config.json
{
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "Llama3.1-405B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0
}
}
}
```
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .eval import eval

__all__ = ["eval"]
20 changes: 20 additions & 0 deletions src/llama_stack_client/lib/cli/eval/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


import click

from .run_benchmark import run_benchmark


@click.group()
def eval():
"""Run evaluation tasks"""
pass


# Register subcommands
eval.add_command(run_benchmark)
82 changes: 82 additions & 0 deletions src/llama_stack_client/lib/cli/eval/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import os
from typing import Optional

import click
from tqdm.rich import tqdm


@click.command("run_benchmark")
@click.argument("eval-task-ids", nargs=-1, required=True)
@click.option(
"--eval-task-config",
required=True,
help="Path to the eval task config file in JSON format",
type=click.Path(exists=True),
)
@click.option(
"--output-dir",
required=True,
help="Path to the dump eval results output directory",
)
@click.option(
"--num-examples", required=False, help="Number of examples to evaluate on, useful for debugging", default=None
)
@click.pass_context
def run_benchmark(
ctx, eval_task_ids: tuple[str, ...], eval_task_config: str, output_dir: str, num_examples: Optional[int]
):
"""Run a evaluation benchmark"""

client = ctx.obj["client"]

for eval_task_id in eval_task_ids:
eval_task = client.eval_tasks.retrieve(name=eval_task_id)
scoring_functions = eval_task.scoring_functions
dataset_id = eval_task.dataset_id

rows = client.datasetio.get_rows_paginated(
dataset_id=dataset_id, rows_in_page=-1 if num_examples is None else num_examples
)

with open(eval_task_config, "r") as f:
eval_task_config = json.load(f)

output_res = {}

for r in tqdm(rows.rows):
eval_res = client.eval.evaluate_rows(
task_id=eval_task_id,
input_rows=[r],
scoring_functions=scoring_functions,
task_config=eval_task_config,
)
for k in r.keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(r[k])

for k in eval_res.generations[0].keys():
if k not in output_res:
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])

for scoring_fn in scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Save results to JSON file
output_file = os.path.join(output_dir, f"{eval_task_id}_results.json")
with open(output_file, "w") as f:
json.dump(output_res, f, indent=2)

print(f"Results saved to: {output_file}")
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/eval_tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
# the root directory of this source tree.

from .eval_tasks import eval_tasks

__all__ = ["eval_tasks"]
12 changes: 11 additions & 1 deletion src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os

import click
import yaml

Expand All @@ -12,6 +14,7 @@

from .constants import get_config_file_path
from .datasets import datasets
from .eval import eval
from .eval_tasks import eval_tasks
from .memory_banks import memory_banks
from .models import models
Expand Down Expand Up @@ -50,7 +53,13 @@ def cli(ctx, endpoint: str, config: str | None):
if endpoint == "":
endpoint = "http://localhost:5000"

client = LlamaStackClient(base_url=endpoint)
client = LlamaStackClient(
base_url=endpoint,
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
"togethers_api_key": os.environ.get("TOGETHERS_API_KEY", ""),
},
)
ctx.obj = {"client": client}


Expand All @@ -63,6 +72,7 @@ def cli(ctx, endpoint: str, config: str | None):
cli.add_command(datasets, "datasets")
cli.add_command(configure, "configure")
cli.add_command(scoring_functions, "scoring_functions")
cli.add_command(eval, "eval")


def main():
Expand Down
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/memory_banks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
# the root directory of this source tree.

from .memory_banks import memory_banks

__all__ = ["memory_banks"]
19 changes: 0 additions & 19 deletions src/llama_stack_client/lib/cli/subcommand.py

This file was deleted.