In [2]:
from IPython.display import display, Markdown
import polars as pl
import mistune
from tqdm import tqdm
from pprint import pprint
import mistune.renderers
import mistune.renderers.markdown
from collections import defaultdict
from typing import Literal
from datasets import load_dataset
from typing import List
import re
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from typing import Sequence

  from .autonotebook import tqdm as notebook_tqdm


In [25]:
cot_sft_dataset = pl.read_parquet("../codecontests_cot_sft_v2.parquet")

In [None]:
cot_sft_dataset

In [None]:

ModelChoice = Literal["gpt-4o", "deepseek"]

model_choice: ModelChoice = "gpt-4o"

markdown_renderer = mistune.create_markdown(renderer=None)

heading_freqs = defaultdict(int)

n_with_steps = 0


def recursive_get_all_children(element, children_list, disallowed_types=[]):
    if "type" in element and element["type"] in disallowed_types:
        return children_list
    if "children" in element:
        for child in element["children"]:
            recursive_get_all_children(child, children_list, disallowed_types)
    else:
        children_list.append(element)
    return children_list


cot_formatted_rows = []

for i in tqdm(range(len(cot_sft_dataset))):
    # display(Markdown('## Problem'))
    problem_str = cot_sft_dataset[i]["problem"][0]
    problem_name = cot_sft_dataset[i]["name"][0]
    try:
        solution_str = cot_sft_dataset[i]["completions"][0][0]
    except:
        print("warn: no solution")
        continue
    problem_md = markdown_renderer(problem_str)
    solution_md = markdown_renderer(solution_str)
    # print(solution_str)
    # display(solution_md)
    thoughts = []
    solution_code = None
    for i, element in enumerate(solution_md):
        # Deepseek has a preamble, then the steps; 4o prints a paragraph then the steps, so the first
        # list is always the steps
        if model_choice == "deepseek":
            if element["type"] == "heading":
                heading_text = element["children"][0]["raw"].lower()
                has_steps = False
                required_phrases = ["reasoning", "steps", "approach"]

                if (
                    any(phrase in heading_text for phrase in required_phrases)
                    and "code" not in heading_text
                ):
                    has_steps = True
                if has_steps:
                    n_with_steps += has_steps
                    heading_freqs[heading_text] += 1
                    steps_list_idx = i + 1
                    if solution_md[i + 1]["type"] == "blank_line":
                        steps_list_idx = i + 2
                    steps_list_element = solution_md[steps_list_idx]
                    for step in steps_list_element["children"]:
                        # 0 is the prefix + :
                        if "children" not in step or len(step["children"]) < 2:
                            print(step)
                            continue
                        bullet_point_text = step["children"][1]
                        if "children" not in bullet_point_text:
                            print(bullet_point_text)
                            continue
                        for sub_step in bullet_point_text["children"]:
                            for sub_sub_step in sub_step["children"]:
                                all_text = recursive_get_all_children(sub_sub_step, [])
                                if any(["raw" not in x for x in all_text]):
                                    print(all_text)
                                    continue
                                all_text_str = "".join([t["raw"] for t in all_text])
                                thoughts.append(all_text_str)
        elif model_choice == "gpt-4o":
            if element["type"] == "list":
                # pprint(element)
                for child in element["children"]:
                    for sub_child in child["children"]:
                        all_text = recursive_get_all_children(sub_child, [], ["strong"])
                        all_text = [t["raw"] for t in all_text if "raw" in t]
                        all_text = [t.lstrip(": ") for t in all_text]
                        all_text = [t for t in all_text if len(t) > 0 and t != "\n"]
                        all_text_str = " ".join([t for t in all_text])
                        if len(all_text_str) > 0:
                            thoughts.append(all_text_str)
        if element["type"] == "block_code" and solution_code is None:
            if "raw" not in element:
                print(element)
                continue
            solution_code = element["raw"]

        if solution_code is not None and len(thoughts) > 0:
            cot_formatted_rows.append(
                {
                    "problem": problem_str,
                    "code": solution_code,
                    "thoughts": thoughts,
                    "problem_name": problem_name,
                }
            )
            break


In [None]:
cot_formatted_rows

In [30]:
out_rows_pl = pl.DataFrame(cot_formatted_rows)
out_rows_pl.write_parquet("codecontests_cot_sft_formatted_thoughts_v2_gpt.parquet")

In [None]:
cot_formatted_rows

In [None]:


def format_codecontests_row_sft(row: dict) -> Sequence[ChatCompletionMessageParam]:
    problem, code, thoughts = row["problem"], row["code"], row["thoughts"]

    thoughts_list = [f"<thought>{t}</thought>" for t in thoughts]
    thoughts_str = "\n".join(thoughts_list)
    code_str = f"<solution>{code}</solution>"
    conv: Sequence[ChatCompletionMessageParam] = [
        {
            "role": "user",
            "content": f"Solve the following programming problem in Python.\n{problem}",
        },
        {
            "role": "assistant",
            "content": f"{thoughts_str}\n\n{code_str}",
        },
    ]
    return conv


display(pprint(format_codecontests_row_sft(cot_formatted_rows[0])))

In [None]:
conv_out = []
for row in tqdm(cot_formatted_rows):
    conv_out.append(
        {
            "conversation": format_codecontests_row_sft(row),
            "problem_name": row["problem_name"],
            "source": "codecontests",
        }
    )
out_rows_pl = pl.DataFrame(conv_out)

In [None]:
display(out_rows_pl)
out_rows_pl.write_parquet("codecontests_cot_sft_formatted_thoughts_conversations.parquet")

In [6]:
openo1_sft = load_dataset("O1-OPEN/OpenO1-SFT")['train']

In [None]:
def extract_code_block(msg: str) -> List[str]:
    match_pattern = r"```(\w+)?\n(.*?)```"
    blocks = re.findall(match_pattern, msg, re.DOTALL)

    if len(blocks) > 0:
        blocks = [block[1] for block in blocks]

    return blocks


def _get_all_within_tag(tag_content: str, text: str) -> str:
    pattern = rf"<{tag_content}>(.*?)</{tag_content}>"
    results = re.findall(pattern, text, re.DOTALL)
    if len(results) == 0:
        return ""
    if len(results) > 1:
        return results[-1]
    return results[0]


def format_o1_row(problem, thoughts, code) -> Sequence[ChatCompletionMessageParam]:
    thoughts_list = [f"<thought>{t}</thought>" for t in thoughts]
    thoughts_str = "\n".join(thoughts_list)
    code_str = f"<solution>{code}</solution>"
    conv: Sequence[ChatCompletionMessageParam] = [
        {"role": "user", "content": problem},
        {
            "role": "assistant",
            "content": f"{thoughts_str}\n\n{code_str}",
        },
    ]
    return conv


rows_out = []


def _check_chinese(text: bytes) -> bool:
    results = re.findall(r"[\u4e00-\u9fff]+", text.decode("utf-8"))
    return len(results) > 0


stats = defaultdict(int)

for i, row in tqdm(enumerate(openo1_sft)):
    output = row["output"]  # type: ignore
    if "```" not in output:
        stats["no_code_block"] += 1
    problem = row["instruction"]  # type: ignore
    if _check_chinese(problem.encode("utf-8")) or _check_chinese(
        output.encode("utf-8")
    ):
        stats["chinese_found"] += 1
        continue

    thought = _get_all_within_tag("Thought", output).split("\n\n")
    output = _get_all_within_tag("Output", output)
    output = extract_code_block(output)

    thought = [t.strip() for t in thought if t != ""]
    if len(output) == 0:
        stats["no_output_or_thought"] += 1
        continue
    elif len(thought) == 0:
        stats["no_thought"] += 1
        continue
    if len(output) > 1:
        stats["multiple_outputs"] += 1
        continue
    output = output[0]
    row = format_o1_row(problem, thought, output)
    rows_out.append(
        {
            "conversation": row,
            "problem_name": f"openo1_sft_{i}",
            "source": "openo1_sft",
        }
    )

out_rows_pl = pl.DataFrame(rows_out)
print(stats)

In [None]:
out_rows_pl.write_parquet("openo1_sft_formatted_thoughts_conversations.parquet")

In [4]:
# DPO

dpo_dataset = load_dataset("jondurbin/py-dpo-v0.1")['train']
dpo_dataset_pl = dpo_dataset.to_polars()

In [5]:
dpo_dataset_pl

prompt,chosen,rejected,id
str,str,str,str
"""Use the function to debug the …","""One possible solution to preve…","""def debug_program(arr):  n …","""8c94f83f-6a5a-5f8c-98a2-e242d7…"
"""Write an algorithm in Python t…","""Here is the algorithm to deter…","""def is_prime(n):  # Chec…","""9d7911ee-5778-5009-8fc3-ee297f…"
"""Compose a function named avera…","""Here is the implementation of …","""def average_list(lst):  if …","""01a2d265-9f76-54f7-aa77-066c61…"
"""Rewritten Test: Write a functi…","""Here is a possible implementat…","""Here is a function that delete…","""3dfc33c0-5c2d-524e-b2d5-afd356…"
"""Write a program that extracts …","""Here is a corrected implementa…","""```python import re def extra…","""776d9f46-669e-52e6-8e6b-70944f…"
…,…,…,…
"""Imagine you are a spy trying t…","""Sure, I can help you with that…","""To solve this puzzle, we can c…","""3855e084-90c3-54f0-a58f-872bc4…"
"""Can you design a program in Py…","""Yes, it is possible to design …","""Yes, I can design a program in…","""644979b5-5286-58ad-a521-705023…"
"""How can I use Python and the p…","""Certainly! Here's an example c…","""To generate a comprehensive re…","""a704864d-af10-5b5c-9e30-f1df9d…"
"""Utilize the pre-existing Pytho…","""Yes, I can help you with that.…","""To generate a random sequence …","""a40a2106-21c5-5d13-909d-51d1a3…"
