# Creating a new task in Agent

This notebook will take you through all the steps required to create a task from scratch, so that you can do the same for your custom environments. We will look at creating everything needed for the GSM8k task. This task is already a part of the framework but we will recreate it from scratch so that you can follow the same steps for your own tasks.

Note: Please make sure you have first installed all dependencies of Agent, following the installation guide!

## Step 1: Create prompt templates

Create a new directory under `../src/agent/prompts/templates` which will define custom prompt templates for your new task. These are the boilerplate .jinja files for any messages sent to the LLM. Here we look at creating templates for direct prompting in GSM8k.

First we need a system prompt, which will be the system prompt passed on to the LLM and should define the task:

In [None]:
import os

template_dir = "../src/agent/prompts/templates/example_gsm8k/"
os.makedirs(os.path.dirname(template_dir), exist_ok=True)


system_prompt = """
[[ SYSTEM ]]
You are an expert mathematician. You are provided with mathematical questions that you have to answer.
When asked for an answer, your response should use the following format:
Answer: <answer>
"""

with open(f"{template_dir}system_prompt.jinja", "w") as file:
    file.write(system_prompt)

Next, we want to define a template for the prompting method we are using, here direct prompting. It should instruct the LLM as to what to answer, and in what format:

In [None]:
direct_prompt = """
Now please answer the question.
Answer in the format
Answer: <answer>
"""

with open(f"{template_dir}direct_prompt.jinja", "w") as file:
    file.write(direct_prompt)

Finally, we want to define how the trajectory of past observations (or possibly other content) should be presented to the LLM:

In [None]:
trajectory_prompt = """
Question: {{memory.retrieve(memory.mem_keys.OBSERVATION)}}
"""

with open(f"{template_dir}trajectory.jinja", "w") as file:
    file.write(trajectory_prompt)

The `../src/agent/prompts/templates/default/` folder defines default templates which should come in handy for many tasks. In this case the default external_action.jinja template defines all the additional boilerplate we need, and will pull the direct prompt and trajectory prompt from the files we defined above when prompting the LLM for an action.

# Step 2: Create a new Agent Task

Here we create a new class inheriting from the Agent `Task`, which should contain at minimum the following crucial methods:
- `reset()` - to reset the task and return the first observation
- `step(action)` - to take a step in the task when given an action, returning an observation, a reward, and a `done` boolean
- `answer_parser(raw_response)` - to parse the actual answer content needed from the raw LLM response

In [None]:
%%writefile ./../src/agent/tasks/example_gsm8k.py
import math
import re
from typing import Any, Dict
from datasets import load_dataset
from agent.memory import MemKey
from agent.tasks import ActionSpace
from agent.tasks import DatasetOutOfBoundsException
from agent.tasks import Task


class GSM8K(Task):
    def __init__(self, split: str, **kwargs):
        super().__init__(**kwargs)

        self.action_space = ActionSpace.CONTINUOUS
        self.args = kwargs
        self.dataset = load_dataset("gsm8k", "main", split=split)
        self.episode_counter = 0

    def reset(self, next_subtask: str | None = None) -> Dict[str, str]:
        """Reset the environment and return the initial observation."""

        if next_subtask is not None:
            self.episode_counter = int(next_subtask)

        if self.episode_counter > len(self.dataset):
            raise DatasetOutOfBoundsException(
                "The dataset index is not within dataset bounds. The end of the dataset may have been reached."
            )

        data = self.dataset[self.episode_counter]
        self.answer = float(data["answer"].split("\n####")[-1].replace(",", ""))
        return self._return_observation(data)

    def answer_parser(self, raw_response: str):
        try:
            proposed_answer = re.findall(r"[-+]?(?:\d*\.*\d+)", raw_response.replace(",", ""))[-1]
        except IndexError:
            proposed_answer = ""
        return proposed_answer

    def step(self, action: str) -> tuple[dict, float, bool]:
        """Perform an action and return the next observation, reward, and done."""

        try:
            reward = 1 if math.isclose(float(action), self.answer) else 0
        except Exception:
            reward = 0
        self.episode_counter += 1
        return {}, reward, True

    def _return_observation(self, data: Dict[str, Any]) -> Dict[str, str]:
        """Return the observation for the current step."""

        return {MemKey.OBSERVATION: data["question"]}

## Step 3: Create a task configuration

Create a new configuration file under `../configs/task/` which should define the task parameters, as well as the names of the directories to use for prompt templates (see [Step 1](#step-1-create-prompt-templates)), in priority order.

In [None]:
config_data = """
# @package _global_
agent:
  prompt_builder:
    template_paths:
      - gsm8k
      - default

task:
  _target_: src.agent.tasks.example_gsm8k.GSM8K
  name: gsm8k
  subtask: null
  version: v0.1
  description:
  split: test

max_episodes: 3
"""

with open("../configs/task/example_gsm8k.yaml", "w") as file:
    file.write(config_data)

## Step 4: Running the task

Assuming you are using an existing LLM configuration, say OpenChat-3.5, and prompting method, say direct prompting, the task can be run as follows. The task argument must be the name of the config file defined in [Step 3](#step-3-create-a-task-configuration).

In [None]:
!python ../src/agent/start.py task=example_gsm8k method=direct llm@agent.llm=hf/openchat_3.5

## Clean up files

Here we clean up the files created by this notebook to avoid cluttering the framework.

In [None]:
delete_files = [
    f"{template_dir}system_prompt.jinja",
    f"{template_dir}direct_prompt.jinja",
    f"{template_dir}trajectory.jinja",
    "../src/agent/tasks/example_gsm8k.py",
    "../configs/task/example_gsm8k.yaml",
]

for file in delete_files:
    try:
        os.remove(file)
    except OSError as e:
        print(e)

try:
    os.rmdir(template_dir)
except OSError as e:
    print(e)