In [None]:
# !pip install -q torch torchvision
# !pip install -q torchmetrics[multimodal]
# !pip install -q pycocoevalcap
# !pip install -q Pillow

In [None]:
import torch
from torchmetrics.multimodal import CLIPScore
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from PIL import Image
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module='transformers.modeling_utils')

In [None]:
class ExplanationRewardScorer:
    def __init__(self, alpha: float = 0.5, clip_model_name: str = "openai/clip-vit-base-patch16"):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError("Alpha must be in [0, 1]")
            
        self.alpha = alpha
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.cider_scorer = Cider()
        print("CIDEr scorer initialized.")

        print(f"Initializing CLIPScore on device: {self.device}")
        self.clip_metric = CLIPScore(model_name_or_path=clip_model_name).to(self.device)
        print("CLIPScore model loaded and ready.")

    def calculate_cider_batch(self, ground_truths: dict, predictions: dict) -> dict:
        res = {img_id: [caption] for img_id, caption in predictions.items()}
        _, individual_scores_array = self.cider_scorer.compute_score(ground_truths, res)
        
        image_ids = list(predictions.keys())
        cider_scores_dict = {img_id: score for img_id, score in zip(image_ids, individual_scores_array)}
        return cider_scores_dict

    def calculate_clip_batch(self, image_paths: dict, predictions: dict) -> dict:
        clip_scores = {}
        for img_id, pred_caption in predictions.items():
            image_path = image_paths[img_id]
            try:
                image = Image.open(image_path).convert("RGB")
            except FileNotFoundError:
                print(f"Warning: Image not found at {image_path}. Skipping.")
                clip_scores[img_id] = 0.0
                continue
            
            image_np = np.array(image)
            image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
            
            self.clip_metric.update(image_tensor.to(self.device), [pred_caption])
            score_tensor = self.clip_metric.compute()
            self.clip_metric.reset()
            
            clip_scores[img_id] = score_tensor.item()
            
        return clip_scores

    def explanation_rewards(self, ground_truths: list[list[str]], predictions: list[str], image_paths: list[str]) -> list[float]:
        """
        Args:
            ground_truths (list[list[str]]): Dạng [["caption 1a", ...], ["caption 2a", ...]]
            predictions (list[str]): Dạng ["prediction 1", "prediction 2", ...]
            image_paths (list[str]): Dạng ["path/to/image1.jpg", "path/to/image2.jpg", ...]

        Returns:
            list[float]
        """
        assert len(ground_truths) == len(predictions) == len(image_paths), \
            "Input lists must have the same length."

        gts_dict = {i: gt for i, gt in enumerate(ground_truths)}
        preds_dict = {i: pred for i, pred in enumerate(predictions)}
        paths_dict = {i: path for i, path in enumerate(image_paths)}

        print("Calculating CIDEr scores...")
        cider_scores = self.calculate_cider_batch(gts_dict, preds_dict)

        print("Calculating CLIP scores...")
        clip_scores = self.calculate_clip_batch(paths_dict, preds_dict)

        print("Combining scores to generate final rewards...")
        final_rewards = []
        for i in range(len(predictions)):
            cider_score = cider_scores.get(i, 0.0)
            
            clip_score_raw = clip_scores.get(i, 0.0)
            clip_score_normalized = max(0, (clip_score_raw - 15) / (35 - 15))

            reward = self.alpha * cider_score + (1.0 - self.alpha) * clip_score_normalized
            final_rewards.append(reward) 
            
            print(f"  Index {i}: CIDEr={cider_score:.2f}, CLIP={clip_score_raw:.2f} -> Reward={reward:.4f}")

        return final_rewards

In [None]:
ground_truths_list = [
    ["một chiếc ô tô màu đỏ đậu trên đường phố", "siêu xe thể thao màu đỏ"],
    ["chú chó đang chơi trên bãi cỏ xanh", "một chú chó lông vàng chạy trong công viên"]
]


predictions_list = [
    "một chiếc xe hơi màu đỏ",
    "con chó vui vẻ trên đồng cỏ"
]

try:
    Image.new('RGB', (224, 224), color = 'red').save("car_101.jpg")
    Image.new('RGB', (224, 224), color = 'green').save("dog_102.jpg")
    image_paths_list = [
        "car_101.jpg",
        "dog_102.jpg"
    ]

    reward_scorer = ExplanationRewardScorer(alpha=0.5)

    final_rewards = reward_scorer.explanation_rewards(
        ground_truths=ground_truths_list,
        predictions=predictions_list,
        image_paths=image_paths_list
    )

    print("\n--- Final Reward ---")
    print(final_rewards)

except Exception as e:
    print(f"\nError: {e}")


In [7]:
import re

def format_reward(completions, **kwargs):
    """Reward: +1 nếu có ít nhất 1 cặp hợp lệ cho mỗi tag think/answer/explain. 
    Lặp thêm không được cộng thêm điểm (chống spam). Tổng điểm 0..3."""
    # Tách nội dung completion (giữ nguyên cấu trúc đầu vào của bạn)
    completion_contents = [completion[0]["content"] for completion in completions]

    # Regex cho từng cặp thẻ
    pat_think = re.compile(r"<think>.*?</think>", re.DOTALL)
    pat_answer = re.compile(r"<answer>.*?</answer>", re.DOTALL)
    pat_explain = re.compile(r"<explain>.*?</explain>", re.DOTALL)
    
    scores = []
    for content in completion_contents:
        n_pair_think = len(pat_think.findall(content))
        n_pair_answer = len(pat_answer.findall(content))
        n_pair_explain = len(pat_explain.findall(content))

        n_think_open   = len(re.findall(r"<think>", content))
        n_think_close  = len(re.findall(r"</think>", content))
        n_answer_open  = len(re.findall(r"<answer>", content))
        n_answer_close = len(re.findall(r"</answer>", content))
        n_explain_open  = len(re.findall(r"<explain>", content))
        n_explain_close = len(re.findall(r"</explain>", content))
        # base score
        b_think = 1.0 if n_pair_think >= 1 else (0.5 if n_think_open or n_think_close == 1 else 0.0)
        b_answer = 1.0 if n_pair_answer >= 1 else (0.5 if n_answer_open or n_answer_close == 1 else 0.0)
        b_explain = 1.0 if n_pair_explain >= 1 else (0.5 if n_explain_open or n_explain_close == 1 else 0.0)
        b_total = b_think + b_answer + b_explain
        
        # penalty score
        # Đếm số thẻ mở/đóng riêng lẻ
        # Thẻ đơn dư = (mở + đóng) - 2 * số cặp  (không âm)
        think_singles   = max(0, n_think_open   + n_think_close   - 2 )
        answer_singles  = max(0, n_answer_open  + n_answer_close  - 2 )
        explain_singles = max(0, n_explain_open + n_explain_close - 2 )

        p_think = think_singles * 0.5
        p_answer = answer_singles * 0.5
        p_explain = explain_singles * 0.5
        p_total = p_think + p_answer + p_explain

        total = float(b_total - p_total)
        scores.append(total)

    return scores


# Mock completions: mỗi phần tử là [ {"content": "..."} ]
completions = [
    # 1) Chuẩn 3 cặp, mỗi tag 1 lần → base tối đa (3); không dư cặp, không thẻ đơn
    [ {"content": "<think>t</think>\n<answer>a</answer>\n<explain>e</explain>"} ],

    # 2) Thiếu explain → base 2; không thẻ đơn
    [ {"content": "<think>t</think>\n<answer>a</answer>"} ],

    # 3) think có 2 cặp (dư cặp) nhưng không có thẻ đơn; answer/explain 1 cặp
    [ {"content": "<think>x</think><think>y</think>\n<answer>a</answer>\n<explain>e</explain>"} ],

    # 4) Dư thẻ đơn: thiếu </think>
    [ {"content": "<think>oops\n<answer>a</answer>\n<explain>e</explain>"} ],

    # 5) Dư thẻ đơn: thiếu <answer> (chỉ có </answer>)
    [ {"content": "<think>t</think>\n</answer>\n<explain>e</explain>"} ],

    # 6) Nhiều thẻ đơn dư: thừa 2 mở explain, 1 đóng explain, không có cặp explain
    [ {"content": "<think>t</think>\n<answer>a</answer>\n<explain>\n<explain>\n</explain>"} ],

    # 7) Không có thẻ nào → base 0; không thẻ đơn
    [ {"content": "plain text only"} ],

    # 8) Lồng ghép think: non-greedy sẽ khớp 2 cặp (tùy nội dung)
    [ {"content": "<think> A <think> B </think> C </think>\n<answer>a</answer>\n<explain>e</explain>"} ],

    # 9) Sai thứ tự/chéo tag: có thể khiến findall tìm 0 cặp
    [ {"content": "<think> t </answer> u </think>\n<answer>a</answer>\n<explain>e</explain>"} ],

    # 10) Tag chữ hoa (không match vì regex phân biệt hoa/thường)
    [ {"content": "<THINK>t</THINK>\n<ANSWER>a</ANSWER>\n<EXPLAIN>e</EXPLAIN>"} ],

    # 11) Tag có khoảng trắng/thuộc tính (không match với pattern hiện tại)
    [ {"content": "<think  >t</think>\n<answer id='1'>a</answer>\n<explain class=\"x\">e</explain>"} ],

    # 12) Nhiều noise và dư đóng answer, dư mở think
    [ {"content": "prefix <think>t</think> mid <answer>a</answer> end </answer> tail <think> lone"} ],
]

scores = format_reward(completions)
for i, (item, s) in enumerate(zip(completions, scores), 1):
    print(f"Case {i}: score = {s:.3f} | content = {item[0]['content']!r}")


Case 1: score = 3.000 | content = '<think>t</think>\n<answer>a</answer>\n<explain>e</explain>'
Case 2: score = 2.000 | content = '<think>t</think>\n<answer>a</answer>'
Case 3: score = 2.000 | content = '<think>x</think><think>y</think>\n<answer>a</answer>\n<explain>e</explain>'
Case 4: score = 2.500 | content = '<think>oops\n<answer>a</answer>\n<explain>e</explain>'
Case 5: score = 2.500 | content = '<think>t</think>\n</answer>\n<explain>e</explain>'
Case 6: score = 2.500 | content = '<think>t</think>\n<answer>a</answer>\n<explain>\n<explain>\n</explain>'
Case 7: score = 0.000 | content = 'plain text only'
Case 8: score = 2.000 | content = '<think> A <think> B </think> C </think>\n<answer>a</answer>\n<explain>e</explain>'
Case 9: score = 2.500 | content = '<think> t </answer> u </think>\n<answer>a</answer>\n<explain>e</explain>'
Case 10: score = 0.000 | content = '<THINK>t</THINK>\n<ANSWER>a</ANSWER>\n<EXPLAIN>e</EXPLAIN>'
Case 11: score = 1.500 | content = '<think  >t</think>\n<answer