In [None]:
import os
import sys


sys.path.append(os.path.abspath(os.path.join(os.path.dirname("__file__"), "src")))

In [None]:
import torch
import numpy as np


# env_name = "button-press-topdown-v2"
env_name = "box-close-v2"
exp_name = "AESPA-20-00"
pair_algo = "ternary-500"
reward_model_algo = "MR-dropout"

os.environ["CUDA_VISIBLE_DEVICES"] = "6" 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAJECTORY_LENGTH = 25

In [None]:
from data_generation.data_research import test4
from data_generation.picker.mr_dropout import mr_dropout_test


data = test4(
    env_name=env_name,
    exp_name=exp_name,
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import random

def analyze_prediction_uncertainty(data, env_name="env"):
    # 1. 데이터 분리
    data = np.array(data)  # shape [T, 3]
    true_rewards = data[:, 0].squeeze()
    mean_rewards = data[:, 1].squeeze()
    std_rewards = data[:, 2].squeeze()

    # 2. uncertainty 기준 4분위수 나누기
    quantiles = np.percentile(std_rewards, [25, 50, 75])
    q1, q2, q3 = quantiles
    print(f"Quantiles: Q1={q1:.3f}, Q2={q2:.3f}, Q3={q3:.3f}")

    groups = {
        "Q1 (lowest 25%)": np.where(std_rewards <= q1)[0],
        "Q2 (25-50%)": np.where((std_rewards > q1) & (std_rewards <= q2))[0],
        "Q3 (50-75%)": np.where((std_rewards > q2) & (std_rewards <= q3))[0],
        "Q4 (highest 25%)": np.where(std_rewards > q3)[0],
    }

    # 3. pairwise acc 계산 함수
    def compute_pairwise_accuracy(true_vals, pred_vals, pair_count=100000):
        n = len(true_vals)
        correct = 0

        for _ in range(pair_count):
            i, j = random.sample(range(n), 2)
            true_cmp = true_vals[i] > true_vals[j]
            pred_cmp = pred_vals[i] > pred_vals[j]
            if true_cmp == pred_cmp:
                correct += 1

        return correct / pair_count

    # 4. plot
    plt.figure(figsize=(14, 10))

    for i, (name, indices) in enumerate(groups.items(), 1):
        tr = true_rewards[indices]
        pr = mean_rewards[indices]

        if np.all(tr == tr[0]) or np.all(pr == pr[0]):
            pcc = float('nan')
        else:
            pcc, _ = pearsonr(tr, pr)

        acc = compute_pairwise_accuracy(tr, pr)

        plt.subplot(2, 2, i)
        plt.scatter(tr, pr, alpha=0.1, label=f'PCC={pcc:.3f}, Acc={acc:.3f}')
        plt.xlabel("True Rewards")
        plt.ylabel("Predicted Rewards")
        plt.title(f"{name} (n={len(indices)})")
        plt.legend()
        plt.grid(True)

    plt.suptitle(f"Reward Prediction ({env_name}): Uncertainty Groups (Scatter, PCC, Pairwise Accuracy)", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

    # 5~7: 통합 서브플롯
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))

    # 서브플롯 #5: Predicted vs Uncertainty
    axes[0].scatter(mean_rewards, std_rewards, alpha=0.1)
    axes[0].set_xlabel("Predicted Rewards")
    axes[0].set_ylabel("Predicted Uncertainty (Std)")
    axes[0].set_title("Predicted vs. Uncertainty")
    axes[0].grid(True)

    # 샘플 추출
    sample_size = 2000
    total_samples = len(true_rewards)
    if sample_size >= total_samples:
        sample_indices = np.arange(total_samples)
    else:
        sample_indices = np.random.choice(total_samples, size=sample_size, replace=False)

    tr_sample = true_rewards[sample_indices]
    pr_sample = mean_rewards[sample_indices]
    std_sample = std_rewards[sample_indices]

    # 서브플롯 #6: True vs Pred (색: uncertainty)
    sc = axes[1].scatter(tr_sample, pr_sample, c=std_sample, cmap="viridis", alpha=0.6)
    axes[1].set_xlabel("True Rewards")
    axes[1].set_ylabel("Predicted Rewards")
    axes[1].set_title("True vs. Predicted (Color = Uncertainty)")
    axes[1].grid(True)
    fig.colorbar(sc, ax=axes[1], label="Predicted Uncertainty (Std)")

    # 서브플롯 #7: True vs Uncertainty
    axes[2].scatter(true_rewards, std_rewards, alpha=0.1)
    axes[2].set_xlabel("True Rewards")
    axes[2].set_ylabel("Predicted Uncertainty (Std)")
    axes[2].set_title("True vs. Uncertainty")
    axes[2].grid(True)

    plt.suptitle(f"Reward Model Prediction({env_name}): Uncertainty Relationships", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

    

In [None]:
analyze_prediction_uncertainty(data, env_name=env_name)

In [None]:
from data_loading.load_data import load_pair

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from random import sample

def generate_confident_traj_pairs(all_trajs, top_k=100000,num_buckets=20):
    z = 10
    confident_pairs = []

    # Step 1: 정렬 & bucket 나누기 (mean 기준)
    sorted_trajs = sorted(enumerate(all_trajs), key=lambda x: x[1][2])  # sort by mean_reward
    total = len(sorted_trajs)

    buckets = [
        sorted_trajs[ total * i // num_buckets : total * (i + 1) // num_buckets ]
        for i in range(num_buckets)
    ]

    # Step 2: bucket 간 confident pair 탐색
    pairs_per_bucket = top_k // ((num_buckets * (num_buckets - 1)) // 2)

    for i in range(num_buckets):
        for j in range(i + 1, num_buckets):
            trajs_i = buckets[i]
            trajs_j = buckets[j]
            local_confident_pairs = []

            for idx_i, traj_i in trajs_i:
                mu_i, std_i = traj_i[2], traj_i[3]
                upper_i = mu_i + z * std_i

                for idx_j, traj_j in trajs_j:
                    mu_j, std_j = traj_j[2], traj_j[3]
                    lower_j = mu_j - z * std_j

                    if upper_i < lower_j:
                        # j가 i보다 확실히 우수 → pair 추가
                        local_confident_pairs.append((traj_i, traj_j))

            if len(local_confident_pairs) > pairs_per_bucket:
                confident_pairs.extend(sample(local_confident_pairs, pairs_per_bucket))
            else:
                confident_pairs.extend(local_confident_pairs)

            if len(confident_pairs) >= top_k:
                return confident_pairs[:top_k]

    return confident_pairs[:top_k]


def plot_traj_scatter_by_uncertainty_bins(all_trajs, num_bins=6):
    uncertainties = np.array([traj[3] for traj in all_trajs])
    bins = np.percentile(uncertainties, np.linspace(0, 100, num_bins + 1))
    vmin, vmax = uncertainties.min(), uncertainties.max()

    fig = plt.figure(figsize=(18, 9))
    gs = gridspec.GridSpec(2, 4, width_ratios=[1, 1, 1, 0.05])  # 마지막 column은 colorbar

    axes = [fig.add_subplot(gs[i // 3, i % 3]) for i in range(num_bins)]

    for i in range(num_bins):
        bin_lower = bins[i]
        bin_upper = bins[i + 1]
        ax = axes[i]

        bin_trajs = [traj for traj in all_trajs if bin_lower <= traj[3] <= bin_upper]
        true_rewards = [traj[1] for traj in bin_trajs]
        predicted_rewards = [traj[2] for traj in bin_trajs]
        stds = [traj[3] for traj in bin_trajs]

        sc = ax.scatter(
            true_rewards,
            predicted_rewards,
            c=stds,
            cmap='viridis',
            alpha=0.4,
            vmin=vmin,
            vmax=vmax,
        )

        ax.set_title(f"Bin {i+1}: [{bin_lower:.3f}, {bin_upper:.3f}]")
        ax.set_xlabel("True Reward")
        ax.set_ylabel("Predicted Reward")
        ax.grid(True)

    # 컬러바는 gridspec의 마지막 열에 따로 넣기
    cbar_ax = fig.add_subplot(gs[:, 3])
    cbar = fig.colorbar(sc, cax=cbar_ax)
    cbar.set_label("Uncertainty (std)")

    fig.suptitle("Trajectory-wise Reward Scatter by Uncertainty Bins", fontsize=16)
    plt.tight_layout(rect=[0, 0, 0.96, 0.95])
    plt.show()

def eval_test(data, env_name):
    test_feedbacks = load_pair(
        env_name=env_name,
        exp_name=exp_name,
        pair_type="test",
        pair_algo="ternary-100000",
    )

    data = np.array(data)  # shape [T, 3]
    true_rewards = data[:, 0].squeeze()
    mean_rewards = data[:, 1].squeeze()
    std_rewards = data[:, 2].squeeze()
    var_rewards = std_rewards ** 2

    true_rewards_cum = np.cumsum(true_rewards, dtype=np.float64)
    mean_rewards_cum = np.cumsum(mean_rewards, dtype=np.float64)
    var_rewards_cum = np.cumsum(var_rewards, dtype=np.float64)

    correct_uncertainties = []
    incorrect_uncertainties = []

    all_trajs = []

    for i, (idx0, idx1, _) in enumerate(test_feedbacks):
        s0, e0 = idx0
        s1, e1 = idx1

        true_reward_sum_0 = true_rewards_cum[e0 - 1] - (true_rewards_cum[s0 - 1] if s0 > 0 else 0)
        true_reward_sum_1 = true_rewards_cum[e1 - 1] - (true_rewards_cum[s1 - 1] if s1 > 0 else 0)

        mean_reward_sum_0 = mean_rewards_cum[e0 - 1] - (mean_rewards_cum[s0 - 1] if s0 > 0 else 0)
        mean_reward_sum_1 = mean_rewards_cum[e1 - 1] - (mean_rewards_cum[s1 - 1] if s1 > 0 else 0)

        var_rewards_sum_0 = var_rewards_cum[e0 - 1] - (var_rewards_cum[s0 - 1] if s0 > 0 else 0)
        var_rewards_sum_1 = var_rewards_cum[e1 - 1] - (var_rewards_cum[s1 - 1] if s1 > 0 else 0)
        std_reward_sum_0 = np.sqrt(var_rewards_sum_0)
        std_reward_sum_1 = np.sqrt(var_rewards_sum_1)

        true_reward_bigger = true_reward_sum_0 < true_reward_sum_1
        mean_reward_bigger = mean_reward_sum_0 < mean_reward_sum_1

        if true_reward_bigger != mean_reward_bigger:
            incorrect_uncertainties.append(std_reward_sum_0)
            incorrect_uncertainties.append(std_reward_sum_1)
        else:
            correct_uncertainties.append(std_reward_sum_0)
            correct_uncertainties.append(std_reward_sum_1)
        
        all_trajs.append(((s0, e0), true_reward_sum_0, mean_reward_sum_0, std_reward_sum_0))
        all_trajs.append(((s1, e1), true_reward_sum_1, mean_reward_sum_1, std_reward_sum_1))

    correct_uncertainties = np.array(correct_uncertainties)
    incorrect_uncertainties = np.array(incorrect_uncertainties)

    print (f"Correct Uncertainties: {len(correct_uncertainties)}")
    print (f"Incorrect Uncertainties: {len(incorrect_uncertainties)}")
    print (f"Mean Correct Uncertainties: {np.mean(correct_uncertainties):.3f}")
    print (f"Mean Incorrect Uncertainties: {np.mean(incorrect_uncertainties):.3f}")

    # true reward가 양수인 trajectory만 필터링
    filtered_trajs = [traj for traj in all_trajs if traj[1] >= 0]

    # 상위 100개 certainty(low uncertainty) 추출
    all_trajs_sorted = sorted(filtered_trajs, key=lambda x: x[3])
    top_100_trajs = all_trajs_sorted[:100]

    from itertools import combinations
    total = 0
    correct = 0
    for traj1, traj2 in combinations(top_100_trajs, 2):
        _, true1, mean1, _ = traj1
        _, true2, mean2, _ = traj2
        if (true1 > true2) and (mean1 > mean2):
            correct += 1
        elif (true1 < true2) and (mean1 < mean2):
            correct += 1
        elif (true1 == true2) and (mean1 == mean2):
            correct += 1

        total += 1

    print(f"[Top-100 Traj Pair Accuracy] {correct}/{total} = {correct/total:.3f}")

    # 조건: 예측 보상 차이가 불확실성 합의 n배보다 큰 경우만 필터링
    confident_feedbacks = []

    for i, (idx0, idx1, _) in enumerate(test_feedbacks):
        s0, e0 = idx0
        s1, e1 = idx1

        true_reward_sum_0 = true_rewards_cum[e0 - 1] - (true_rewards_cum[s0 - 1] if s0 > 0 else 0)
        true_reward_sum_1 = true_rewards_cum[e1 - 1] - (true_rewards_cum[s1 - 1] if s1 > 0 else 0)

        mean_reward_sum_0 = mean_rewards_cum[e0 - 1] - (mean_rewards_cum[s0 - 1] if s0 > 0 else 0)
        mean_reward_sum_1 = mean_rewards_cum[e1 - 1] - (mean_rewards_cum[s1 - 1] if s1 > 0 else 0)

        var_rewards_sum_0 = var_rewards_cum[e0 - 1] - (var_rewards_cum[s0 - 1] if s0 > 0 else 0)
        var_rewards_sum_1 = var_rewards_cum[e1 - 1] - (var_rewards_cum[s1 - 1] if s1 > 0 else 0)
        std_reward_sum_0 = np.sqrt(var_rewards_sum_0)
        std_reward_sum_1 = np.sqrt(var_rewards_sum_1)

        reward_diff = abs(mean_reward_sum_0 - mean_reward_sum_1)
        uncertainty_sum = std_reward_sum_0 + std_reward_sum_1

        if reward_diff > 10 * uncertainty_sum:
            confident_feedbacks.append((
                true_reward_sum_0, true_reward_sum_1,
                mean_reward_sum_0, mean_reward_sum_1
            ))

    # 정확도 계산
    total = len(confident_feedbacks)
    correct = 0
    for true0, true1, pred0, pred1 in confident_feedbacks:
        if (true0 > true1 and pred0 > pred1) or \
           (true0 < true1 and pred0 < pred1) or \
           (true0 == true1 and pred0 == pred1):
            correct += 1

    if total > 0:
        print(f"[Confident Pair Accuracy] {correct}/{total} = {correct/total:.3f}")
    else:
        print("No confident pairs found with the given threshold.")

    # Bradley-Terry 모델을 사용하여 승리 확률 계산 (Baseline)
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    bt_confident_feedbacks = []

    for i, (idx0, idx1, _) in enumerate(test_feedbacks):
        s0, e0 = idx0
        s1, e1 = idx1

        true_reward_sum_0 = true_rewards_cum[e0 - 1] - (true_rewards_cum[s0 - 1] if s0 > 0 else 0)
        true_reward_sum_1 = true_rewards_cum[e1 - 1] - (true_rewards_cum[s1 - 1] if s1 > 0 else 0)

        mean_reward_sum_0 = mean_rewards_cum[e0 - 1] - (mean_rewards_cum[s0 - 1] if s0 > 0 else 0)
        mean_reward_sum_1 = mean_rewards_cum[e1 - 1] - (mean_rewards_cum[s1 - 1] if s1 > 0 else 0)

        # Bradley-Terry 승리 확률 (0이 1을 이길 확률)
        bt_prob = sigmoid(mean_reward_sum_0 - mean_reward_sum_1)

        if bt_prob <= 0.001 or bt_prob >= 0.999:
            bt_confident_feedbacks.append((
                true_reward_sum_0, true_reward_sum_1,
                mean_reward_sum_0, mean_reward_sum_1,
                bt_prob
            ))

    # 정확도 계산
    total = len(bt_confident_feedbacks)
    correct = 0
    for true0, true1, pred0, pred1, _ in bt_confident_feedbacks:
        if (true0 > true1 and pred0 > pred1) or \
           (true0 < true1 and pred0 < pred1) or \
           (true0 == true1 and pred0 == pred1):
            correct += 1

    if total > 0:
        print(f"[BT Prob (0.001 or 0.999) Accuracy] {correct}/{total} = {correct/total:.3f}")
    else:
        print("No highly confident BT pairs found.")

    # 그룹으로 나누어 Augmentation
    confident_pairs = generate_confident_traj_pairs(all_trajs[:20000], top_k=100000)

    total = len(confident_pairs)
    correct = 0

    for traj_i, traj_j in confident_pairs:
        true_i = traj_i[1]
        true_j = traj_j[1]

        pred_i = traj_i[2]
        pred_j = traj_j[2]


        if (true_i > true_j and pred_i > pred_j) or \
            (true_i < true_j and pred_i < pred_j) or \
            (true_i == true_j and pred_i == pred_j):
                correct += 1

    if total > 0:
        print(f"[Generated Confident Pair Accuracy] {correct}/{total} = {correct/total:.3f}")
    else:
        print("No confident trajectory pairs generated.")


    plot_traj_scatter_by_uncertainty_bins(all_trajs)

    # Mean Reward vs. Std Scatter Plot
    mean_rewards_for_plot = [traj[2] for traj in all_trajs]
    std_rewards_for_plot = [traj[3] for traj in all_trajs]

    plt.figure(figsize=(7, 6))
    plt.scatter(mean_rewards_for_plot, std_rewards_for_plot, alpha=0.4, c='blue')
    plt.xlabel("Mean Predicted Reward")
    plt.ylabel("Uncertainty (std)")
    plt.title("Mean Reward vs. Uncertainty (Trajectory-wise)")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


    # ----- BT 기반 -----
    bt_pred_0 = [p[2] for p in bt_confident_feedbacks]
    bt_pred_1 = [p[3] for p in bt_confident_feedbacks]
    bt_true_0 = [p[0] for p in bt_confident_feedbacks]
    bt_true_1 = [p[1] for p in bt_confident_feedbacks]

    # ----- generate_confident_traj_pairs 기반 (랜덤 뒤집기 포함) -----
    pred_i = []
    pred_j = []
    true_i = []
    true_j = []

    for j, i in confident_pairs:
        if random.random() < 0.5:
            i, j = j, i  # swap
        pred_i.append(i[2])
        pred_j.append(j[2])
        true_i.append(i[1])
        true_j.append(j[1])

    # ----- 시각화 (2x2 subplot) -----
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # 1. BT 기반 predicted scatter
    axes[0, 0].scatter(bt_pred_0, bt_pred_1, alpha=0.4, c='green')
    axes[0, 0].set_xlabel("Predicted Reward 0")
    axes[0, 0].set_ylabel("Predicted Reward 1")
    axes[0, 0].set_title("BT Prob - Predicted Reward")
    axes[0, 0].grid(True)

    # 2. Confident Pair predicted scatter
    axes[0, 1].scatter(pred_i, pred_j, alpha=0.4, c='blue')
    axes[0, 1].set_xlabel("Predicted Reward (i)")
    axes[0, 1].set_ylabel("Predicted Reward (j)")
    axes[0, 1].set_title("Confident Pair - Predicted Reward")
    axes[0, 1].grid(True)

    # 3. BT 기반 true scatter
    axes[1, 0].scatter(bt_true_0, bt_true_1, alpha=0.4, c='orange')
    axes[1, 0].set_xlabel("True Reward 0")
    axes[1, 0].set_ylabel("True Reward 1")
    axes[1, 0].set_title("BT Prob - True Reward")
    axes[1, 0].grid(True)

    # 4. Confident Pair true scatter
    axes[1, 1].scatter(true_i, true_j, alpha=0.4, c='purple')
    axes[1, 1].set_xlabel("True Reward (i)")
    axes[1, 1].set_ylabel("True Reward (j)")
    axes[1, 1].set_title("Confident Pair - True Reward")
    axes[1, 1].grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
eval_test(data, env_name=env_name)