# core

> Fill in a module description here

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| default_exp task

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import uuid
from enum import Enum
from typing import Optional, List, Union
from pydantic import BaseModel, Field
from iagent.core import FunctionInput, FunctionResult

class TaskStatus(str, Enum):
    PENDING = "pending"
    STARTING = "starting"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"

class TaskEvent(str, Enum):
    STARTING = "starting"
    COMPLETED = "completed"
    FAILED = "failed"
    RUNNING = "running"
    SUB_TASK_ADDED = "sub_task_added"
    SUB_TASK_REMOVED = "sub_task_removed"

class TaskListener:
    """
    Listener interface for Task events.
    Implement on_event to handle task events.
    """
    def on_event(self, task: 'Task', event: TaskEvent, **kwargs):
        pass

class Task(BaseModel):
    """
    Task is a self-contain work unit i.e. doesn't depend on further context, 
    But could have a list of sub tasks
    """
    id: str = Field(..., description="Unique identifier for the task")
    name: Optional[str] = Field(None, description="Name of the task")
    instruction: str = Field(..., description="Instruction for the task")
    input: Optional[FunctionInput] = Field(None, description="Input for the task")
    output: Optional[FunctionResult] = Field(None, description="Output for the task, could be error")
    status: TaskStatus = Field(TaskStatus.PENDING, description="Status of the task")
    sub_tasks: Optional[list['Task']] = Field(None, description="List of sub-tasks")
    _listeners: List[TaskListener] = []

    def add_listener(self, listener: TaskListener):
        if not hasattr(self, "_listeners") or self._listeners is None:
            self._listeners = []
        self._listeners.append(listener)

    def remove_listener(self, listener: TaskListener):
        if hasattr(self, "_listeners") and listener in self._listeners:
            self._listeners.remove(listener)

    def _notify(self, event: TaskEvent, **kwargs):
        if hasattr(self, "_listeners"):
            for listener in self._listeners:
                listener.on_event(self, event, **kwargs)

    def start(self):
        self.status = TaskStatus.STARTING
        self._notify(TaskEvent.STARTING)

    def run(self):
        self.status = TaskStatus.RUNNING
        self._notify(TaskEvent.RUNNING)

    def complete(self, output: Union[str, BaseModel]=None):
        self.status = TaskStatus.COMPLETED
        self.output = FunctionResult(success=False, output=output)  
        self._notify(TaskEvent.COMPLETED)

    def fail(self, reason: str = ""):
        self.status = TaskStatus.FAILED
        self.output = FunctionResult(success=False, error=reason)  
        self._notify(TaskEvent.FAILED, reason=reason)

    def add_sub_task(self, task: 'Task'):
        """
        Add a sub-task to the task.
        """
        if self.sub_tasks is None:
            self.sub_tasks = []
        self.sub_tasks.append(task)
        self._notify(TaskEvent.SUB_TASK_ADDED, sub_task=task)
    
    def remove_sub_task(self, task: 'Task'):
        """
        Remove a sub-task from the task.
        """
        if self.sub_tasks is not None:
            self.sub_tasks.remove(task)
            self._notify(TaskEvent.SUB_TASK_REMOVED, sub_task=task)

    @classmethod
    def from_instruction(cls, instruction: str, id: Optional[str] = None, **kwargs):
        """
        Create a Task from instruction and optional id. If id is not provided, generate a UUID.
        Additional fields can be passed via kwargs.
        """
        if id is None:
            id = str(uuid.uuid4())
        return cls(id=id, instruction=instruction, **kwargs)



In [None]:
# Test cases for Task class

def test_task_creation():
    task = Task.from_instruction("Test instruction", name="Test Task")
    assert task.instruction == "Test instruction"
    assert task.name == "Test Task"
    assert task.status == TaskStatus.PENDING
    assert isinstance(task.id, str)

def test_add_subtask():
    parent = Task.from_instruction("Parent task")
    sub = Task.from_instruction("Sub task")
    parent.add_sub_task(sub)
    assert parent.sub_tasks is not None
    assert len(parent.sub_tasks) == 1
    assert parent.sub_tasks[0].instruction == "Sub task"

def test_remove_subtask():
    parent = Task.from_instruction("Parent task")
    sub1 = Task.from_instruction("Sub task 1")
    sub2 = Task.from_instruction("Sub task 2")
    parent.add_sub_task(sub1)
    parent.add_sub_task(sub2)

    # Remove sub1
    parent.remove_sub_task(sub1)

    assert len(parent.sub_tasks) == 1
    assert parent.sub_tasks[0].instruction == "Sub task 2"

def test_start_and_run_task():
    task = Task.from_instruction("Start and run")
    task.status = TaskStatus.STARTING
    assert task.status == TaskStatus.STARTING
    task.status = TaskStatus.RUNNING
    assert task.status == TaskStatus.RUNNING

def test_task_failed():
    task = Task.from_instruction("Failing task")
    task.status = TaskStatus.FAILED
    assert task.status == TaskStatus.FAILED

In [None]:
# Example listener for testing
class TestListener(TaskListener):
    def __init__(self):
        self.events = []
    def on_event(self, task, event, **kwargs):
        self.events.append((event, kwargs))

def test_task_event_notifications():
    task = Task.from_instruction("Test with events")
    listener = TestListener()
    task.add_listener(listener)

    # Start task
    task.start()
    assert listener.events[-1][0] == TaskEvent.STARTING

    # Run task
    task.run()
    assert listener.events[-1][0] == TaskEvent.RUNNING

    # Complete task
    task.complete("done")
    assert listener.events[-1][0] == TaskEvent.COMPLETED

    # Fail task
    task.fail("error")
    assert listener.events[-1][0] == TaskEvent.FAILED

def test_subtask_event_notifications():
    parent = Task.from_instruction("Parent")
    sub = Task.from_instruction("Sub")
    listener = TestListener()
    parent.add_listener(listener)

    parent.add_sub_task(sub)
    assert listener.events[-1][0] == TaskEvent.SUB_TASK_ADDED
    assert listener.events[-1][1]['sub_task'] == sub

    parent.remove_sub_task(sub)
    assert listener.events[-1][0] == TaskEvent.SUB_TASK_REMOVED
    assert listener.events[-1][1]['sub_task'] == sub


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()