In [None]:
import json
import wandb
import os
from ipywidgets import widgets, VBox, HBox, Button, Textarea, Label, BoundedIntText, ToggleButton
from IPython.display import display, clear_output, Javascript
from datetime import datetime
from IPython import get_ipython

from ipywidgets import Layout

# gpt-5
# gpt-4o-realtime-preview
# gpt-4o-mini-realtime-preview
# gpt-5-chat-latest
used_model="gpt-5-chat-latest"

llm_output_box = Textarea(
    description='LLM output:',
    layout=Layout(width='80%', height='220px', flex='1 1 auto'),
    disabled=False
)

# ===== 文件路径 =====
# LLM4Molecule_curated/SFT_Qwen2.5-7B/epoch_6_lr5e-5_2/test_preds
path="./"
input_file = path + "RL_need_curate_data.jsonl"
curated_file = path + "RL_curated.jsonl"
log_file = path + "RL_record.log"
config_file = path + "RL_config.json"
LLM_LOG_FILE = path + "RL_record_LLM.log"   # 新增：LLM 日志

# ===== 加载 config 状态 =====
if os.path.exists(config_file):
    with open(config_file, 'r', encoding='utf-8') as f:
        config = json.load(f)
    current_index = config.get("current_index", 0)
    wandb_run_id = config.get("wandb_run_id", wandb.util.generate_id())
else:
    current_index = 0
    wandb_run_id = wandb.util.generate_id()

# ===== 初始化 wandb =====
wandb.init(
    project="molecule-curation",
    name="curation-session",
    entity="byfrfy",
    id=wandb_run_id,
    resume="allow"
)
        

def save_config():
    with open(config_file, 'w', encoding='utf-8') as f:
        json.dump({
            "current_index": current_index,
            "wandb_run_id": wandb_run_id
        }, f, ensure_ascii=False, indent=2)

# ===== 加载数据 =====
with open(input_file, 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

if not os.path.exists(curated_file):
    with open(curated_file, 'w', encoding='utf-8') as f:
        for record in data:
            f.write(json.dumps(record, ensure_ascii=False) + '\n')
curated_data=[]

curated_data = []

with open(curated_file, 'r', encoding='utf-8') as f:
    for line in f:
        if line is not None and line.strip():   # 确保非空行
            try:
                curated_data.append(json.loads(line.strip()))
            except json.JSONDecodeError as e:
                print(f"JSON decode error: {e} | Line content: {line}")
            except Exception as e:
                print(f"Unexpected error: {e} | Line content: {line}")

print(f"Loaded {len(curated_data)} records from {curated_file}")

answer_list=[]
test_path="./test_data.txt"
with open(test_path, 'r', encoding='utf-8') as f:
    for line in f.readlines():
        answer, description = line.strip().split("\t")
        answer_list.append(answer)

for i in range(len(data)):
    data[i]['answer'] = answer_list[i]

modified_indices = set()
print(f"✅ 已加载 {len(data)} 条记录")

# ===== 日志函数 =====
def log_action(action_type, index, record=None):
    timestamp = datetime.now().isoformat()
    log_entry = {"timestamp": timestamp, "action": action_type, "index": index}
    if record:
        log_entry["record"] = record

    with open(log_file, 'a', encoding='utf-8') as f:
        f.write(json.dumps(log_entry, ensure_ascii=False) + '\n')

    wandb.log({
        "timestamp": timestamp,
        "action": action_type,
        "index": index,
        "text_length": len(json.dumps(record, ensure_ascii=False)) if record else 0
    })

common_layout = widgets.Layout(width='150px', height='40px')
# ===== 文本框控件 =====
# 左侧打分框（仅保存到 merged_data_all.jsonl）
edit_score_box = BoundedIntText(value=3, min=1, max=5, description='score:')

# 左侧（编辑原始数据）
edit_smile_box = Textarea(description='smile:', layout=widgets.Layout(width='90%', height='30px'), disabled=True)
edit_explain_box = Textarea(description='explain:', layout=widgets.Layout(width='90%', height='50px'), disabled=True)
edit_cot_box = Textarea(description='cot:', layout=widgets.Layout(width='90%', height='260px'))
# 左侧 reason 文本框（保存到 curated 文件）
edit_reason_box = Textarea(description='reason:', layout=widgets.Layout(width='90%', height='50px'))

# 右侧打分框（随修改保存到 curated 文件）
curated_score_box = BoundedIntText(value=5, min=1, max=5, description='score\n[curated]:',style={'description_width': '160px'})

# 右侧（只读显示保存内容）
curated_smile_box = Textarea(description='smile\n[curated]:', layout=widgets.Layout(width='90%', height='30px'), disabled=True, style={'description_width': '160px'})
curated_explain_box = Textarea(description='explain\n[curated]:', layout=widgets.Layout(width='90%', height='50px'), disabled=True, style={'description_width': '160px'})
curated_cot_box = Textarea(description='cot\n[curated]:', layout=widgets.Layout(width='90%', height='260px'), disabled=True, style={'description_width': '160px'})
curated_reason_box = Textarea(description='reason\n[curated]:', layout=widgets.Layout(width='90%', height='50px'), disabled=True, style={'description_width': '160px'})

# ===== 新增：LLM assistant 开关与输出框 =====

llm_toggle = ToggleButton(
    value=True,                      # ← 默认开启
    description='LLM assistant',
    button_style='success'           # ← 绿色
)
llm_toggle.tooltip = "选中后在切换记录时调用 LLM 生成参考"

# llm_toggle = ToggleButton(value=False, description='LLM assistant', button_style='')
# llm_toggle.tooltip = "选中后在切换记录时调用 LLM 生成参考"
# llm_output_box = Textarea(description='LLM output:', layout=widgets.Layout(width='100%', height='260px'), disabled=True)

# 控制按钮
# 全屏按钮
fullscreen_button = Button(description="全屏", button_style='')
fullscreen_button.tooltip = "切换全屏 (Ctrl+Shift+F)"

def toggle_fullscreen(_=None):
    # 用 data-widget-id 精确定位到 ui 的 DOM 元素，并切换全屏
    display(Javascript(f"""
        (function() {{
            const el = document.querySelector('[data-widget-id="{ui._model_id}"]');
            if (!el) {{ console.warn('ui root not found'); return; }}
            const target = el.closest('.jupyter-widgets-view') || el;
            if (document.fullscreenElement) {{
                document.exitFullscreen();
            }} else {{
                if (target.requestFullscreen) {{
                    target.requestFullscreen().catch(err => console.error(err));
                }} else if (target.webkitRequestFullscreen) {{
                    target.webkitRequestFullscreen();
                }}
            }}
        }})();
    """))

fullscreen_button.on_click(toggle_fullscreen)



index_label = Label()
prev_button = Button(description="Previous", button_style='info')
next_button = Button(description="Next", button_style='info')
save_button = Button(description="Save Changes", button_style='success')
export_button = Button(description="Export Records", button_style='warning')
jump_input = BoundedIntText(value=current_index+1, min=1, max=len(data), description='Jump to:')
jump_button = Button(description="Jump", button_style='primary')
show_modified_button = Button(description="📋 View Modified", button_style='warning')
export_modified_button = Button(description="📝 Export Modified", button_style='success')

# ====== LLM 相关：沿用你原来的 client + 模板 ======
# 你的静态示例 COT 与模板（保持不动）

prompt_template_head = f"""
Analyze whether the reasoning is correct. The reasoning is generated based on the given description, where the target molecule has already been specified. If the molecule derived from the reasoning matches the target molecule, it is correct. If it does not match, errors in the reasoning should be corrected so that the reasoning can generate the target molecule.

Key point 1: The goal is to generate the target molecule. If it is incorrect, the CoT should be carefully revised so that it can generate the target molecule from the description. Make revisions based on the original text and **provide a complete, copyable CoT in Markdown format directly**. Modify only the incorrect or problematic parts; do not change information that is already correct, and keep the original formatting as much as possible.

Key point 2: Then check whether the CoT before and after revision has any issues, especially in the structure and functional group descriptions, and correct them if necessary.

Key point 3: If possible, base your revisions on the original text, modifying only the incorrect or problematic parts without altering information that is already correct. You do not need to mark the changes directly in the revised CoT; instead, list them before the completed CoT.

Key point 4: every original CoT has a summarisation before giving the SMILES, you should keep every original summarisation based on the **original CoT**. For example some CoT contains: (1) Putting this together, we arrive at the correct SMILES representation for chlorobenzene:\n<answer>SMILES</answer> (2) Assembling all parts, the complete SMILES representation for benzoin is:\n<answer>SMILES</answer> (3) Now that we have a clear understanding of the structure, we can convert this into its SMILES notation:\n<answer>SMILES</answer> (4) The SMILES for 2-isopropylmalic acid is as follows:\n<answer>SMILES</answer> **You should follow the the original CoT**.


"""

# ===== OpenAI Client（沿用你给的写法；如用环境变量可自行替换） =====
from openai import OpenAI
client = OpenAI(api_key="sk-xxx")  # your API Key

# ===== 动态构造 prompt（使用当前记录的 explain/cot/目标 smile） =====
def build_prompt_with_current_record():
    """
    使用左侧编辑框内容构造 Prompt：
    - 描述  ← edit_explain_box.value
    - 推理  ← edit_cot_box.value
    - 目标  ← edit_smile_box.value
    """
    current_explain = (edit_explain_box.value or "").strip()
    current_cot = (edit_cot_box.value or "").strip()
    current_target_smile = (edit_smile_box.value or "").strip()

    dynamic_cot = (
        "Now, please start analysis:\n\n"
        "描述：\n"
        f"{current_explain}\n\n"
        "推理：\n"
        f"{current_cot}\n\n"
        "目标：\n"
        f"{current_target_smile}\n"
    )

    return f"{prompt_template_head}\n\n{dynamic_cot}"

# def call_llm_and_display():
#     """当开关选中时：在切换记录后调用 LLM，将输出显示到 llm_output_box（不写日志）。"""
#     if not llm_toggle.value:
#         return
#     try:
#         # 使用“当前记录”的动态 prompt，以满足“对应调用”的需求
#         prompt_current = build_prompt_with_current_record()
#         # print(prompt_current)
#         resp = client.responses.create(
#             model="gpt-5",
#             instructions="You are an expert of chemistry.",
#             input=prompt_current
#         )
#         out_text = resp.output_text if hasattr(resp, "output_text") else str(resp)
#         llm_output_box.value = out_text  # 只显示
#     except Exception as e:
#         llm_output_box.value = f"LLM 调用失败：{e}"
def call_llm_and_display():
    """当开关选中时：在切换记录后调用 LLM，将输出显示到 llm_output_box，并把本次查询写入 record_LLM.log。"""
    if not llm_toggle.value:
        return

    # 先构造 prompt，便于失败时也能记录
    prompt_current = build_prompt_with_current_record()

    try:
        resp = client.responses.create(
            model=used_model,  # 保持你现有配置
            instructions="You are an expert of chemistry.",
            input=prompt_current
        )
        out_text = resp.output_text if hasattr(resp, "output_text") else str(resp)

        # 显示到界面
        llm_output_box.value = out_text

        # —— 记录“每条查询”到文件（逐行 JSON）——
        llm_log = {
            "timestamp": datetime.now().isoformat(),
            "index": current_index,
            "event": "llm_query",
            "model": used_model,
            "prompt_len": len(prompt_current),
            "output_len": len(out_text),
            "prompt": prompt_current,
            "output": out_text
        }
        with open(LLM_LOG_FILE, 'a', encoding='utf-8') as f:
            f.write(json.dumps(llm_log, ensure_ascii=False) + '\n')

    except Exception as e:
        # 显示错误
        llm_output_box.value = f"LLM 调用失败：{e}"

        # —— 失败也记录到日志 —— 
        err_log = {
            "timestamp": datetime.now().isoformat(),
            "index": current_index,
            "event": "llm_query_error",
            "model": used_model,
            "error": str(e),
            "prompt_len": len(prompt_current),
            "prompt": prompt_current
        }
        with open(LLM_LOG_FILE, 'a', encoding='utf-8') as f:
            f.write(json.dumps(err_log, ensure_ascii=False) + '\n')

# ===== 记录切换/保存时的 UI 逻辑 =====

def load_record(index):
    """加载 index 记录并在需要时调用 LLM。"""
    global current_index
    current_index = index
    save_config()

    # 左侧
    record = data[index]
    edit_smile_box.value = record.get("answer", "")
    edit_explain_box.value = record.get("explain", "")
    edit_cot_box.value = record.get("cot", "")
    edit_score_box.value = record.get("curation_score", 3)
    edit_reason_box.value = curated_data[index].get("reason", "")

    # 右侧
    saved = curated_data[index]
    curated_smile_box.value = saved.get("smile", "")
    curated_explain_box.value = saved.get("explain", "")
    curated_cot_box.value = saved.get("cot", "")
    curated_score_box.value = saved.get("curation_score", 5)
    curated_reason_box.value = curated_data[index].get("reason", "")

    index_label.value = f"Current Record: {index + 1} / {len(data)}"
    jump_input.value = index + 1

    # 若开关开启，则每次切换记录时调用 LLM 并显示
    if llm_toggle.value:
        call_llm_and_display()
    else:
        llm_output_box.value = ""

def save_current(_=None):
    updated_right = {
        'smile': edit_smile_box.value,
        'explain': edit_explain_box.value,
        'cot': edit_cot_box.value,
        'curation_score': curated_score_box.value,
        'reason': edit_reason_box.value
    }
    updated_left_score = edit_score_box.value

    # 更新右侧（curated 文件）
    if curated_data[current_index].get('smile') != updated_right['smile'] or \
       curated_data[current_index].get('explain') != updated_right['explain'] or \
       curated_data[current_index].get('cot') != updated_right['cot'] or \
       curated_data[current_index].get('curation_score') != updated_right['curation_score']:
        curated_data[current_index].update(updated_right)
        modified_indices.add(current_index)

        # 写入 updated_right 到 curated_file
        with open(curated_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        lines[current_index] = json.dumps(curated_data[current_index], ensure_ascii=False) + '\n'
        with open(curated_file, 'w', encoding='utf-8') as f:
            f.writelines(lines)

    # 更新左侧分数到原始文件
    with open(input_file, 'r', encoding='utf-8') as f:
        input_lines = f.readlines()
    original_record = json.loads(input_lines[current_index])
    original_record['curation_score'] = updated_left_score
    input_lines[current_index] = json.dumps(original_record, ensure_ascii=False) + '\n'
    with open(input_file, 'w', encoding='utf-8') as f:
        f.writelines(input_lines)

    log_action("save", current_index, {
        "left_score": updated_left_score,
        "right_data": updated_right
    })

    # 若开关开启：保存时把当前显示在 llm_output_box 里的内容写入 record_LLM.log（不重新调用 LLM）
    if llm_toggle.value:
        try:
            out_text = llm_output_box.value or ""
            llm_log = {
                "timestamp": datetime.now().isoformat(),
                "index": current_index,
                "output_len": len(out_text),
                "output": out_text
            }
            with open(LLM_LOG_FILE, 'a', encoding='utf-8') as f:
                f.write(json.dumps(llm_log, ensure_ascii=False) + '\n')
        except Exception:
            pass

    clear_output(wait=True)
    print(f"✅ 已保存第 {current_index + 1} 条记录（左分数+右内容）")
    display(ui)

def prev_record(_=None):
    if current_index > 0:
        log_action("prev", current_index)
        load_record(current_index - 1)

def next_record(_=None):
    if current_index < len(data) - 1:
        log_action("next", current_index)
        load_record(current_index + 1)

def jump_to_index(_=None):
    idx = jump_input.value - 1
    if 0 <= idx < len(data):
        log_action("jump", current_index)
        load_record(idx)

def export_modified_records(_):
    if not modified_indices:
        print("ℹ️ 没有修改记录可导出")
        return
    export_file = "modified_records.jsonl"
    with open(export_file, 'w', encoding='utf-8') as f:
        for idx in sorted(modified_indices):
            f.write(json.dumps(curated_data[idx], ensure_ascii=False) + '\n')
    print(f"✅ 已导出 {len(modified_indices)} 条修改记录到 {export_file}")

def show_modified_records(_):
    clear_output(wait=True)
    if not modified_indices:
        print("ℹ️ 暂无已修改记录")
    else:
        indices = sorted(modified_indices)
        print(f"📋 已修改 {len(indices)} 条记录：{[i+1 for i in indices]}")
    display(ui)

# 绑定快捷键（仅支持 Jupyter Notebook 前端）
# def bind_shortcuts():
#     display(Javascript("""
#         document.addEventListener('keydown', function(event) {
#             if (event.ctrlKey && event.key === 'ArrowRight') {
#                 document.querySelector('button[title="下一条"]').click();
#             }
#             if (event.ctrlKey && event.key === 'ArrowLeft') {
#                 document.querySelector('button[title="上一条"]').click();
#             }
#             if (event.ctrlKey && event.key === 's') {
#                 event.preventDefault();
#                 document.querySelector('button[title="保存当前修改"]').click();
#             }
#         });
#     """))
def bind_shortcuts():
    display(Javascript("""
        document.addEventListener('keydown', function(event) {
            if (event.ctrlKey && event.key === 'ArrowRight') {
                document.querySelector('button[title="下一条"]').click();
            }
            if (event.ctrlKey && event.key === 'ArrowLeft') {
                document.querySelector('button[title="上一条"]').click();
            }
            if (event.ctrlKey && event.key === 's') {
                event.preventDefault();
                document.querySelector('button[title="保存当前修改"]').click();
            }
            // 新增：Ctrl+Shift+F 切换全屏
            if (event.ctrlKey && event.shiftKey && event.key.toLowerCase() === 'f') {
                const btn = Array.from(document.querySelectorAll('button'))
                  .find(b => b.getAttribute('title') === '切换全屏 (Ctrl+Shift+F)');
                if (btn) btn.click();
            }
        });
    """))

# 为按钮设置 title，方便 JavaScript 查找并触发
prev_button.tooltip = "上一条"
next_button.tooltip = "下一条"
save_button.tooltip = "保存当前修改"

# 绑定事件
save_button.on_click(save_current)
prev_button.on_click(prev_record)
next_button.on_click(next_record)
jump_button.on_click(jump_to_index)
export_modified_button.on_click(export_modified_records)
show_modified_button.on_click(show_modified_records)

# LLM 开关切换时的即时行为：打开即对当前记录生成；关闭则清空显示
def _on_llm_toggle_change(change):
    if change['name'] == 'value':
        if change['new']:
            call_llm_and_display()
        else:
            llm_output_box.value = ""

llm_toggle.observe(_on_llm_toggle_change)

# UI 布局（左右分栏）
left_boxes = VBox([
    edit_smile_box, 
    edit_explain_box, 
    edit_cot_box, 
    edit_score_box, 
    edit_reason_box
], layout=widgets.Layout(width='50%'))

right_boxes = VBox([
    curated_smile_box,
    curated_explain_box,
    curated_cot_box,
    curated_score_box,
    curated_reason_box,
    # llm_output_box,   # 新增：展示 LLM 输出
], layout=widgets.Layout(width='50%'))

ui = VBox([
    index_label,
    HBox([left_boxes, right_boxes]),
    HBox([prev_button, next_button, llm_toggle, save_button, export_button]),  # 新增：llm_toggle
    HBox([jump_input, jump_button]),
    HBox([show_modified_button, export_modified_button])
])
top_llm_row = HBox([llm_output_box], layout=widgets.Layout(width='100%'))

ui = VBox([
    top_llm_row,  # ← 顶部横栏：LLM 输出放这里
    index_label,
    HBox([left_boxes, right_boxes]),
    # HBox([prev_button, next_button, llm_toggle, fullscreen_button, save_button, export_button]),## 不使用全屏
    HBox([prev_button, next_button, save_button, export_button]),
    HBox([jump_input, jump_button]),
    HBox([show_modified_button, export_modified_button, llm_toggle])
])

# 初始显示
load_record(current_index)
bind_shortcuts()
display(ui)
