In [0]:
%load_ext autoreload
%autoreload 2

## Imports and Configuration

In [0]:
!pip install python-box
%restart_python

In [0]:
import pandas as pd
import torch
import mlflow
from probe.train_supcon import train_supcon_probe
from probe.test_supcon import test_supcon_probe
from pathlib import Path

In [0]:
# Base path for dataset outputs
DATA_ROOT = "/Workspace/Users/xinji@pennmedicine.upenn.edu/research/probing_activation_outputs/"

# Define experiment
EXPERIMENT_PATH = "/Workspace/Users/xinji@pennmedicine.upenn.edu/mlflow_logs/6200_E2E_testing"
mlflow.set_experiment(EXPERIMENT_PATH)

# Hardcoded model settings
LAYERS = [47, 31, 15]  # [63, 47, 31, 15, 0] 
USE_METADATA = True
PCA_DIM = 1024
PROJECTION_DIM = 512
HIDDEN_DIM = 512
ALPHA = 0.5
TEMPERATURE = 0.1
CLASS_WEIGHTS = [1.0, 0.1, 1.0]
NUM_EPOCHS = 30
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
CHECKPOINT_PATH = "checkpoints/supcon_probe_final.pt"
DATASET_CONFIGS = [
    (DATA_ROOT + "math500_CH_1", "stats_math500.csv", "activations_math500.h5"),
    (DATA_ROOT +"math500_CH_2", "stats_math500.csv", "activations_math500.h5"),
    # (DATA_ROOT +"math500_CH_3", "stats_math500.csv", "activations_math500.h5"),
    (DATA_ROOT +"gaokao_mathqa_CH", "stats_gaokao_mathqa.csv", "activations_gaokao_mathqa.h5"),
    (DATA_ROOT +"gaokao_cloze_CH", "stats_gaokao_cloze.csv", "activations_gaokao_cloze.h5")
]



## Train SupCon Probe

In [0]:
with mlflow.start_run(run_name="train_supcon_probe") as run:
    model, val_logits, val_labels, val_is_natural, test_data, (pca, scaler) = train_supcon_probe(
        dataset_configs=DATASET_CONFIGS,
        layers=LAYERS,
        input_dim = PCA_DIM + 3,
        use_metadata=USE_METADATA,
        projection_dim=PROJECTION_DIM,
        hidden_dim=HIDDEN_DIM,
        pca_dim=PCA_DIM,
        alpha=ALPHA,
        temperature=TEMPERATURE,
        class_weights=CLASS_WEIGHTS,
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        save_path=CHECKPOINT_PATH,
    )
    mlflow.log_param("input_dim", PCA_DIM)
    mlflow.log_param("dataset_configs", DATASET_CONFIGS)
    mlflow.log_param("projection_dim", PROJECTION_DIM)
    mlflow.log_param("alpha", ALPHA)
    mlflow.log_param("temperature", TEMPERATURE)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("num_epochs", NUM_EPOCHS)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("hidden_dim", HIDDEN_DIM)
    mlflow.log_param("layers", str(LAYERS))
    mlflow.log_param("checkpoint", CHECKPOINT_PATH)


In [0]:
import joblib

joblib.dump(pca, "checkpoints/pca.pkl")
joblib.dump(scaler, "checkpoints/scaler.pkl")
mlflow.log_artifact("checkpoints/pca.pkl")
mlflow.log_artifact("checkpoints/scaler.pkl")


## Find optimal threshold

Goal: find the (τ_h, τ_help) thresholds that maximize decision utility on the validation set

In [0]:
import torch
import numpy as np

# Convert val_logits into softmax probabilities
val_probs = torch.softmax(torch.tensor(val_logits), dim=1).numpy()


In [0]:
from probe.threshold_search import threshold_search_utility

best_thresh, best_util, y_pred_best = threshold_search_utility(
    probs=val_probs,
    y_true=val_labels,
    is_natural=val_is_natural,
    tau_range=(0.1, 0.9),  # Reasonable range for thresholds
    step=0.05,             # 5% increments, can adjust if needed
    verbose=True           # Print intermediate results
)


In [0]:
print("Best thresholds:", best_thresh)
print("Best decision utility:", best_util)

## Test SupCon Probe

In [0]:
from sklearn.metrics import classification_report
import numpy as np

X_test, y_test, test_df = test_data  # returned from train_supcon_probe

# with mlflow.start_run(run_name="eval_supcon_probe"):
y_pred, y_true, is_natural, logits, probs = test_supcon_probe(
    X_test=X_test,
    y_test=y_test,
    test_df=test_df,
    checkpoint_path=CHECKPOINT_PATH,
    input_dim=PCA_DIM+3,
    projection_dim=PROJECTION_DIM,
    hidden_dim=HIDDEN_DIM,
    model_type="supcon",
    log_to_mlflow=True,
    return_probs=True
)

print("Classification Report:")
print(classification_report(y_true, y_pred))


In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(y_true, y_pred, class_names, normalize=False, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt=".2f" if normalize else "d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Confusion Matrix" + (" (Normalized)" if normalize else ""))
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()

class_names = ["Harmful", "Neutral", "Helpful"]
plot_confusion_matrix(y_test, y_pred, class_names, normalize=True, save_path="figs/confusion_matrix_normalized.png")


In [0]:
from sklearn.metrics import classification_report

report = classification_report(y_test, y_pred, output_dict=True)
f1_scores = [report[str(i)]["f1-score"] for i in range(3)]

plt.figure(figsize=(5, 4))
sns.barplot(x=class_names, y=f1_scores, palette="pastel")
plt.title("F1-score per Class")
plt.ylim(0, 1)
plt.ylabel("F1-score")
plt.tight_layout()
plt.savefig("figs/f1_score_per_class.png", dpi=300)
plt.show()


## End-to-End Testing with Threshold

In [0]:
from probe.wrapper import ProbeWrapper

# Unpack test data
X_test, y_test, test_df = test_data

# Initialize probe wrapper
probe = ProbeWrapper(
    checkpoint_path=CHECKPOINT_PATH,
    pca=pca,
    scaler=scaler,
    layers=LAYERS,
    input_dim=PCA_DIM + 3,
    projection_dim=PROJECTION_DIM,
    hidden_dim=HIDDEN_DIM,
    device="cuda",
    threshold=best_thresh
)


In [0]:
from transformers import AutoModelForCausalLM, AutoTokenizer
# === Config ===
MODEL_NAME = "Qwen/QwQ-32B-Preview"
# === Load Model and Tokenizer ===
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)


In [0]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")

In [0]:
model.eval()

In [0]:
import json
from typing import List
from tqdm import tqdm

def run_batched_decoding_with_probe(
    matching_indices: List[int],
    problem_ids: List[str],
    data_generator,
    probe,
    sample_src: str = 'math500',
    prompt_lang: str = 'ch',
    batch_size: int = 4,
    save_path: str = "math500_decoding_w_probe_w_threshold_outputs.json"
):
    output = []
    assert len(matching_indices) == len(problem_ids), "Mismatch between indices and problem IDs."

    for start_idx in tqdm(range(0, len(matching_indices), batch_size)):
        end_idx = start_idx + batch_size
        one_batch_indices = matching_indices[start_idx:end_idx]
        one_batch_problem_ids = problem_ids[start_idx:end_idx]

        # Run the generation function
        generated_texts, token_counts, probe_logs = data_generator.generate_constrained_response_with_probe(
            sample_idxs=one_batch_indices,
            prompt_lang=prompt_lang,
            probe=probe
        )

        # Accumulate results
        for i in range(len(one_batch_indices)):
            output.append({
                "sample_src": sample_src,
                "sample_idx": one_batch_indices[i],
                "problem_id": one_batch_problem_ids[i],
                "probe_log": probe_logs[i],
                "token_count": token_counts[i],
                "generated_text": generated_texts[i]
            })

    # Save the full output
    with open(save_path, "w", encoding="utf-8") as f:
        json.dump(output, f, indent=2, ensure_ascii=False)

    print(f"Saved output for {len(output)} samples to {save_path}")


In [0]:
from utils.constrained_decoding import generate_probe_constrained_response_batch
from utils.data_generator import DataGenerator

In [0]:
from box import Box
import yaml

with open("config.yaml", "r") as f:
    config = Box(yaml.safe_load(f))

### Math 500 Test

In [0]:
math500_problem_ids = list(set(test_df[test_df["dataset_src"]=='math500']['problem_id'].to_list()))

dataset = pd.read_csv("/Workspace/Users/xinji@pennmedicine.upenn.edu/research/Language-Mixing/translation/data/translated_math500.csv")

In [0]:
len(math500_problem_ids)

In [0]:
matching_indices = dataset[dataset["problem_id"].isin(math500_problem_ids)].index.tolist()


In [0]:
print(f"Found {len(matching_indices)} matching indices.")
print("Example indices:", matching_indices[:5])  # show first 5

In [0]:
data_generator = DataGenerator(cfg=config, model=model, tokenizer=tokenizer, dataset=dataset)

In [0]:
run_batched_decoding_with_probe(
    matching_indices=matching_indices,
    problem_ids=math500_problem_ids,
    data_generator=data_generator,
    probe=probe,
    prompt_lang='ch',
    batch_size=4,
    save_path="math500_decoding_w_probe_w_threshold_outputs.json"
)

### Gaokao Math QA Test

In [0]:
gaokao_mathqa_problem_ids = list(set(test_df[test_df["dataset_src"]=='gaokao_mathqa']['problem_id'].to_list()))
dataset = pd.read_csv("/Workspace/Users/xinji@pennmedicine.upenn.edu/research/Language-Mixing/translation/data/translated_gaokao_mathqa.csv")

In [0]:
len(gaokao_mathqa_problem_ids)

In [0]:
matching_indices = dataset[dataset["problem_id"].isin(gaokao_mathqa_problem_ids)].index.tolist()

In [0]:
print(f"Found {len(matching_indices)} matching indices.")
print("First 5 problem ids:", gaokao_mathqa_problem_ids[:5])  # show first 5
print("First 5 indices:", matching_indices[:5])  # show first 5

In [0]:
len(matching_indices)

In [0]:
data_generator = DataGenerator(cfg=config, model=model, tokenizer=tokenizer, dataset=dataset)

In [0]:
run_batched_decoding_with_probe(
    matching_indices=matching_indices,
    problem_ids=gaokao_mathqa_problem_ids,
    data_generator=data_generator,
    probe=probe,
    sample_src='gaokao_mathqa',
    prompt_lang='ch',
    batch_size=4,
    save_path="gaokao_mathqa_decoding_w_probe_w_threshold_outputs.json"
)

### Gaokao _Cloze_ Test

In [0]:
gaokao_cloze_problem_ids = list(set(test_df[test_df["dataset_src"]=='gaokao_cloze']['problem_id'].to_list()))
dataset = pd.read_csv("/Workspace/Users/xinji@pennmedicine.upenn.edu/research/Language-Mixing/translation/data/translated_gaokao_cloze.csv")

In [0]:
len(gaokao_cloze_problem_ids)

In [0]:
matching_indices = dataset[dataset["problem_id"].isin(gaokao_cloze_problem_ids)].index.tolist()

In [0]:
print(f"Found {len(matching_indices)} matching indices.")
print("First 5 problem ids:", gaokao_cloze_problem_ids[:5])  # show first 5
print("First 5 indices:", matching_indices[:5])  # show first 5

In [0]:
len(matching_indices)

In [0]:
data_generator = DataGenerator(cfg=config, model=model, tokenizer=tokenizer, dataset=dataset)

In [0]:
run_batched_decoding_with_probe(
    matching_indices=matching_indices,
    problem_ids=gaokao_cloze_problem_ids,
    data_generator=data_generator,
    probe=probe,
    sample_src = 'gaokao_cloze',
    prompt_lang='ch',
    batch_size=4,
    save_path="gaokao_cloze_decoding_w_probe_w_threshold_outputs.json"
)