In [1]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
import re

In [2]:
device = "cpu"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

In [18]:
def generate_subtasks(task_description, model, tokenizer, num_subtasks):
    prompt = (
        f"You are a task planner. Break down the following task into exactly {num_subtasks} clear and actionable steps. "
        "Each subtask should be practical, specific, and easy to follow. The subtasks should be ordered logically, "
        "and should focus on accomplishing the task in a methodical way. Avoid any filler, general explanations, or placeholders. "
        "The goal is for someone to be able to follow these steps and complete "
        "the task without needing further clarification.\n\n"
        f"Task: {task_description}\n\n"
        "Subtasks:"
    )

    tokenizer.pad_token = tokenizer.eos_token

    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
        
        outputs = model.generate(
            **inputs,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            num_beams=1,
            early_stopping=True,
            max_new_tokens=256
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    subtasks_section = generated_text.split("Subtasks:", 1)[-1].strip()
    subtasks = []
    
    # Match valid subtasks that start with a number followed by a period and space
    valid_subtask_regex = r"^\d+\.\s+"

    for line in subtasks_section.split("\n"):
        line = line.strip()
        if line and re.match(valid_subtask_regex, line):
            subtask_text = re.sub(r"^\d+\.\s*", "", line)  # Remove numbering
            subtasks.append(subtask_text.strip())  # Add the subtask text

    # Ensure exactly 'num_subtasks' subtasks, truncating if necessary
    subtasks = subtasks[:num_subtasks]

    # If there are fewer subtasks, fill in placeholders but avoid adding placeholders if it is an incomplete task
    while len(subtasks) < num_subtasks:
        subtasks.append("Complete the task.")

    return subtasks

In [19]:
subtasks = generate_subtasks("plan wedding", model, tokenizer, num_subtasks=8)

In [20]:
for i, subtask in enumerate(subtasks, 1):
    print(f"{i}. {subtask}")

1. Research wedding venues
2. Create a guest list
3. Choose a date and time for the wedding
4. Book a caterer
5. Arrange for flowers and decorations
6. Plan the ceremony and reception
7. Send out invitations
8. Prepare for last-minute changes
