In [None]:
from dataset import get_MATH_dataset, get_GSM8k_dataset
from dataset_generator import (
    generate_completion_dataset,
    generate_corrective_dataset,
    generate_kto_dataset,
    generate_copy_dataset,
)
from transformers import (
    pipeline,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    StoppingCriteria,
    StoppingCriteriaList
)
import re

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16",
)
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_path, quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)



prompt = "{problem} \nPlease reason step by step, and put your final answer within \\boxed{{}}"
corrective_prompt = "{problem}\nHere is the incorrect solution of the problem.\n{incorrect_solution}\nHere is the answer got from the incorrect solution: {incorrect_answer}, Here is the correct answer of the problem.\n{correct_answer}\nPlease based on the incorrect solution and the correct answer, provide the correct solution of the problem by reasoning step by step and put your final answer within \\boxed{{}}"
corrective_prompt_v2 = """
Here is the problem: {problem}
Here is the answer got from the incorrect solution: {incorrect_answer}
You are a smart and talented math professor, please reason step by step, and put your final answer within \\boxed{{}} following the output format of the incorrect solution.
Output format:
\\boxed{{your answer}}
"""
dataset = get_GSM8k_dataset().select(range(2))

copy_prompt = """
Problem: {problem}
Correct Solution: {correct_solution}
Incorrect Solution: {incorrect_solution}
<End>
This is the problem and the correct and incorrect solution generated by LLM. Please generate one similar problem with a similar correct solution and a similar incorrect solution in the same format.
Do not generate any other unnecessary information and only generate these three parts: problem, correct solution, and incorrect solution, Put <End> at the end of the output.

Problem: """

class StoppingCriteriaSub(StoppingCriteria):
    def __call__(self, input_ids, scores):
        decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return decoded_text.endswith("<End>")


def extract_prompt_correct_incorrect(text):
    print('recieved text:', text)
    pattern = re.compile(
        r"\s*(?P<problem>.+?)\n"
        r"Correct Solution:\s*(?P<correct_solution>.+?)\n"
        r"Incorrect Solution:\s*(?P<incorrect_solution>.+?)\n<End>",
        re.DOTALL,
    )

    match = pattern.search(text)
    if match:
        return (
            match.group("problem"),
            match.group("correct_solution"),
            match.group("incorrect_solution"),
        )
    else:
        return None


def get_answer_from_output(text):
    try:
        result_output = re.findall(r"\\boxed\{(\d+)\}", text)
        return float(result_output[0])
    except Exception:
        return None
    
generate_kargs = {
    "max_new_tokens": 1000, 
    "do_sample": True, 
    "batch_size": 4,
    "stopping_criteria" : StoppingCriteriaList([StoppingCriteriaSub()]),
}

In [None]:
completion_dataset = generate_completion_dataset(
    pipe,
    dataset,
    prompt,
    get_answer_from_output,
    generate_kwargs=generate_kargs,
    generate_count_per_problem=1,
)
completion_dataset[:]

In [None]:
incorrect_dataset = completion_dataset.filter(lambda x: x["label"] is False)
corrective_dataset, history = generate_corrective_dataset(
    incorrect_dataset,
    corrective_prompt,
    pipe,
    get_answer_from_output,
    generate_kwargs=generate_kargs,
    corrective_solution_count_per_incorrect_solution=1,
    return_completion_history=True,
)
history

# Example of generate_copy_dataset


In [None]:
from datasets import Dataset

example_corrective_dataset = Dataset.from_dict(
    {
        "problem": ["What is the value of $\\frac{1}{2} + \\frac{1}{3}$?"],
        "correct_completion": [
            "$\\frac{1}{2} + \\frac{1}{3} = \\frac{3}{6} + \\frac{2}{6} = \\frac{5}{6}$"
        ],
        "incorrect_completion": [
            "$\\frac{1}{2} + \\frac{1}{3} = \\frac{1+1}{2+3} = \\frac{2}{5}$"
        ],
    }
)

In [None]:
copied_dataset = generate_copy_dataset(
    example_corrective_dataset,
    copy_prompt,
    pipe,
    2,
    extract_problem_correct_incorrect=extract_prompt_correct_incorrect,
    generate_kwargs=generate_kargs,
)

In [None]:
copied_dataset[:]