In [1]:
def split_rationale_and_final_answer(generated_text: str):
    """
    Given a text that follows the pattern:
      "Step-by-step reasoning:\n...some rationale...\nFinal Answer:\n...the answer..."
    return a (rationale, final_answer) tuple.

    If the labels aren't found, defaults to empty strings.
    """
    # Label markers used in your STaR prompts
    rationale_marker = "Step-by-step reasoning:"
    answer_marker = "Final Answer:"

    # Initialize outputs
    rationale = ""
    final_ans = ""

    # Normalize line breaks (optional)
    text = generated_text.replace("\r", "")

    # Locate the markers
    rationale_start_idx = text.find(rationale_marker)
    answer_start_idx = text.find(answer_marker)

    if rationale_start_idx != -1:
        # Move index to start of actual rationale text (beyond the marker)
        rationale_start = rationale_start_idx + len(rationale_marker)
        if answer_start_idx != -1 and answer_start_idx > rationale_start:
            # Rationale is everything from rationale_start up to the "Final Answer:" marker
            rationale = text[rationale_start:answer_start_idx].strip()
        else:
            # If "Final Answer:" not found, or it's before rationale_start, treat the rest as rationale
            rationale = text[rationale_start:].strip()

    if answer_start_idx != -1:
        # Move index to start of final answer
        answer_start = answer_start_idx + len(answer_marker)
        final_ans = text[answer_start:].strip()

    return rationale, final_ans


In [2]:
# Suppose you have an LLM response:
response_text = (
    "Step-by-step reasoning:\n"
    "First, realize it is in Paris.\n"
    "Check that it's in France.\n"
    "Final Answer:\n"
    "The Eiffel Tower is in Paris, France."
)

rationale_part, final_part = split_rationale_and_final_answer(response_text)
print("Rationale:", rationale_part)
print("Final:", final_part)


Rationale: First, realize it is in Paris.
Check that it's in France.
Final: The Eiffel Tower is in Paris, France.


In [11]:
def explode_answers_into_rows(
    examples,
    question_col="question",
    answers_col="model_outputs",
    id_col="id",
    new_id_col="new_id",
    new_question_col="question",
    new_answer_col="answer"
):
    """
    This function will be called in batched mode by dataset.map(..., batched=True).
    `examples` is a dictionary of lists, e.g.:
       {
         "id": [id1, id2, ...],
         "question": [q1, q2, ...],
         "model_outputs": [[ans11, ans12], [ans21, ans22, ans23], ...]
       }

    We want to "explode" each list of answers into multiple rows.

    Return a dict of lists:
       {
         new_id_col: [...],
         new_question_col: [...],
         new_answer_col: [...]
       }
    so huggingface Dataset can expand them properly.
    """

    # 1) Extract the entire batch columns as lists
    batch_ids = examples[id_col]
    batch_questions = examples[question_col]
    batch_answers = examples[answers_col]

    # Prepare output lists
    out_new_ids = []
    out_questions = []
    out_answers = []

    # 2) Iterate over each row in this batch
    for i in range(len(batch_ids)):
        original_id = str(batch_ids[i])
        question_text = batch_questions[i]
        answers_list = batch_answers[i]

        # If it's not a list, wrap in a list
        if not isinstance(answers_list, list):
            answers_list = [answers_list]

        # 3) "Explode" each answer into a new row
        for idx_ans, ans in enumerate(answers_list):
            new_id = f"{original_id}_{idx_ans}"
            out_new_ids.append(new_id)
            out_questions.append(question_text)
            out_answers.append(ans)

    # 4) Return dict of lists (the new "expanded" columns)
    return {
        new_id_col: out_new_ids,
        new_question_col: out_questions,
        new_answer_col: out_answers
    }


In [13]:

# 3) Apply the explode function
exploded_dataset = explode_answers_into_rows(original_dataset)

print("Exploded Dataset:")
print(exploded_dataset)


Exploded Dataset:
{'new_id': ['rowA_0', 'rowA_1', 'rowA_2', 'rowB_0', 'rowB_1'], 'question': ['Where is the Eiffel Tower?', 'Where is the Eiffel Tower?', 'Where is the Eiffel Tower?', 'Who was the first president of the US?', 'Who was the first president of the US?'], 'answer': ['Paris', 'France', 'On Earth', 'George Washington', 'John Adams']}


In [15]:
Dataset.from_dict(exploded_dataset)

Dataset({
    features: ['new_id', 'question', 'answer'],
    num_rows: 5
})

In [None]:

    # 4) Verify the result
    #    We expect 3 + 2 = 5 rows total
    assert len(exploded_dataset) == 5, "Expected 5 rows in exploded dataset."
    #    Let's check a few row samples:
    row0 = exploded_dataset[0]
    assert row0["new_id"] == "rowA_0", "Unexpected ID for first exploded row."
    assert row0["question"] == "Where is the Eiffel Tower?", "Question mismatch"
    assert row0["answer"] == "Paris", "Answer mismatch"

    row3 = exploded_dataset[3]  # first row of second question
    assert row3["new_id"] == "rowB_0", "Unexpected ID for rowB_0."
    assert row3["answer"] == "George Washington", "Answer mismatch for rowB_0"

    print("All tests passed! :)")


In [2]:
test_explode_answers_into_rows()


Original Dataset:
Dataset({
    features: ['id', 'question', 'model_outputs'],
    num_rows: 2
})
     id                                question  \
0  rowA              Where is the Eiffel Tower?   
1  rowB  Who was the first president of the US?   

                     model_outputs  
0        [Paris, France, On Earth]  
1  [George Washington, John Adams]   



Map:   0%|          | 0/2 [00:00<?, ? examples/s]

TypeError: Provided `function` which is applied to all elements of table returns a variable of type <class 'list'>. Make sure provided `function` returns a variable of type `dict` (or a pyarrow table) to update the dataset or `None` if you are only interested in side effects.