Skip to content

Commit

Permalink
feat: Add crew train cli (#624)
Browse files Browse the repository at this point in the history
* fix: fix crewai-tools cli command

* feat: add crewai train CLI command

* feat: add the tests

* fix: fix typing hinting issue on code

* fix: test.yml

* fix: fix test

* fix: removed fix since it didnt changed the test
  • Loading branch information
pythonbyte committed May 23, 2024
1 parent a336381 commit 24ed8a2
Show file tree
Hide file tree
Showing 16 changed files with 278 additions and 45 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ jobs:

- name: Install Requirements
run: |
sudo apt-get update &&
pip install poetry &&
pip install poetry
poetry lock &&
poetry install
- name: Run tests
run: poetry run pytest
run: poetry run pytest tests
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
# Run the linter.
- id: ruff
args: [--fix]
args: ["--fix"]
exclude: "templates"
- id: ruff-format
exclude: "templates"
17 changes: 16 additions & 1 deletion src/crewai/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pkg_resources

from .create_crew import create_crew
from .train_crew import train_crew


@click.group()
Expand All @@ -27,11 +28,25 @@ def version(tools):

if tools:
try:
tools_version = pkg_resources.get_distribution("crewai[tools]").version
tools_version = pkg_resources.get_distribution("crewai-tools").version
click.echo(f"crewai tools version: {tools_version}")
except pkg_resources.DistributionNotFound:
click.echo("crewai tools not installed")


@crewai.command()
@click.option(
"-n",
"--n_iterations",
type=int,
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)


if __name__ == "__main__":
crewai()
14 changes: 13 additions & 1 deletion src/crewai/cli/templates/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
import sys
from {{folder_name}}.crew import {{crew_name}}Crew


Expand All @@ -7,4 +8,15 @@ def run():
inputs = {
'topic': 'AI LLMs'
}
{{crew_name}}Crew().crew().kickoff(inputs=inputs)
{{crew_name}}Crew().crew().kickoff(inputs=inputs)


def train():
"""
Train the crew for a given number of iterations.
"""
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]))

except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")
5 changes: 3 additions & 2 deletions src/crewai/cli/templates/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ authors = ["Your Name <you@example.com>"]

[tool.poetry.dependencies]
python = ">=3.10,<=3.13"
crewai = {extras = ["tools"], version = "^0.30.11"}
crewai = { extras = ["tools"], version = "^0.30.11" }

[tool.poetry.scripts]
{{folder_name}} = "{{folder_name}}.main:run"
train = "{{folder_name}}.main:train"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
build-backend = "poetry.core.masonry.api"
29 changes: 29 additions & 0 deletions src/crewai/cli/train_crew.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import subprocess

import click


def train_crew(n_iterations: int) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]

try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")

result = subprocess.run(command, capture_output=False, text=True, check=True)

if result.stderr:
click.echo(result.stderr, err=True)

except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while training the crew: {e}", err=True)
click.echo(e.output, err=True)

except Exception as e:
click.echo(f"An unexpected error occurred: {e}", err=True)
8 changes: 7 additions & 1 deletion src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def create_crew_memory(self) -> "Crew":
"""Set private attributes."""
if self.memory:
self._long_term_memory = LongTermMemory()
self._short_term_memory = ShortTermMemory(crew=self, embedder_config=self.embedder)
self._short_term_memory = ShortTermMemory(
crew=self, embedder_config=self.embedder
)
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
return self

Expand Down Expand Up @@ -280,6 +282,10 @@ def kickoff(self, inputs: Optional[Dict[str, Any]] = {}) -> str:

return result

def train(self, n_iterations: int) -> None:
# TODO: Implement training
pass

def _run_sequential_process(self) -> str:
"""Executes tasks sequentially and returns the final output."""
task_output = ""
Expand Down
5 changes: 4 additions & 1 deletion src/crewai/memory/entity/entity_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ class EntityMemory(Memory):

def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(
type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew
type="entities",
allow_reset=False,
embedder_config=embedder_config,
crew=crew,
)
super().__init__(storage)

Expand Down
4 changes: 3 additions & 1 deletion src/crewai/memory/short_term/short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class ShortTermMemory(Memory):
"""

def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(type="short_term", embedder_config=embedder_config, crew=crew)
storage = RAGStorage(
type="short_term", embedder_config=embedder_config, crew=crew
)
super().__init__(storage)

def save(self, item: ShortTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
Expand Down
8 changes: 4 additions & 4 deletions src/crewai/project/crew_base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import inspect
import yaml
import os

from pathlib import Path
from pydantic import ConfigDict

import yaml
from dotenv import load_dotenv
from pydantic import ConfigDict

load_dotenv()


def CrewBase(cls):
class WrappedClass(cls):
model_config = ConfigDict(arbitrary_types_allowed=True)
is_crew_class: bool = True
is_crew_class: bool = True # type: ignore

base_directory = None
for frame_info in inspect.stack():
Expand Down
2 changes: 1 addition & 1 deletion src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _save_file(self, result: Any) -> None:
if directory and not os.path.exists(directory):
os.makedirs(directory)

with open(self.output_file, "w", encoding='utf-8') as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
with open(self.output_file, "w", encoding="utf-8") as file: # type: ignore # Argument 1 to "open" has incompatible type "str | None"; expected "int | str | bytes | PathLike[str] | PathLike[bytes]"
file.write(result)
return None

Expand Down
22 changes: 14 additions & 8 deletions src/crewai/tools/agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,26 @@ def tools(self):
]
return tools

def delegate_work(self, task: str, context: str, coworker: Union[str, None] = None, **kwargs):
def delegate_work(
self, task: str, context: str, coworker: Union[str, None] = None, **kwargs
):
"""Useful to delegate a specific task to a co-worker passing all necessary context and names."""
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
is_list = coworker.startswith("[") and coworker.endswith("]")
if is_list:
coworker = coworker[1:-1].split(",")[0]
if coworker is not None:
is_list = coworker.startswith("[") and coworker.endswith("]")
if is_list:
coworker = coworker[1:-1].split(",")[0]
return self._execute(coworker, task, context)

def ask_question(self, question: str, context: str, coworker: Union[str, None] = None, **kwargs):
def ask_question(
self, question: str, context: str, coworker: Union[str, None] = None, **kwargs
):
"""Useful to ask a question, opinion or take from a co-worker passing all necessary context and names."""
coworker = coworker or kwargs.get("co_worker") or kwargs.get("co-worker")
is_list = coworker.startswith("[") and coworker.endswith("]")
if is_list:
coworker = coworker[1:-1].split(",")[0]
if coworker is not None:
is_list = coworker.startswith("[") and coworker.endswith("]")
if is_list:
coworker = coworker[1:-1].split(",")[0]
return self._execute(coworker, question, context)

def _execute(self, agent, task, context):
Expand Down
59 changes: 59 additions & 0 deletions tests/cli/cli_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest import mock

import pytest
from click.testing import CliRunner

from crewai.cli.cli import train, version


@pytest.fixture
def runner():
return CliRunner()


@mock.patch("crewai.cli.cli.train_crew")
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)

train_crew.assert_called_once_with(5)
assert result.exit_code == 0
assert "Training the crew for 5 iterations" in result.output


@mock.patch("crewai.cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])

train_crew.assert_called_once_with(10)
assert result.exit_code == 0
assert "Training the crew for 10 iterations" in result.output


@mock.patch("crewai.cli.cli.train_crew")
def test_train_invalid_string_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "invalid"])

train_crew.assert_not_called()
assert result.exit_code == 2
assert (
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
in result.output
)


def test_version_command(runner):
result = runner.invoke(version)

assert result.exit_code == 0
assert "crewai version:" in result.output


def test_version_command_with_tools(runner):
result = runner.invoke(version, ["--tools"])

assert result.exit_code == 0
assert "crewai version:" in result.output
assert (
"crewai tools version:" in result.output
or "crewai tools not installed" in result.output
)
87 changes: 87 additions & 0 deletions tests/cli/train_crew_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import subprocess
from unittest import mock

from crewai.cli.train_crew import train_crew


@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_positive_iterations(mock_subprocess_run):
# Arrange
n_iterations = 5
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=["poetry", "run", "train", str(n_iterations)],
returncode=0,
stdout="Success",
stderr="",
)

# Act
train_crew(n_iterations)

# Assert
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", str(n_iterations)],
capture_output=False,
text=True,
check=True,
)


@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_zero_iterations(click):
train_crew(0)
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
)


@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_negative_iterations(click):
train_crew(-2)
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
)


@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_called_process_error(mock_subprocess_run, click):
n_iterations = 5
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
returncode=1,
cmd=["poetry", "run", "train", str(n_iterations)],
output="Error",
stderr="Some error occurred",
)
train_crew(n_iterations)

mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
)
click.echo.assert_has_calls(
[
mock.call.echo(
"An error occurred while training the crew: Command '['poetry', 'run', 'train', '5']' returned non-zero exit status 1.",
err=True,
),
mock.call.echo("Error", err=True),
]
)


@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
# Arrange
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
train_crew(n_iterations)

mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True
)
Loading

0 comments on commit 24ed8a2

Please sign in to comment.