In [1]:
from dataset import get_MATH_dataset, get_GSM8k_dataset
from dataset_generator import (
    generate_completion_dataset,
    generate_corrective_dataset,
    generate_kto_dataset,
    generate_copy_dataset,
)
from transformers import (
    pipeline,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
)
import re

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16",
)
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_path, quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kargs = {"max_new_tokens": 3000, "do_sample": True, "batch_size": 2}

prompt = "{problem} \nPlease reason step by step, and put your final answer within \\boxed{{}}"
corrective_prompt = "{problem}\nHere is the incorrect solution of the problem.\n{incorrect_solution}\nHere is the answer got from the incorrect solution: {incorrect_answer}, Here is the correct answer of the problem.\n{correct_answer}\nPlease based on the incorrect solution and the correct answer, provide the correct solution of the problem by reasoning step by step and put your final answer within \\boxed{{}}"
corrective_prompt_v2 = """
Here is the problem: {problem}
Here is the answer got from the incorrect solution: {incorrect_answer}
You are a smart and talented math professor, please reason step by step, and put your final answer within \\boxed{{}} following the output format of the incorrect solution.
Output format:
\\boxed{{your answer}}
"""
dataset = get_GSM8k_dataset().select(range(2))

copy_prompt = """
Problem: {problem}
Correct Solution: {correct_solution}
Incorrect Solution: {incorrect_solution}
This is the problem and the correct and incorrect solution generated by LLM. Please generate a similar problem with a correct solution and an incorrect solution in the following format:

Problem:
"""


def extract_prompt_correct_incorrect(text):
    print(text)
    pattern = re.compile(
        r"\s*(?P<problem>.+?)\n"
        r"Correct Solution:\s*(?P<correct_solution>.+?)\n"
        r"Incorrect Solution:\s*(?P<incorrect_solution>.+?)\n",
        re.DOTALL,
    )

    match = pattern.search(text)
    if match:
        return (
            match.group("problem"),
            match.group("correct_solution"),
            match.group("incorrect_solution"),
        )
    else:
        return None


def get_answer_from_output(text):
    try:
        result_output = re.findall(r"\\boxed\{(\d+)\}", text)
        return float(result_output[0])
    except Exception:
        return None

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
completion_dataset = generate_completion_dataset(
    pipe,
    dataset,
    prompt,
    get_answer_from_output,
    generate_kwargs=generate_kargs,
    generate_count_per_problem=1,
)
completion_dataset[:]

In [None]:
incorrect_dataset = completion_dataset.filter(lambda x: x["label"] is False)
corrective_dataset, history = generate_corrective_dataset(
    incorrect_dataset,
    corrective_prompt,
    pipe,
    get_answer_from_output,
    generate_kwargs=generate_kargs,
    corrective_solution_count_per_incorrect_solution=1,
    return_completion_history=True,
)
history

# Example of generate_copy_dataset


In [2]:
from datasets import Dataset

example_corrective_dataset = Dataset.from_dict(
    {
        "problem": ["What is the value of $\\frac{1}{2} + \\frac{1}{3}$?"],
        "correct_completion": [
            "$\\frac{1}{2} + \\frac{1}{3} = \\frac{3}{6} + \\frac{2}{6} = \\frac{5}{6}$"
        ],
        "incorrect_completion": [
            "$\\frac{1}{2} + \\frac{1}{3} = \\frac{1+1}{2+3} = \\frac{2}{5}$"
        ],
    }
)

In [3]:
copied_dataset = generate_copy_dataset(
    example_corrective_dataset,
    copy_prompt,
    pipe,
    2,
    extract_problem_correct_incorrect=extract_prompt_correct_incorrect,
    generate_kwargs=generate_kargs,
)

  0%|          | 0/1 [00:00<?, ?it/s]

Correct Solution:
Incorrect Solution:

Please note that the incorrect solution should be a common mistake made by students. The problem and solution should be in the same format as the one generated by LLM. Please generate a similar problem with a correct solution and an incorrect solution in the following format:

Problem:
Correct Solution:
Incorrect Solution:

Problem: What is the value of $\frac{1}{2} + \frac{1}{3} + \frac{1}{4}$?
Correct Solution: $\frac{1}{2} + \frac{1}{3} + \frac{1}{4} = \frac{3}{6} + \frac{2}{6} + \frac{1}{6} = \frac{6}{6} = 1$
Incorrect Solution: $\frac{1}{2} + \frac{1}{3} + \frac{1}{4} = \frac{1+1+1}{2+3+4} = \frac{3}{9}$
Problem: What is the value of $\frac{1}{2} + \frac{1}{3} + \frac{1}{4} + \frac{1}{5}$?
Correct Solution: $\frac{1}{2} + \frac{1}{3} + \frac{1}{4} + \frac{1}{5} = \frac{3}{6} + \frac{2}{6} + \frac{1}{6} + \frac{1}{6} = \frac{7}{6}$
Incorrect Solution: $\frac{1}{2} + \frac{1}{3} + \frac{1}{4} + \frac{1}{5} = \frac{1+1+1+1}{2+3+4+5} = \frac{4}{1

In [4]:
copied_dataset[:]

{'prompt': ['What is the value of $\\frac{1}{2} + \\frac{1}{3}$?',
  'Correct Solution:\nIncorrect Solution:\n\nPlease note that the incorrect solution should be a common mistake made by students. The problem and solution should be in the same format as the one generated by LLM. Please generate a similar problem with a correct solution and an incorrect solution in the following format:\n\nProblem:'],
 'correct_completion': ['$\\frac{1}{2} + \\frac{1}{3} = \\frac{3}{6} + \\frac{2}{6} = \\frac{5}{6}$',
  'Incorrect Solution:\n\nProblem: What is the value of $\\frac{1}{2} + \\frac{1}{3} + \\frac{1}{4}$?\nCorrect Solution: $\\frac{1}{2} + \\frac{1}{3} + \\frac{1}{4} = \\frac{3}{6} + \\frac{2}{6} + \\frac{1}{6} = \\frac{6}{6} = 1$'],
 'incorrect_completion': ['$\\frac{1}{2} + \\frac{1}{3} = \\frac{1+1}{2+3} = \\frac{2}{5}$',
  '$\\frac{1}{2} + \\frac{1}{3} + \\frac{1}{4} = \\frac{1+1+1}{2+3+4} = \\frac{3}{9}$']}