In [None]:
import os
import pickle
import matplotlib.pyplot as plt

def plot_single_reward_extreme_ratios(base_dir, steps=range(1, 61), title=None):
    def compute_stats(base_dir, steps):
        zero_ratio_per_step = []
        one_ratio_per_step = []
        zero_or_one_ratio_per_step = []
        avg_reward_per_step = []

        for step in steps:
            file_path = os.path.join(base_dir, f"saved_data/rollout_data_step_{step}.pkl")
            try:
                with open(file_path, "rb") as f:
                    data = pickle.load(f)

                zero_count = 0
                one_count = 0
                total_count = 0
                reward_sum = 0
                reward_count = 0

                for key in data:
                    rewards = data[key].get("rewards", [])
                    if isinstance(rewards, list) and len(rewards) > 0:
                        mean_reward = sum(rewards) / len(rewards)
                        total_count += 1
                        reward_sum += sum(rewards)
                        reward_count += len(rewards)
                        if mean_reward == 0:
                            zero_count += 1
                        elif mean_reward == 1:
                            one_count += 1

                if total_count > 0:
                    zero_ratio = zero_count / total_count
                    one_ratio = one_count / total_count
                    zero_or_one_ratio = (zero_count + one_count) / total_count
                else:
                    zero_ratio = one_ratio = zero_or_one_ratio = float('nan')

                if reward_count > 0:
                    avg_reward = reward_sum / reward_count
                else:
                    avg_reward = float('nan')

            except Exception:
                zero_ratio = one_ratio = zero_or_one_ratio = avg_reward = float('nan')

            zero_ratio_per_step.append(zero_ratio)
            one_ratio_per_step.append(one_ratio)
            zero_or_one_ratio_per_step.append(zero_or_one_ratio)
            avg_reward_per_step.append(avg_reward)

        return zero_ratio_per_step, one_ratio_per_step, zero_or_one_ratio_per_step, avg_reward_per_step

    # 自动命名标题
    if title is None:
        title = os.path.basename(os.path.dirname(base_dir))

    # 计算
    zero_ratios, one_ratios, both_ratios, avg_rewards = compute_stats(base_dir, steps)

    # 画图
    plt.figure(figsize=(8, 4))
    plt.plot(steps, zero_ratios, marker='o', label='Mean = 0')
    plt.plot(steps, one_ratios, marker='x', label='Mean = 1')
    plt.plot(steps, both_ratios, marker='^', label='Mean = 0 or 1')
    plt.plot(steps, avg_rewards, linestyle='--', color='gray', label='Avg reward')
    plt.title(title)
    plt.xlabel("Step")
    plt.ylabel("Proportion / Avg reward")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


# plot_single_reward_extreme_ratios("/home/yifan50/rl/teacher_evaluation_data/random_1_True_model_Qwen2.5-Math-7B_dataset_deepmath_8192_epoch_30_bs_512_lr_1e-6_beta_0_entropy_0_mu_2_tau_1e-3_alpha_0.5")
# plot_single_reward_extreme_ratios("/home/yifan50/rl/teacher_evaluation_data/random_False_model_Qwen2.5-Math-7B_dataset_deepmath_8192_epoch_30_bs_512_lr_1e-6_beta_0_entropy_0_mu_2_tau_1e-3_alpha_0.5")

plot_single_reward_extreme_ratios("/home/yifan50/rl/teacher_evaluation_data/random_True_model_Qwen2.5-Math-7B_dataset_orz_9728_epoch_30_bs_512_lr_1e-6_beta_0_entropy_0_mu_2_tau_1e-3_alpha_0.5")
plot_single_reward_extreme_ratios("/home/yifan50/rl/teacher_evaluation_data/random_False_model_Qwen2.5-Math-7B_dataset_orz_9728_epoch_30_bs_512_lr_1e-6_beta_0_entropy_0_mu_2_tau_1e-3_alpha_0.5")