In [1]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
import os
import sys

# ==========================================
# 1. Ë®≠ÂÆöÂèÉÊï∏
# ==========================================
JSON_FILE = "/kaggle/input/knowledge/mmlu.json"         # ÊÇ®ÁöÑÊú¨Âú∞Ê™îÊ°à
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Ë®ìÁ∑¥ÂèÉÊï∏
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 0.001

# ==========================================
# 2. ÂÆöÁæ© MLP Êé¢Ê∏¨Âô® (Improved Probe)
# ==========================================
class ImprovedProbe(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(ImprovedProbe, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# ==========================================
# 3. ËºîÂä©ÂáΩÊï∏
# ==========================================
def format_json_prompt(item):
    """
    ÈÅ©ÈÖç choices ÁÇ∫ dictionary ÁöÑÊ†ºÂºè
    """
    prompt = f"Question: {item['question']}\n"
    for key in ['A', 'B', 'C', 'D']:
        choice_text = item['choices'].get(key, "")
        prompt += f"{key}. {choice_text}\n"
    prompt += "Answer:"
    return prompt

def parse_model_output(output_text):
    """
    Âö¥Ê†ºËß£ÊûêÊ®°ÂûãËº∏Âá∫„ÄÇÂè™Êé•Âèó 'A', 'B', 'C', 'D'„ÄÇ
    """
    # ÂéªÈô§ÂâçÂæåÁ©∫ÁôΩ‰∏¶ËΩâÂ§ßÂØ´ (ËôïÁêÜ " A", "a ", "  A  " Á≠âÊÉÖÊ≥Å)
    clean_text = output_text.strip().upper()
    
    # Ê™¢Êü•ÊòØÂê¶ÁÇ∫ÂêàÊ≥ïÈÅ∏È†Ö
    if clean_text in ['A', 'B', 'C', 'D']:
        return clean_text
    else:
        return None  # Ëß£ÊûêÂ§±Êïó

# ==========================================
# 4. Ê†∏ÂøÉÂáΩÊï∏ÔºöÊèêÂèñÁâπÂæµËàáÊ®ôÁ±§ (Âê´ÈÅéÊøæÂäüËÉΩ)
# ==========================================
def extract_features_from_data(data_list, model, tokenizer, desc="Processing"):
    """
    Ëº∏ÂÖ•ÔºöÊï∏ÊìöÂàóË°® (list of dicts)
    Ëº∏Âá∫ÔºöÁâπÂæµ Tensor (X), Ê®ôÁ±§ Tensor (y)
    ÈÇèËºØÔºöËã•Ëß£ÊûêÂ§±ÊïóÔºåÂâáË∑≥ÈÅéË©≤È°å (Drop data)
    """
    features = []
    labels = []
    
    total_count = len(data_list)
    skipped_count = 0
    
    print(f"Ê≠£Âú®Âæû {desc} ÊèêÂèñÁâπÂæµ (ÂÖ± {total_count} Á≠Ü)...")

    for item in tqdm(data_list):
        prompt = format_json_prompt(item)
        ground_truth = item['answer'].strip().upper()

        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

        with torch.no_grad():
            # 1. ÊèêÂèñ Hidden States
            outputs = model(**inputs, output_hidden_states=True)

            # ÊèêÂèñÊúÄÂæå 3 Â±§ÁöÑÊúÄÂæå‰∏ÄÂÄã TokenÔºå‰∏¶ÊãºÊé•
            layers_to_cat = []
            for i in [1, 2, 3]:
                layer_vec = outputs.hidden_states[-i][:, -1, :].cpu().to(torch.float32)
                layers_to_cat.append(layer_vec)
            combined_feature = torch.cat(layers_to_cat, dim=1) # [1, 12288]

            # 2. ËÆìÊ®°ÂûãÁîüÊàêÁ≠îÊ°à
            generated_ids = model.generate(
                **inputs, 
                max_new_tokens=1, 
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id
            )
            new_token_id = generated_ids[0][-1]
            raw_output = tokenizer.decode(new_token_id)
            
            # === [Êñ∞Â¢û] Parsing Ëàá ÈÅéÊøæÈÇèËºØ ===
            pred_text = parse_model_output(raw_output)
            
            if pred_text is None:
                # Ëß£ÊûêÂ§±Êïó (Ëº∏Âá∫‰∏çÊòØ A/B/C/D)ÔºåË∑≥ÈÅéÊ≠§È°å
                skipped_count += 1
                continue 
            
            # === Ëã•Ëß£ÊûêÊàêÂäüÔºåÊâçÂü∑Ë°åÊ®ôË®ª ===
            # 3. Ê®ôË®ª (1=Á≠îÂ∞ç/Known, 0=Á≠îÈåØ/Unknown)
            label = 1.0 if pred_text == ground_truth else 0.0

            features.append(combined_feature)
            labels.append(label)

    # È°ØÁ§∫ÈÅéÊøæÂ†±Âëä
    valid_count = len(labels)
    print("\n" + "-"*40)
    print(f"üìä [{desc}] Ë≥áÊñôËôïÁêÜÂ†±Âëä:")
    print(f"   - ÂéüÂßãÁ∏ΩÊï∏: {total_count}")
    print(f"   - ‚ùå Âà™Èô§ (Invalid): {skipped_count} È°å")
    print(f"   - ‚úÖ ‰øùÁïô (Valid):   {valid_count} È°å")
    print("-" * 40 + "\n")
    
    if valid_count == 0:
        print("‚ö†Ô∏è Ë≠¶ÂëäÔºöÊ≤íÊúâ‰ªª‰ΩïÊúâÊïàÊï∏ÊìöÔºÅË´ãÊ™¢Êü•Ê®°ÂûãËº∏Âá∫ÊòØÂê¶Ê≠£Â∏∏„ÄÇ")
        return None, None

    # Âêà‰ΩµÁÇ∫Â§ß Tensor
    X = torch.cat(features, dim=0)
    y = torch.tensor(labels)
    return X, y

# ==========================================
# 5. ‰∏ªÁ®ãÂºèÊµÅÁ®ã
# ==========================================
if __name__ == "__main__":
    print(f"Using device: {DEVICE}")

    # --- A. ËÆÄÂèñ JSON Ê™îÊ°à ---
    if not os.path.exists(JSON_FILE):
        print(f"ÈåØË™§ÔºöÊâæ‰∏çÂà∞ {JSON_FILE}ÔºÅË´ãÁ¢∫Ë™çÊ™îÊ°àÂú®Âêå‰∏ÄÁõÆÈåÑ‰∏ã„ÄÇ")
        sys.exit() # Êâæ‰∏çÂà∞Ê™îÊ°àÁõ¥Êé•ÁµêÊùü
        
    with open(JSON_FILE, 'r', encoding='utf-8') as f:
        raw_data = json.load(f)

    if 'dev' not in raw_data or 'test' not in raw_data:
        print("ÈåØË™§ÔºöJSON Ê™îÊ°à‰∏≠Áº∫Â∞ë 'dev' Êàñ 'test' Ê¨Ñ‰Ωç„ÄÇ")
        sys.exit()

    dev_data = raw_data['dev']
    test_data = raw_data['test']

    # --- B. ËºâÂÖ• LLM ---
    print(f"Loading LLM: {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model.eval()

    # --- C. ÊèêÂèñÁâπÂæµ (Data Collection) ---
    # 1. ËôïÁêÜ Dev Set
    X_train, y_train = extract_features_from_data(dev_data, model, tokenizer, desc="Dev Set")

    # 2. ËôïÁêÜ Test Set
    # test_data = test_data[:500] # Ëã•Ë¶ÅÊ∏¨Ë©¶ÂèØËß£ÈñãÈôêÂà∂
    X_test, y_test = extract_features_from_data(test_data, model, tokenizer, desc="Test Set")

    # Á¢∫‰øùÊúâÊï∏ÊìöÊâçÁπºÁ∫å
    if X_train is None or X_test is None:
        print("Êï∏Êìö‰∏çË∂≥ÔºåÁ®ãÂºèÁµÇÊ≠¢„ÄÇ")
        sys.exit()

    # --- D. Ë®ìÁ∑¥ Probe (Training) ---
    print("\n" + "="*30)
    print("ÈñãÂßãË®ìÁ∑¥Êé¢Ê∏¨Âô® (Training Probe)...")
    print("="*30)

    # Ê∫ñÂÇô DataLoader
    train_dataset = TensorDataset(X_train.to(DEVICE), y_train.unsqueeze(1).to(DEVICE))
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    # ÂàùÂßãÂåñ Probe
    input_dim = X_train.shape[1]
    probe = ImprovedProbe(input_dim).to(DEVICE)
    optimizer = optim.Adam(probe.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCELoss()

    # Ë®ìÁ∑¥Ëø¥Âúà
    best_val_acc = 0.0

    for epoch in range(EPOCHS):
        probe.train()
        epoch_loss = 0
        for batch_x, batch_y in train_loader:
            optimizer.zero_grad()
            outputs = probe(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # ÊØèÂÄã Epoch ÈÉΩÂú® Test Set ‰∏äÂÅö‰∏ÄÊ¨°Á∞°ÂñÆÈ©óË≠â
        if (epoch + 1) % 5 == 0:
            probe.eval()
            with torch.no_grad():
                test_out = probe(X_test.to(DEVICE))
                test_pred = (test_out > 0.5).float()
                acc = accuracy_score(y_test.cpu(), test_pred.cpu())

                if acc > best_val_acc:
                    best_val_acc = acc

                print(f"Epoch {epoch+1:02d} | Loss: {epoch_loss/len(train_loader):.4f} | Test Acc: {acc:.4f}")

    # --- E. ÊúÄÁµÇË©ï‰º∞ (Evaluation) ---
    print("\n" + "="*30)
    print("ÊúÄÁµÇË©ï‰º∞Â†±Âëä (Final Report)")
    print("="*30)

    probe.eval()
    with torch.no_grad():
        final_probs = probe(X_test.to(DEVICE)).cpu().numpy()
        final_preds = (final_probs > 0.5).astype(int)

    print(classification_report(y_test, final_preds, target_names=["Unknown (0)", "Known (1)"]))
    print(f"Best Test Accuracy during training: {best_val_acc:.4f}")

Using device: cuda
Loading LLM: mistralai/Mistral-7B-v0.1...


tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

2025-12-09 17:37:02.509533: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765301822.751535      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765301822.814922      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Ê≠£Âú®Âæû Dev Set ÊèêÂèñÁâπÂæµ (ÂÖ± 1000 Á≠Ü)...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [03:49<00:00,  4.36it/s]



----------------------------------------
üìä [Dev Set] Ë≥áÊñôËôïÁêÜÂ†±Âëä:
   - ÂéüÂßãÁ∏ΩÊï∏: 1000
   - ‚ùå Âà™Èô§ (Invalid): 104 È°å
   - ‚úÖ ‰øùÁïô (Valid):   896 È°å
----------------------------------------

Ê≠£Âú®Âæû Test Set ÊèêÂèñÁâπÂæµ (ÂÖ± 1000 Á≠Ü)...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [03:59<00:00,  4.18it/s]



----------------------------------------
üìä [Test Set] Ë≥áÊñôËôïÁêÜÂ†±Âëä:
   - ÂéüÂßãÁ∏ΩÊï∏: 1000
   - ‚ùå Âà™Èô§ (Invalid): 84 È°å
   - ‚úÖ ‰øùÁïô (Valid):   916 È°å
----------------------------------------


ÈñãÂßãË®ìÁ∑¥Êé¢Ê∏¨Âô® (Training Probe)...
Epoch 05 | Loss: 0.2789 | Test Acc: 0.6976
Epoch 10 | Loss: 0.1730 | Test Acc: 0.6583
Epoch 15 | Loss: 0.1009 | Test Acc: 0.6878
Epoch 20 | Loss: 0.1167 | Test Acc: 0.7074
Epoch 25 | Loss: 0.0959 | Test Acc: 0.6921
Epoch 30 | Loss: 0.0822 | Test Acc: 0.6758
Epoch 35 | Loss: 0.0487 | Test Acc: 0.6921
Epoch 40 | Loss: 0.1028 | Test Acc: 0.6889
Epoch 45 | Loss: 0.0869 | Test Acc: 0.7009
Epoch 50 | Loss: 0.0309 | Test Acc: 0.7140

ÊúÄÁµÇË©ï‰º∞Â†±Âëä (Final Report)
              precision    recall  f1-score   support

 Unknown (0)       0.57      0.62      0.60       312
   Known (1)       0.80      0.76      0.78       604

    accuracy                           0.71       916
   macro avg       0.68      0.69      0.69       916
weighte

In [2]:
# ==========================================
# [New Section] Ë®àÁÆóË´ñÊñáÊ®ôÊ∫ñÊåáÊ®ô (R-Acc, ER, A-Acc, A-F1)
# ==========================================
print("\n" + "="*40)
print("üèÜ Ë´ñÊñáÊ®ôÊ∫ñÊåáÊ®ôË©ï‰º∞ (Paper Metrics Calculation)")
print("="*40)

# 1. Ê∫ñÂÇôÊï∏Êìö
# y_test ÊòØ tensor, ËΩâÁÇ∫ numpy array (Â¶ÇÊûúÈÇÑ‰∏çÊòØÁöÑË©±)
if isinstance(y_test, torch.Tensor):
    paper_correct_flags = y_test.cpu().numpy().astype(int)
else:
    paper_correct_flags = y_test.astype(int)

# final_probs ÊòØ Probe Ë™çÁÇ∫ "Known (1)" ÁöÑÊ©üÁéá
# Â¶ÇÊûúÈÇÑÊòØ tensor, ËΩâÁÇ∫ numpy
if isinstance(final_probs, torch.Tensor):
    probs_known = final_probs.cpu().numpy().flatten()
else:
    probs_known = final_probs.flatten()

# 2. ËΩâÊèõÈÇèËºØ (ÈóúÈçµÊ≠•È©ü)
# Probe È†êÊ∏¨ 1 (Known) -> Á≥ªÁµ±ÈÅ∏Êìá "ÂõûÁ≠î" (Abstain Flag = 0)
# Probe È†êÊ∏¨ 0 (Unknown) -> Á≥ªÁµ±ÈÅ∏Êìá "ÊãíÁµï" (Abstain Flag = 1)

# ÈÄôË£°‰ΩøÁî® 0.5 ‰ΩúÁÇ∫ÈñæÂÄºÔºåÊÇ®ÂèØ‰ª•Ê†πÊìöÈúÄË¶ÅË™øÊï¥ (‰æãÂ¶Ç 0.4 Êàñ 0.6)
threshold = 0.5
paper_abstain_flags = (probs_known < threshold).astype(int)

# ÊãíÁµï‰ø°ÂøÉÂàÜÊï∏ (Abstain Score) = 1 - Áü•ÈÅìÁöÑÊ©üÁéá
paper_abstain_scores = 1 - probs_known

# 3. Ë®àÁÆóÊåáÊ®ôÂáΩÊï∏ (Áõ¥Êé•ÂÖßÂµåÂú®ÈÄôË£°Êñπ‰æøÊÇ®Âü∑Ë°å)
def get_paper_metrics(c_flags, a_flags):
    A = 0 # Answered & Correct (True Accept)
    B = 0 # Abstained & Correct (False Abstain - ÈÅéÂ∫¶‰øùÂÆà)
    C = 0 # Answered & Incorrect (False Accept - ÂπªË¶∫)
    D = 0 # Abstained & Incorrect (True Abstain - ÊàêÂäüÊîîÊà™)

    for i in range(len(c_flags)):
        if a_flags[i] == 1: # ÊãíÁµï
            if c_flags[i] == 1: B += 1
            else: D += 1
        else: # ÂõûÁ≠î
            if c_flags[i] == 1: A += 1
            else: C += 1

    total = A + B + C + D

    r_acc = A / (A + C) if (A + C) > 0 else 0.0
    er = (A - C) / total
    a_acc = (A + D) / total

    precision = D / (B + D) if (B + D) > 0 else 0.0
    recall = D / (C + D) if (C + D) > 0 else 0.0
    a_f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    return A, B, C, D, r_acc, er, a_acc, a_f1

# 4. Âü∑Ë°åË®àÁÆóËàáÈ°ØÁ§∫
A, B, C, D, r_acc, er, a_acc, a_f1 = get_paper_metrics(paper_correct_flags, paper_abstain_flags)

print(f"Á∏ΩÊ®£Êú¨Êï∏ (Total Samples): {len(paper_correct_flags)}")
print("-" * 40)
print(f"Áü©Èô£Á¥∞ÁØÄ:")
print(f"  [A] Ê≠£Á¢∫ÂõûÁ≠î (True Accept):     {A} (ÁõÆÊ®ô: È´ò)")
print(f"  [B] ÈÅéÂ∫¶ÊãíÁµï (False Abstain):   {B} (ÁõÆÊ®ô: ‰Ωé)")
print(f"  [C] Áî¢ÁîüÂπªË¶∫ (False Accept):    {C} (ÁõÆÊ®ô: ‰Ωé - ÊúÄÂç±Èö™!)")
print(f"  [D] ÊàêÂäüÊîîÊà™ (True Abstain):    {D} (ÁõÆÊ®ô: È´ò)")
print("-" * 40)
print(f"üìä ÈóúÈçµÊåáÊ®ô (Key Metrics):")
print(f"  1. Reliable Accuracy (R-Acc) : {r_acc:.4f}")
print(f"  2. Effective Reliability (ER): {er:.4f}")
print(f"  3. Abstain Accuracy (A-Acc)  : {a_acc:.4f}")
print(f"  4. Abstain F1 Score (A-F1)   : {a_f1:.4f}")
print("="*40)


üèÜ Ë´ñÊñáÊ®ôÊ∫ñÊåáÊ®ôË©ï‰º∞ (Paper Metrics Calculation)
Á∏ΩÊ®£Êú¨Êï∏ (Total Samples): 916
----------------------------------------
Áü©Èô£Á¥∞ÁØÄ:
  [A] Ê≠£Á¢∫ÂõûÁ≠î (True Accept):     460 (ÁõÆÊ®ô: È´ò)
  [B] ÈÅéÂ∫¶ÊãíÁµï (False Abstain):   144 (ÁõÆÊ®ô: ‰Ωé)
  [C] Áî¢ÁîüÂπªË¶∫ (False Accept):    118 (ÁõÆÊ®ô: ‰Ωé - ÊúÄÂç±Èö™!)
  [D] ÊàêÂäüÊîîÊà™ (True Abstain):    194 (ÁõÆÊ®ô: È´ò)
----------------------------------------
üìä ÈóúÈçµÊåáÊ®ô (Key Metrics):
  1. Reliable Accuracy (R-Acc) : 0.7958
  2. Effective Reliability (ER): 0.3734
  3. Abstain Accuracy (A-Acc)  : 0.7140
  4. Abstain F1 Score (A-F1)   : 0.5969
