In [1]:
import json
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
from plot_utils import gather_metrics
from parse_levels import find_levels_in_configs, find_levels_in_configs_glob
import pandas as pd
import matplotlib.ticker as ticker
import re



In [2]:
from typing import List

def load_from_jsonl(file_name: str) -> List[dict]:
    def load_json_line(line: str, i: int, file_name: str):
        try:
            return json.loads(line)
        except:
            raise ValueError(f"Error in line {i+1}\n{line} of {file_name}")

    with open(file_name, "r", encoding="UTF-8") as f:
        data = [load_json_line(line, i, file_name) for i, line in enumerate(f)]
    return data

def save_to_jsonl(data, filename, write_mode="w"):
    with open(filename, write_mode) as file:
        for item in data:
            json_str = json.dumps(item)
            file.write(json_str + "\n")


In [None]:
# generate LLM judge prompts
model_name = "o3_mini"
with open(
    # '/home/user123456/scientist/code_analysis_with_all_versions_knowledge.json', 
    f"/home/user123456/scientist/code_analysis_with_all_versions_knowledge_{model_name}.json",
    'r'
) as f:
    record_dict_w_level = json.load(f)
prompt_version = "v4"
prompt = '''
Below is a baseline implementation of a GPT-2 model, followed by two proposed changes (see code diffs below) to improve the training speed. 
The first change is from an expert human. The second change is from an AI Assistant, aiming to reproduce the improvement made by the expert human.
Inspect the code diffs carefully and provide an objective evaluation of the AI Assistant's solution in terms of its similarity with expert human's solution. To derive an objective evaluation, first enumerate all the key changes made by expert human which can affect training speed, and then analyze all the changes made by the AI Assistant one by one. Based on understanding of these code changes, derive a percentage score (between 0 and 1) to quantify what fraction of the key changes (which has impact on training speed) made by the expert were correctly implemented in the AI Assistant's solution. 

Return your final score in a JSON object, with the key "reproducibility_score".
# =============== Baseline Implementation ===========
{human_code}
# =============== Change made by Expert Human ===========
{next_human_code}
# =============== Change made by AI Assistant ===========
{agent_code}
'''
score_prompts = []
judge_name = "DeepSeek-R1-Distill-Llama-70B"
judge_name = "DeepSeek-R1"
for method, data in record_dict_w_level.items():
    #if method != "tree":
    #    continue
    for level, records in data.items():
        for target_record, run_data in records.items():
            # print(target_record)
            if "human_code" not in run_data.keys():
                continue
            next_human_code = run_data["human_diff"] 
            if next_human_code == "":
                continue
            human_code = run_data["human_code"]
             
            human_metrics =run_data["metrics"]
            next_human_metrics =run_data["next_metrics"]
            min_train_time = 100000000
            for k, v in run_data.items():
                if "v_" in k:
                    if v["metrics"]["val_loss"] is not None:
                        if v["metrics"]["val_loss"] > 0 and v["metrics"]["val_loss"] <= 3.28 and v["metrics"]["train_time"] < min_train_time and v["metrics"]["train_time"] > 100000:
                            min_train_time = v["metrics"]["train_time"]
                            agent_speedup = human_metrics["metrics"]["train_time"] - v["metrics"]["train_time"]
                            human_speedup = human_metrics["metrics"]["train_time"] - next_human_metrics["metrics"]["train_time"]
                            if agent_speedup > human_speedup :
                                print(f"====== {target_record}, level {level} ======")
                                print(f"human: {human_metrics}")
                                print(f"next human: {next_human_metrics}")
                                print(f"agent {k}: {v["metrics"]}")
                                
      
                            agent_code = v["version_diff"]
                            score_prompt = prompt.format(human_code=human_code, next_human_code=next_human_code, agent_code=agent_code)
                            annot_data = {
                                "method": method,
                                "level": level,
                                "record": target_record,
                                "human_metrics": human_metrics,
                                "next_human_metrics": next_human_metrics,
                                "version": k,
                                "metrics": v["metrics"],
                                "model": run_data["model"],
                            }
                            if judge_name == "DeepSeek-R1-Distill-Llama-70B":
                                annot_data["text"] =  f"<｜User｜>{score_prompt}<｜Assistant｜>"
                            else:
                                annot_data["score_prompt"] = score_prompt
                                
            score_prompts.append(annot_data)

save_to_jsonl(score_prompts, f"/home/user123457/scientist/llm_judge_prompts/all_records_metadata_best_{model_name}_prompts_{prompt_version}_{judge_name}.jsonl")



human: {'job_status': 'COMPLETED', 'metrics': {'n_steps': 6200, 'val_loss': 3.2772, 'train_time': 1301740}, 'hypothesis': 'Baseline run of GPT2 124M model on FineWeb 10B dataset with default hyperparameters.', 'outcome_summary': 'The model achieves a validation loss of 3.2772, reaching under the 3.28 target validation loss.'}
next human: {'job_status': 'COMPLETED', 'metrics': {'n_steps': 5100, 'val_loss': 3.2751, 'train_time': 949528}, 'hypothesis': 'Baseline run of GPT2 124M model on FineWeb 10B dataset with default hyperparameters.', 'outcome_summary': 'The model achieves a validation loss of 3.2751, reaching under the 3.28 target validation loss.'}
agent v_11: {'val_loss': 3.2789, 'train_time': 925150}
human: {'job_status': 'COMPLETED', 'metrics': {'n_steps': 6200, 'val_loss': 3.2772, 'train_time': 1301740}, 'hypothesis': 'Baseline run of GPT2 124M model on FineWeb 10B dataset with default hyperparameters.', 'outcome_summary': 'The model achieves a validation loss of 3.2772, reachin

In [None]:
# load annoated results
data = load_from_jsonl(
    #"/home/user123457/scientist/llm_judge_prompts/all_records_metadata_best_code_diff_annot_DeepSeek-R1-Distill-Llama-70B_v3.jsonl"
    "/home/user123457/scientist/llm_judge_prompts/all_records_metadata_best_o3_mini_code_diff_annot_DeepSeek-R1-Distill-Llama-70B_v4.jsonl"
)


def extract_json_from_string(text: str) -> dict | list | None:
    """
    Extracts a JSON object or array embedded within Markdown code fences
    (specifically ```json ... ```) from a string.

    Args:
        text: The input string potentially containing the JSON in code fences.

    Returns:
        The parsed JSON object (as a dict) or array (as a list),
        or None if no valid JSON block is found or parsing fails.
    """
    # Regex to find content within ```json ... ``` block
    # - ```json : Matches the start fence literally
    # - \s*    : Matches optional whitespace (including newline) after 'json'
    # - (.*?)  : Captures the content non-greedily (*) between fences.
    #            The '.' matches any character, '?' makes it non-greedy.
    # - \s*    : Matches optional whitespace before the end fence
    # - ```    : Matches the end fence literally
    # re.DOTALL flag makes '.' match newline characters as well.
    match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)

    if match:
        json_string = match.group(1).strip() # Extract captured group and strip leading/trailing whitespace
        try:
            # Attempt to parse the extracted string as JSON
            data = json.loads(json_string)
            return data
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON: {e}")
            print(f"Extracted string was: '{json_string}'")
            return None
        except Exception as e:
            print(f"An unexpected error occurred during JSON parsing: {e}")
            return None
    else:
        # print("No ```json ... ``` block found in the string.")
        return None

for i, datum in enumerate(data):
    data[i]["record_number"] = datum["record"].split("_")[1]
    try:
        input_string = datum["vllm_output"]["output"].split("</think>")[1]
        extracted_data = extract_json_from_string(input_string)

        if extracted_data:
            # print("Successfully extracted JSON:")
            # print(extracted_data)
            data[i]["code_diff_judge"] = extracted_data
            # print(f"Correctness Score: {extracted_data.get('correctness_score')}")
    except Exception as e:
        continue


In [22]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import numpy as np
import math
import os # Import os for path joining

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# --- Data Processing Function (no changes needed) ---
def preprocess_data(data_list: list) -> pd.DataFrame:
    """
    Converts list of dicts to DataFrame, extracts scores, keeps only the row
    with min 'train_time' per group ('model','method','level','record_number'),
    adds a combined model-level group, and prepares data for plotting.
    """
    records = []
    base_keys = ['model', 'method', 'level', 'record_number']
    # score_keys = ['reproducibility_score', 'correctness_score', 'efficiency_score']
    score_keys = ['reproducibility_score']

    for i, item in enumerate(data_list):
        if not isinstance(item, dict): continue
        if not all(key in item for key in base_keys): continue

        metrics_info = item.get('metrics', {})
        judge_info = item.get('code_diff_judge', {})

        train_time = metrics_info.get('train_time') if isinstance(metrics_info, dict) else None

        scores = {key: None for key in score_keys}
        if isinstance(judge_info, dict):
            for key in score_keys: scores[key] = judge_info.get(key)

        record_num = item.get('record_number')
        model = item.get('model', 'Unknown')
        level = item.get('level', 'Unknown')

        record_data = {
            'model': model, 'method': item.get('method'), 'level': level,
            'record_number_str': str(record_num), 'record_number': record_num,
            'train_time': train_time,
        }
        record_data.update(scores)
        records.append(record_data)

    if not records: return pd.DataFrame()
    df = pd.DataFrame(records)
    logger.info(f"Initial DataFrame created with {len(df)} rows.")

    numeric_cols = ['record_number', 'train_time'] + score_keys
    for col in numeric_cols: df[col] = pd.to_numeric(df[col], errors='coerce')

    grouping_cols = ['model', 'method', 'level', 'record_number']
    selection_key = 'train_time'
    required_numeric_cols = ['record_number', selection_key]

    initial_rows = len(df)
    df.dropna(subset=required_numeric_cols, inplace=True)
    rows_dropped_nan = initial_rows - len(df)
    if rows_dropped_nan > 0: logger.info(f"Dropped {rows_dropped_nan} rows due to missing required numeric values ({required_numeric_cols}).")
    if df.empty: return df

    df['record_number'] = df['record_number'].astype(int)
    df['model'] = df['model'].astype(str)
    df['method'] = df['method'].astype(str)
    df['level'] = df['level'].astype(str) # Ensure level is string for combination

    logger.info(f"Starting deduplication based on minimum '{selection_key}'...")
    initial_rows_before_dedup = len(df)
    df.sort_values(by=grouping_cols + [selection_key], ascending=True, inplace=True)
    df.drop_duplicates(subset=grouping_cols, keep='first', inplace=True)
    rows_dropped_dedup = initial_rows_before_dedup - len(df)
    logger.info(f"Deduplication complete. Kept {len(df)} rows (dropped {rows_dropped_dedup} duplicates).")

    if df.empty: return df

    # --- Create combined group for hue ---
    df['model_level_group'] = df['model'] + ' | L' + df['level']

    # Final sort for plotting consistency (method sort handled later)
    df.sort_values(by=['record_number', 'model_level_group'], inplace=True)
    logger.info(f"Preprocessing complete. {len(df)} valid rows remaining.")
    return df


# --- Plotting Function (MODIFIED for order, legend placement, saving) ---
def plot_method_grouped_bar_charts(df: pd.DataFrame,
                                   y_col: str = 'reproducibility_score',
                                   filename: str = "method_bar_charts.pdf"):
    """
    Generates grouped bar chart subplots in a specified method order,
    with a shared y-axis range, placing the legend in an empty subplot slot,
    and saving the figure to a file.

    Args:
        df: The preprocessed DataFrame with 'model_level_group' column.
        y_col: The name of the column containing the score to plot on the y-axis.
        filename: The name (including path if needed) for the output PDF file.
    """
    if df.empty:
        logger.error("Cannot plot: DataFrame is empty.")
        return
    required_cols = ['method', 'record_number', 'model_level_group', y_col]
    if not all(col in df.columns for col in required_cols):
        logger.error(f"Cannot plot: DataFrame missing required columns ({required_cols}). Found: {df.columns.tolist()}")
        return

    # --- Define Custom Method Order ---
    method_order = ['tree', 'forest', 'flat', 'ori_aide', 'multi_aide']
    available_methods_in_df = df['method'].unique()
    # Filter and order methods based on definition and availability
    methods_to_plot = [m for m in method_order if m in available_methods_in_df]
    n_methods = len(methods_to_plot)

    if n_methods == 0:
        logger.error(f"None of the desired methods ({method_order}) found in the DataFrame.")
        return
    logger.info(f"Plotting methods in order: {methods_to_plot}")

    # --- Calculate Global Y-Limits ---
    global_ymin, global_ymax = None, None
    valid_scores = df[y_col].dropna()
    if not valid_scores.empty:
        min_score, max_score = valid_scores.min(), valid_scores.max()
        data_range = max_score - min_score
        padding = (data_range * 0.05) if (pd.notna(data_range) and data_range > 1e-6) else 0.2
        global_ymin = (min_score - padding) if pd.notna(min_score) else None
        global_ymax = (max_score + padding) if pd.notna(max_score) else None
        if pd.notna(min_score) and min_score >= 0:
             negative_offset = max(padding * 0.1, 0.02 * (max_score if pd.notna(max_score) and max_score > 0 else 1))
             if pd.notna(global_ymin): global_ymin = min(global_ymin, -negative_offset)
             else: global_ymin = -negative_offset
        ymin_str = f"{global_ymin:.2f}" if pd.notna(global_ymin) else "auto"
        ymax_str = f"{global_ymax:.2f}" if pd.notna(global_ymax) else "auto"
        logger.info(f"Calculated global Y-limits for '{y_col}': ({ymin_str}, {ymax_str})")
    else:
        logger.warning(f"Could not calculate global Y-limits: No valid data found for '{y_col}'.")

    # --- Determine subplot layout (assuming 2 rows, 3 columns for 5 plots + legend) ---
    ncols = 3
    nrows = 2 # Fixed grid to accommodate 5 plots + legend slot
    n_total_slots = nrows * ncols

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6, nrows * 5), squeeze=False, sharey=False)
    axes = axes.flatten() # Flatten to 1D array

    fig.suptitle(f'{y_col.replace("_", " ").title()} vs. Record Number (Grouped by Method, Bars by Model/Level)', fontsize=16, y=0.98)

    last_plotted_ax_index = -1 # Track the index of the last axis actually used for a plot

    # --- Iterate through methods IN CUSTOM ORDER ---
    for i, method in enumerate(methods_to_plot):
        ax = axes[i]
        last_plotted_ax_index = i # Update the last plotted index
        logger.info(f"Generating subplot {i} for method: {method}")

        method_df = df[(df['method'] == method) & (df[y_col].notna())].copy()
        method_df.sort_values(by=['record_number', 'model_level_group'], inplace=True)

        if method_df.empty:
            logger.warning(f"No valid data found for method '{method}' and score '{y_col}'. Skipping subplot content.")
            ax.set_title(f'Method: {method}\n(No Data for {y_col})')
            if pd.notna(global_ymin) and pd.notna(global_ymax):
                 ax.set_ylim(bottom=global_ymin, top=global_ymax)
            ax.text(0.5, 0.5, f"No valid '{y_col}' data", ha='center', va='center', transform=ax.transAxes)
            ax.set_xticks([])
            ax.set_yticks([])
            continue # Skip plotting for this method

        # --- Create the grouped bar plot ---
        sns.barplot(data=method_df, x='record_number', y=y_col, hue='model_level_group', ax=ax, ci=None)

        # --- Apply GLOBAL Y-Limits ---
        if pd.notna(global_ymin) and pd.notna(global_ymax):
             ax.set_ylim(bottom=global_ymin, top=global_ymax)

        # --- Customize Axes ---
        ax.set_title(f'Method: {method}')
        ax.set_ylabel(y_col.replace("_", " ").title())
        ax.set_xlabel('Record Number')
        ax.tick_params(axis='x', rotation=45)
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        ax.legend().set_visible(False) # Hide individual subplot legends

    # --- Legend Placement in Unused Slot ---
    legend_ax_index = n_methods # The index of the first unused slot
    if last_plotted_ax_index != -1 and legend_ax_index < n_total_slots:
        handles, labels = axes[last_plotted_ax_index].get_legend_handles_labels()
        if handles and labels:
            # Place legend centered within the bounding box of the unused axis
            legend_ax = axes[legend_ax_index]
            legend_ax.axis('off') # Turn off the axis frame and ticks for the legend slot
            fig.legend(handles, labels, title='Model | Level',
                       bbox_to_anchor=legend_ax.get_position(), # Anchor to the legend axis position
                       loc='center', # Center the legend within the anchor box
                       bbox_transform=fig.transFigure, # Use figure coordinates for anchor box
                       borderaxespad=0.)
            logger.info(f"Placed legend in subplot slot {legend_ax_index}")
        else:
             logger.warning(f"Could not generate legend: No handles/labels found on axis {last_plotted_ax_index}.")
             # Hide the legend axis anyway if it exists
             if legend_ax_index < n_total_slots:
                  axes[legend_ax_index].axis('off')
    elif last_plotted_ax_index == -1:
         logger.warning("Could not generate legend: No subplots contained data.")
         # Hide all potentially unused axes
         for j in range(n_total_slots): axes[j].axis('off')

    # --- Hide any other potentially unused axes beyond the legend slot ---
    for j in range(legend_ax_index + 1, n_total_slots):
         axes[j].axis('off')

    # --- Adjust Layout and Save ---
    # Use subplots_adjust for more control than tight_layout with manual legend
    plt.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.90, wspace=0.3, hspace=0.4) # Tune these values

    try:
        # Create directory if it doesn't exist
        output_dir = os.path.dirname(filename)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            logger.info(f"Created output directory: {output_dir}")

        plt.savefig(filename, bbox_inches='tight', format='pdf')
        logger.info(f"Plot successfully saved to: {filename}")
    except Exception as e:
        logger.error(f"Error saving plot to {filename}: {e}", exc_info=True)
    finally:
        plt.close(fig) # Close the figure to free memory


# --- Main Execution ---
if __name__ == "__main__":
    logger.info("Starting data processing...")
    processed_df = preprocess_data(data)

    output_filename = "methods_reproducibility_bars.pdf" # Define output filename

    if not processed_df.empty:
        print("\n" + "="*40)
        print("DEBUG: DataFrame used for grouped bar plots:")
        print(processed_df.to_string())
        print("="*40 + "\n")

        plot_method_grouped_bar_charts(processed_df,
                                       y_col='reproducibility_score',
                                       filename=output_filename)
    else:
        logger.info("No plot generated as there was no valid data after preprocessing.")

INFO:__main__:Starting data processing...
INFO:__main__:Initial DataFrame created with 479 rows.
INFO:__main__:Starting deduplication based on minimum 'train_time'...
INFO:__main__:Deduplication complete. Kept 182 rows (dropped 297 duplicates).
INFO:__main__:Preprocessing complete. 182 valid rows remaining.
INFO:__main__:Plotting methods in order: ['tree', 'forest', 'flat', 'ori_aide', 'multi_aide']
INFO:__main__:Calculated global Y-limits for 'reproducibility_score': (-0.05, 1.05)
INFO:__main__:Generating subplot 0 for method: tree

The `ci` parameter is deprecated. Use `errorbar=None` for the same effect.

  sns.barplot(data=method_df, x='record_number', y=y_col, hue='model_level_group', ax=ax, ci=None)
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
INFO:matplotlib.category:Using categorical units to plot a list of str


DEBUG: DataFrame used for grouped bar plots:
       model      method level record_number_str  record_number  train_time  reproducibility_score model_level_group
452  o3-mini        flat    12                 1              1     2188124                  0.920     o3-mini | L12
101  o3-mini      forest    12                 1              1     2183229                  0.750     o3-mini | L12
302  o3-mini  multi_aide    12                 1              1     2205406                  0.950     o3-mini | L12
235  o3-mini    ori_aide    12                 1              1     2229536                  0.980     o3-mini | L12
9    o3-mini        tree    12                 1              1     2221258                  0.700     o3-mini | L12
467  o3-mini        flat   125                 1              1     2204435                  0.400    o3-mini | L125
163  o3-mini      forest   125                 1              1     2196640                  0.100    o3-mini | L125
365  o3-mini  mult

INFO:__main__:Generating subplot 2 for method: flat

The `ci` parameter is deprecated. Use `errorbar=None` for the same effect.

  sns.barplot(data=method_df, x='record_number', y=y_col, hue='model_level_group', ax=ax, ci=None)
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or dates. If these strings should be plotted as numbers, cast to the appropriate data type before plotting.
INFO:__main__:Generating subplot 3 for method: ori_aide

The `ci` parameter is deprecated. Use `errorbar=None` for the same effect.

  sns.barplot(data=method_df, x='record_number', y=y_col, hue='model_level_group', ax=ax, ci=None)
INFO:matplotlib.category:Using categorical units to plot a list of strings that are all parsable as floats or 