In [10]:
import torch
import transformers
from transformers import AutoTokenizer, T5ForConditionalGeneration
import re

In [11]:
# Load the T5-small tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [12]:
def generate_subtasks(task_description, model, tokenizer, num_subtasks, max_length=500):
    # Simplified prompt without motivation aspect
    prompt = (
        f"Generate a list of {num_subtasks} clear and actionable steps for the task: {task_description}\n\n"
        "Steps:"
    )

    tokenizer.pad_token = tokenizer.eos_token

    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

        outputs = model.generate(
            **inputs,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            num_beams=1,
            early_stopping=True,
            max_new_tokens=256
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    steps_section = generated_text.split("Steps:", 1)[-1].strip()
    subtasks = []

    valid_step_regex = r"^[0-9]+\. .+"  # Match lines that start with a number followed by a dot and a space

    for line in steps_section.split("\n"):
        line = line.strip()
        if line and re.match(valid_step_regex, line):
            # Remove the numbering (the part before the first dot)
            step_text = re.sub(r"^\d+\.\s*", "", line)
            # Limit to max_length words and remove the trailing period
            truncated_step = " ".join(step_text.split()[:max_length]).rstrip(".")
            subtasks.append(truncated_step)

    # Ensure exactly 'num_subtasks' subtasks
    while len(subtasks) < num_subtasks:
        subtasks.append("Placeholder")
    subtasks = subtasks[:num_subtasks]  # If there are too many, truncate to the required number

    return subtasks

In [13]:
subtasks = generate_subtasks("how to cook eggs", model, tokenizer, num_subtasks=7)

In [14]:
print("\n".join(f"- {subtask}" for subtask in subtasks))

- Placeholder
- Placeholder
- Placeholder
- Placeholder
- Placeholder
- Placeholder
- Placeholder
