In [11]:
from datasets import load_dataset
temp=load_dataset('/mnt/data/yyh/ChatLearn/dataset/geo3k/')['train']
data=[temp[tt] for tt in range(2100)]

In [12]:
data[0] 

{'data_source': 'hiyouga/geometry3k',
 'prompt': [{'role': 'user',
   'content': '<image>Find x. You FIRST think about the reasoning process as an internal monologue and then provide the final answer. The reasoning process MUST BE enclosed within <think> </think> tags. The final answer MUST BE put in \\boxed{}.'}],
 'images': [<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=250x258>],
 'ability': 'math',
 'reward_model': {'style': 'rule', 'ground_truth': '3'},
 'extra_info': {'split': 'train',
  'index': 0,
  'answer': '3',
  'question': '<image>Find x.'}}

In [13]:
import torch
import numpy as np 

def get_index(temp):
    return temp['extra_info']['index']

def get_answer(temp):
    return temp['extra_info']['answer']

def get_input(temp):
    return temp['images'][0], temp['prompt'][0]['content']

import re

from mathruler.grader import extract_boxed_content, grade_answer


def format_reward(predict_str: str) -> float:
    pattern = re.compile(r"<think>.*</think>.*\\boxed\{.*\}.*", re.DOTALL)
    match_result = re.fullmatch(pattern, predict_str)
    return 1.0
    #return 1.0 if match_result else 0.0


def acc_reward(predict_str: str, ground_truth: str, use_boxed: bool = True) -> float:
    if use_boxed:
        answer = extract_boxed_content(predict_str)
    else:
        answer = predict_str
    return 1.0 if grade_answer(answer, ground_truth) else 0.0


def get_reward(predict_str: str, ground_truth: str, use_boxed: bool = True, format_score: float = 0.1) -> float:
    return acc_reward(predict_str, ground_truth, use_boxed) # + format_score * format_reward(predict_str)

def get_gain(reward_list, temp=0.25):
    if isinstance(reward_list, torch.Tensor):
        reward_array = reward_list.cpu().numpy()
    elif isinstance(reward_list, list):
        reward_array = np.array(reward_list, dtype=np.float64)
    elif isinstance(reward_list, np.ndarray):
        reward_array = reward_list
    else:
        raise TypeError(
            f"Unsupported type for reward_list: {type(reward_list)}. Expected torch.Tensor, list, or np.ndarray."
        )

    reward_array_ct = reward_array - reward_array.max()
    reward_array_ctt = reward_array_ct / temp
    exp_rewards = np.exp(reward_array_ctt)
    soft_dist = exp_rewards / exp_rewards.sum()
    return np.sum(soft_dist * reward_array) - reward_array.mean()

In [14]:
import base64
from PIL import Image
from io import BytesIO
from openai import OpenAI


def student(img,query):
    openai_api_key = "EMPTY"
    openai_api_base = "http://localhost:8000/v1"
    client = OpenAI(
        api_key=openai_api_key,
        base_url=openai_api_base,
    )

    SYSTEM_PROMPT='''You are a helpful assistant.'''
    buffer = BytesIO()
    img.save(buffer, format="PNG")  # 保存为PNG格式到内存
    img_bytes = buffer.getvalue()   # 获取字节数据

    encoded_image_text = base64.b64encode(img_bytes).decode('utf-8')
    #encoded_image_text = encoded_image.decode("utf-8")
    base64_qwen = f"data:image;base64,{encoded_image_text}"

    chat_response = client.chat.completions.create(
        model="/mnt/data/yyh/RLdistill/models/geo3k_3B/v7/",
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": base64_qwen
                        },
                    },
                    {"type": "text", "text": query},
                ],
            },
        ],
        temperature=0.8,
        n=8,
        max_tokens=2048 
    )

    return [chat_response.choices[i].message.content for i in range(8)]

def teacher(img,query):

    buffer = BytesIO()
    img.save(buffer, format="PNG")  # 保存为PNG格式到内存
    img_bytes = buffer.getvalue()   # 获取字节数据

    encoded_image_text = base64.b64encode(img_bytes).decode('utf-8')
    #encoded_image_text = encoded_image.decode("utf-8")
    base64_qwen = f"data:image;base64,{encoded_image_text}"
    
    api_key="sk-"

    client = OpenAI(
        api_key=api_key,
        base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    )
    
    completion = client.chat.completions.create(
        model="qwen-vl-max-latest",
        messages=[{
            "role": "user",
            "content": [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": base64_qwen
                    }
                },
                {
                    "type": "text", 
                    "text": query
                }
            ]
        }],
        temperature=0.5,
        max_tokens=1000,
        stop=["</score>"]  # ← 这里设置 stop tokens
    )

    return completion.choices[0].message.content

In [15]:
'''
pool={}

from tqdm import tqdm
for dd in tqdm(data):
    img,query=get_input(dd)
    candidates=student(img,query)
    ans=get_answer(dd)
    rewards=[get_reward(candidates[kk],ans) for kk in range(8)]
    idx=get_index(dd)
    gain_1=get_gain(rewards)
    
    if gain_1>=0.5:
        pool[idx]=''
        print('gain_1',gain_1)
        continue
    if sum(rewards)>0:
        continue
        
    ask_hint=f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model’s candidate answer with reasoning: {candidates[0]}

    Analyze the student’s reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~90% likely to self-correct and produce the right answer, and ~10% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    hint=teacher(img,ask_hint)
    
    hint_query=f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates=student(img,hint_query)
    rewards=[get_reward(candidates[kk],ans) for kk in range(8)]
    gain_2=get_gain(rewards)
    
    if gain_2>=0.5:
        pool[idx]=hint
        print('gain_2',gain_2)
        continue
    if sum(rewards)>0:
        continue
        
    ask_hint=f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model’s candidate answer with reasoning: {candidates[0]}
    Previous hint (invalid): {hint}

    Analyze the student’s reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~99% likely to self-correct and produce the right answer, and ~1% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    hint=teacher(img,ask_hint)
    
    hint_query=f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates=student(img,hint_query)
    rewards=[get_reward(candidates[kk],ans) for kk in range(8)]
    gain_3=get_gain(rewards)
    
    if gain_3>=0.5:
        pool[idx]=hint
        print('gain_3',gain_4)
        continue
    if sum(rewards)>0:
        continue
        
    ask_hint=f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model’s candidate answer with reasoning: {candidates[0]}
    Previous hint (invalid): {hint}

    Analyze the student’s reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~100% likely to self-correct and produce the right answer, and ~0% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    hint=teacher(img,ask_hint)
    
    hint_query=f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates=student(img,hint_query)
    rewards=[get_reward(candidates[kk],ans) for kk in range(8)]
    gain_4=get_gain(rewards)
    
    if gain_4>=0.5:
        pool[idx]=hint
        print('gain_4',gain_4)
        continue
    if sum(rewards)>0:
        continue
    print('Useless')
'''

  0%|          | 1/2100 [00:04<2:42:00,  4.63s/it]

gain_1 0.761360223541884


  0%|          | 1/2100 [00:06<3:54:41,  6.71s/it]

KeyboardInterrupt



In [16]:
visited={}
pool = {}

In [17]:
import concurrent.futures
from tqdm import tqdm
import threading

def process_item(dd):
    
    idx = get_index(dd)
    
    if visited.get(idx):
        if pool.get(idx):
            return idx, pool[idx]
        else:
            return idx, None
    else:
        visited[idx]=''
    
    img, query = get_input(dd)
    
    candidates = student(img, query)
    ans = get_answer(dd)
    rewards = [get_reward(candidates[kk], ans) for kk in range(8)]
    
    gain_1 = get_gain(rewards)
    
    if gain_1 >= 0.1:
        return idx, ''
    if sum(rewards) > 0:
        return idx, None
        
    ask_hint = f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model's candidate answer with reasoning: {candidates[0]}

    Analyze the student's reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~30% likely to self-correct and produce the right answer, and ~70% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    try:
        hint = teacher(img, ask_hint)
    except:
        return idx, None
    
    hint_query = f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates = student(img, hint_query)
    rewards = [get_reward(candidates[kk], ans) for kk in range(8)]
    gain_2 = get_gain(rewards)
    
    if gain_2 >= 0.1:
        return idx, hint
    if sum(rewards) > 0:
        return idx, None
        
    ask_hint = f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model's candidate answer with reasoning: {candidates[0]}
    Previous hint (invalid): {hint}

    Analyze the student's reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~80% likely to self-correct and produce the right answer, and ~20% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    try:
        hint = teacher(img, ask_hint)
    except:
        return idx, None
    
    hint_query = f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates = student(img, hint_query)
    rewards = [get_reward(candidates[kk], ans) for kk in range(8)]
    gain_3 = get_gain(rewards)
    
    if gain_3 >= 0.1:
        return idx, hint
    if sum(rewards) > 0:
        return idx, None
        
    ask_hint = f"""
    Given the following:

    Original query: {query}
    Reference answer: {ans}
    Student model's candidate answer with reasoning: {candidates[0]}
    Previous hint (invalid): {hint}

    Analyze the student's reasoning process and identify where it goes wrong or becomes misleading. Generate a strong, direct hint that corrects or redirects the flawed reasoning — without revealing the final answer. The hint must be appendable verbatim to the original query, contain NO irrelevant text or answer leakage, and when added, should make the student model ~99% likely to self-correct and produce the right answer, and ~1% likely to still fail — calibrated to encourage reasoning repair, not answer copying. Output ONLY the hint text, nothing else.
    """
    try:
        hint = teacher(img, ask_hint)
    except:
        return idx, None
    
    hint_query = f"""
    ### Query
    {query}
    ### Hint
    {hint}
    """
    candidates = student(img, hint_query)
    rewards = [get_reward(candidates[kk], ans) for kk in range(8)]
    gain_4 = get_gain(rewards)
    
    if gain_4 >= 0.1:
        return idx, hint
    if sum(rewards) > 0:
        return idx, None
    
    return idx, None

with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor:
    # Submit all tasks
    future_to_data = {executor.submit(process_item, dd): dd for dd in data}
    
    # Process completed tasks with progress bar
    for future in tqdm(concurrent.futures.as_completed(future_to_data), total=len(data)):
        idx, hint = future.result()
        if hint is not None:
            pool[idx] = hint


100%|██████████| 2100/2100 [28:11<00:00,  1.24it/s]


In [20]:
# 创建一个函数来修改特定样本
def modify_prompt(example,idx):
    origin_content=get_input(example)[1]
    if pool.get(get_index(example)):
        if len(pool[get_index(example)])>0:
            example['prompt'][0]['content']=f"""### Query
{origin_content}
### Hint
{pool[get_index(example)]}
            """
    return example

def filter_condition(example, idx):
    if pool.get(get_index(example)):
        return True
    return False  # 这里替换成您的实际条件

# 过滤数据集
from datasets import load_dataset
dataset=load_dataset('/mnt/data/yyh/ChatLearn/dataset/geo3k/')['train']

modified_dataset = dataset.map(modify_prompt, with_indices=True)

modified_dataset = modified_dataset.filter(filter_condition, with_indices=True)

from datasets import Features, Sequence, Value, Image

# 定义正确的 features 结构（重点：用 Sequence 替代 List）
features = Features(
    {
        "data_source": Value("string"),
        "prompt": [{"role": Value("string"), "content": Value("string")}],
        # images 原来是 List(Image())，这里改成 Sequence(Image())
        "images": Sequence(Image()),
        "ability": Value("string"),
        "reward_model": {
            "style": Value("string"),
            "ground_truth": Value("string"),
        },
        "extra_info": {
            "split": Value("string"),
            "index": Value("int64"),
            "answer": Value("string"),
            "question": Value("string"),
        },
    }
)
# 强制转换类型
modified_dataset = modified_dataset.cast(features)

modified_dataset.to_parquet("/mnt/data/yyh/ChatLearn/dataset/CVPR/geo3k_3B/v8/train.parquet")

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

16083189

In [21]:
len(modified_dataset)

676