<a href="https://colab.research.google.com/github/janbanot/msc-project/blob/main/test_notebooks/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!uv pip install --upgrade transformers datasets captum quantus accelerate

In [None]:
import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from datetime import datetime
from datasets import Dataset
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    Trainer,
    TrainingArguments
)
from captum.attr import IntegratedGradients, InputXGradient

In [None]:
from google.colab import drive
drive.mount('/drive')

In [None]:
# ==========================================
# 1. KONFIGURACJA GLOBALNA
# ==========================================
N_SAMPLES_XAI = 100          # Liczba próbek do metod XAI (IG/IxG)
N_SAMPLES_PROBE = 1000       # Liczba próbek do analizy warstwowej RepE
N_SAMPLES_STABILITY = 50     # Liczba par do testu stabilności (parafrazy)
BATCH_SIZE = 32
TOP_K_TOKENS = 5             # Ile słów usuwamy w metryce Comprehensiveness
DF_SIZE = 3000

# Ścieżki (Dostosuj jeśli trzeba)
DATA_PATH = '/drive/MyDrive/msc-project/jigsaw-toxic-comment/train.csv'
RESULTS_DIR = "/drive/MyDrive/msc-project/results_final"
MODEL_CHECKPOINT = "/drive/MyDrive/msc-project/models/distilbert-jigsaw-full"

# Urządzenie
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# Tworzenie katalogu wyników
os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
# ==========================================
# 2. PRZYGOTOWANIE DANYCH I MODELU (SETUP)
# ==========================================

def clean_text(example):
    """Funkcja czyszcząca tekst (taka sama jak przy treningu)."""
    text = example['comment_text']
    text = text.lower()
    text = re.sub(r'http\S+|www\S+', '', text)
    text = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', '', text)
    text = re.sub(r'\(talk\)', '', text)
    text = re.sub(r'\d{2}:\d{2}, \w+ \d{1,2}, \d{4} \(utc\)', '', text)
    text = text.replace('\n', ' ').replace('\xa0', ' ')
    text = text.strip(' "')
    text = re.sub(r'\s+', ' ', text).strip()
    example['comment_text'] = text
    return example

def prepare_environment():
    print(">>> [SETUP] Loading and preprocessing data...")

    # 1. Wczytanie danych
    try:
        df = pd.read_csv(DATA_PATH).head(DF_SIZE)
        dataset = Dataset.from_pandas(df)
    except FileNotFoundError:
        raise FileNotFoundError(f"Nie znaleziono pliku: {DATA_PATH}. Sprawdź ścieżkę w Konfiguracji Globalnej.")

    # 2. Preprocessing
    dataset = dataset.map(clean_text)

    # 3. Ładowanie Tokenizera (musi pasować do modelu)
    print(f">>> [SETUP] Loading tokenizer from: {MODEL_CHECKPOINT}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
    except OSError:
        print(f"Błąd: Nie znaleziono tokenizera w {MODEL_CHECKPOINT}. Pobieram domyślny 'distilbert-base-uncased'.")
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    # 4. Tokenizacja
    def tokenize_function(examples):
        return tokenizer(examples["comment_text"], padding="max_length", truncation=True, max_length=256)

    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # 5. Labeling
    label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    def create_labels(example):
        example['labels'] = [float(example[col]) for col in label_cols]
        return example

    final_dataset = tokenized_dataset.map(create_labels)

    # Ustawienie formatu Torch (wymagane dla modelu)
    # Usuwamy kolumny tekstowe, zostawiamy tensory
    cols_to_keep = ["input_ids", "attention_mask", "labels"]
    final_dataset.set_format("torch", columns=cols_to_keep)

    # 6. Split (Tylko po to, aby wyodrębnić zbiór testowy, na którym pracujemy)
    splits = final_dataset.train_test_split(test_size=0.2, seed=42)
    eval_dataset = splits['test']

    # 7. Ładowanie Wytrenowanego Modelu
    print(f">>> [SETUP] Loading Pre-trained Model from: {MODEL_CHECKPOINT}...")
    try:
        model = AutoModelForSequenceClassification.from_pretrained(
            MODEL_CHECKPOINT,
            num_labels=6,
            problem_type="multi_label_classification"
        )
    except OSError:
        raise OSError(f"Nie znaleziono modelu w ścieżce: {MODEL_CHECKPOINT}. Upewnij się, że najpierw uruchomiłeś skrypt treningowy.")

    model.to(device)
    model.eval() # WAŻNE: Przełączenie w tryb ewaluacji (wyłącza dropout)

    print(f">>> [SETUP] Environment Ready. Device: {device}")
    return model, tokenizer, eval_dataset

# Inicjalizacja środowiska
model, tokenizer, eval_dataset = prepare_environment()

In [None]:
# ==========================================
# 3. MODUŁ A: PORÓWNANIE METOD XAI (Comprehensiveness)
# ==========================================

def run_module_a_xai(model, tokenizer, dataset):
    print("\n>>> [MODULE A] Running XAI Comparison (IG vs IxG)...")
    model.eval()

    # Filtrowanie tylko toksycznych przykładów
    toxic_indices = [i for i, labels in enumerate(dataset['labels']) if labels[0] == 1]
    subset_indices = toxic_indices[:N_SAMPLES_XAI]
    subset = dataset.select(subset_indices)

    results = []

    # Wrapper dla Captum
    def predict_func(inputs_embeds, attention_mask=None):
        return model(inputs_embeds=inputs_embeds, attention_mask=attention_mask).logits

    ig = IntegratedGradients(predict_func)
    ixg = InputXGradient(predict_func)

    for i in tqdm(range(len(subset)), desc="XAI Evaluation"):
        input_ids = subset[i]['input_ids'].unsqueeze(0).to(device)
        attention_mask = subset[i]['attention_mask'].unsqueeze(0).to(device)
        input_embeds = model.distilbert.embeddings(input_ids)
        baseline = model.distilbert.embeddings(torch.tensor([tokenizer.pad_token_id]*256, device=device).unsqueeze(0))

        # 1. Oryginalne prawdopodobieństwo
        with torch.no_grad():
            orig_out = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
            orig_prob = torch.sigmoid(orig_out.logits)[0, 0].item()

        # Funkcja pomocnicza do obliczania spadku pewności
        def calculate_drop(attr_tensor):
            # Suma po wymiarze embeddingów
            attr_sum = attr_tensor.sum(dim=-1).squeeze(0)
            # Znajdź Top-K najważniejszych tokenów
            _, top_indices = torch.topk(attr_sum, k=TOP_K_TOKENS)

            # Maskowanie tokenów
            masked_ids = input_ids.clone()
            masked_ids[0, top_indices] = tokenizer.pad_token_id

            with torch.no_grad():
                new_out = model(masked_ids, attention_mask=attention_mask)
                new_prob = torch.sigmoid(new_out.logits)[0, 0].item()

            return orig_prob - new_prob

        # 2. Metoda IG
        attr_ig, _ = ig.attribute(inputs=input_embeds, baselines=baseline, target=0,
                                  additional_forward_args=(attention_mask,), return_convergence_delta=True)
        drop_ig = calculate_drop(attr_ig)

        # 3. Metoda IxG
        attr_ixg = ixg.attribute(inputs=input_embeds, target=0, additional_forward_args=(attention_mask,))
        drop_ixg = calculate_drop(attr_ixg)

        results.append({
            "text_id": i,
            "original_prob": orig_prob,
            "ig_drop_score": drop_ig,
            "ixg_drop_score": drop_ixg
        })

    # Zapis i Wizualizacja
    df_res = pd.DataFrame(results)
    df_res.to_csv(f"{RESULTS_DIR}/xai_comparison_results.csv", index=False)

    plt.figure(figsize=(8, 6))
    sns.boxplot(data=df_res[['ig_drop_score', 'ixg_drop_score']])
    plt.title(f"Comprehensiveness (Confidence Drop) - Removing Top {TOP_K_TOKENS} Tokens")
    plt.ylabel("Probability Drop")
    plt.savefig(f"{RESULTS_DIR}/xai_boxplot.png")
    plt.close()
    print("Module A complete.")
    return df_res

In [None]:
# ==========================================
# 4. MODUŁ B: ANALIZA WARSTWOWA (RepE)
# ==========================================

def run_module_b_repe(model, dataset):
    print("\n>>> [MODULE B] Running Layer-wise Probing (RepE)...")
    model.eval()

    # Wybór podzbioru
    subset = dataset.select(range(min(len(dataset), N_SAMPLES_PROBE)))

    # Ekstrakcja
    layers_data = {i: [] for i in range(7)} # 0=Embed, 1-6=Layers
    all_labels = []

    loader = torch.utils.data.DataLoader(subset, batch_size=BATCH_SIZE)

    for batch in tqdm(loader, desc="Extracting Layers"):
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'][:, 0].numpy() # Tylko klasa 'toxic'

        with torch.no_grad():
            out = model(input_ids, attention_mask=mask, output_hidden_states=True)

        all_labels.extend(labels)
        for i, hidden in enumerate(out.hidden_states):
            # CLS token only
            layers_data[i].append(hidden[:, 0, :].cpu().numpy())

    # Trenowanie sond
    results = []
    y = np.array(all_labels)
    y_bin = (y > 0.5).astype(int)

    for layer_idx in sorted(layers_data.keys()):
        X = np.concatenate(layers_data[layer_idx], axis=0)

        X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42)

        clf = LogisticRegression(max_iter=1000, solver='liblinear')
        clf.fit(X_train, y_train)
        preds = clf.predict(X_test)

        acc = accuracy_score(y_test, preds)
        f1 = f1_score(y_test, preds)

        results.append({"layer": layer_idx, "accuracy": acc, "f1_score": f1})

        # Zapisujemy wektor (direction) dla warstwy 5 do Modułu D
        if layer_idx == 5:
            global layer_5_activations, layer_5_labels
            layer_5_activations = X
            layer_5_labels = y_bin

    df_res = pd.DataFrame(results)
    df_res.to_csv(f"{RESULTS_DIR}/repe_layer_performance.csv", index=False)

    plt.figure(figsize=(10, 5))
    sns.lineplot(data=df_res, x='layer', y='accuracy', marker='o', label='Accuracy')
    sns.lineplot(data=df_res, x='layer', y='f1_score', marker='s', label='F1 Score')
    plt.title("Linear Probe Performance per Layer")
    plt.grid(True)
    plt.savefig(f"{RESULTS_DIR}/repe_layer_plot.png")
    plt.close()
    print("Module B complete.")

In [None]:
# ==========================================
# 5. MODUŁ C: TEST STABILNOŚCI (Robustness)
# ==========================================

def run_module_c_stability(model, tokenizer, dataset):
    print("\n>>> [MODULE C] Running Stability Analysis with T5...")

    # Wczytanie T5
    t5_name = "Vamsi/T5_Paraphrase_Paws"
    t5_tok = AutoTokenizer.from_pretrained(t5_name)
    t5_model = AutoModelForSeq2SeqLM.from_pretrained(t5_name).to(device)

    # Wybór toksycznych próbek
    toxic_indices = [i for i, l in enumerate(dataset['labels']) if l[0] == 1]
    sample_indices = toxic_indices[:N_SAMPLES_STABILITY]

    results = []

    def get_embedding(text, l_idx=5):
        in_ids = tokenizer(text, return_tensors='pt', truncation=True, max_length=256).to(device)
        with torch.no_grad():
            out = model(**in_ids, output_hidden_states=True)
        return out.hidden_states[l_idx][0, 0, :], torch.sigmoid(out.logits)[0,0].item()

    def get_top_tokens(text):
        # Uproszczona wersja IG dla stabilności
        in_ids = tokenizer(text, return_tensors='pt', truncation=True).to(device)
        emb = model.distilbert.embeddings(in_ids['input_ids'])
        ig = IntegratedGradients(lambda x: model(inputs_embeds=x).logits)
        attr = ig.attribute(emb, target=0, n_steps=10) # Mniej kroków dla szybkości
        attr_sum = attr.sum(dim=-1).squeeze(0)
        _, idx = torch.topk(attr_sum, k=TOP_K_TOKENS)
        return set(tokenizer.convert_ids_to_tokens(in_ids['input_ids'][0][idx]))

    for idx in tqdm(sample_indices, desc="Stability"):
        orig_ids = dataset[idx]['input_ids']
        orig_text = tokenizer.decode(orig_ids, skip_special_tokens=True)

        # Generowanie parafrazy
        t5_input = t5_tok("paraphrase: " + orig_text + " </s>", return_tensors="pt", padding=True).to(device)
        t5_out = t5_model.generate(t5_input.input_ids, max_length=256, do_sample=True, top_k=50)
        para_text = t5_tok.decode(t5_out[0], skip_special_tokens=True)

        # 1. Output Stability
        _, prob_orig = get_embedding(orig_text)
        vec_para, prob_para = get_embedding(para_text)
        vec_orig, _ = get_embedding(orig_text) # Recalculate to be sure

        # 2. Representation Stability
        cos_sim = F.cosine_similarity(vec_orig.unsqueeze(0), vec_para.unsqueeze(0)).item()

        # 3. Explanation Stability (Jaccard)
        # (Pomińmy dla bardzo krótkich tekstów, żeby nie wywalało błędu)
        try:
            toks_orig = get_top_tokens(orig_text)
            toks_para = get_top_tokens(para_text)
            intersect = len(toks_orig.intersection(toks_para))
            union = len(toks_orig.union(toks_para))
            jaccard = intersect / union if union > 0 else 0
        except:
            jaccard = 0.0

        results.append({
            "prob_diff": abs(prob_orig - prob_para),
            "cosine_sim": cos_sim,
            "jaccard_ig": jaccard
        })

    df_res = pd.DataFrame(results)
    df_res.to_csv(f"{RESULTS_DIR}/stability_results.csv", index=False)

    plt.figure(figsize=(8, 5))
    sns.histplot(df_res['cosine_sim'], kde=True, color='green')
    plt.title("Distribution of Representation Stability (Layer 5)")
    plt.xlabel("Cosine Similarity")
    plt.axvline(0.9, color='red', linestyle='--')
    plt.savefig(f"{RESULTS_DIR}/stability_cosine_hist.png")
    plt.close()
    print("Module C complete.")

In [None]:
# ==========================================
# 6. MODUŁ D: TEST SKUTECZNOŚCI STEROWANIA (Steering)
# ==========================================

def run_module_d_steering(model, tokenizer, dataset):
    print("\n>>> [MODULE D] Running Steering Efficacy Test...")

    # 1. Oblicz wektor sterujący (Difference of Means)
    # Korzystamy z danych zapisanych w Module B
    toxic_vecs = layer_5_activations[layer_5_labels == 1]
    safe_vecs = layer_5_activations[layer_5_labels == 0]

    mean_toxic = np.mean(toxic_vecs, axis=0)
    mean_safe = np.mean(safe_vecs, axis=0)
    direction = mean_toxic - mean_safe
    steering_tensor = torch.tensor(direction, dtype=torch.float32).to(device)

    # Parametry
    ALPHA = -3.0 # Wartość ustalona eksperymentalnie (detoksykacja)

    # Hook class
    class SteeringHook:
        def __init__(self, vector, coeff):
            self.vector = vector
            self.coeff = coeff
        def __call__(self, module, inputs, output):
            return (output[0] + (self.coeff * self.vector),) + output[1:]

    # Test na zbiorze toksycznym
    toxic_indices = [i for i, l in enumerate(dataset['labels']) if l[0] == 1][:N_SAMPLES_XAI]
    safe_indices = [i for i, l in enumerate(dataset['labels']) if l[0] == 0][:N_SAMPLES_XAI]

    success_count = 0
    side_effect_count = 0

    # Hook registration
    layer_module = model.distilbert.transformer.layer[5]

    # --- Ewaluacja Toksycznych ---
    # Register Hook
    handle = layer_module.register_forward_hook(SteeringHook(steering_tensor, ALPHA))

    for idx in toxic_indices:
        input_ids = dataset[idx]['input_ids'].unsqueeze(0).to(device)
        mask = dataset[idx]['attention_mask'].unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(input_ids, attention_mask=mask)
            prob = torch.sigmoid(out.logits)[0,0].item()
            if prob < 0.5:
                success_count += 1

    handle.remove() # Remove hook for next steps/safety

    # --- Ewaluacja Skutków Ubocznych (Safe Samples) ---
    handle = layer_module.register_forward_hook(SteeringHook(steering_tensor, ALPHA))

    for idx in safe_indices:
        input_ids = dataset[idx]['input_ids'].unsqueeze(0).to(device)
        mask = dataset[idx]['attention_mask'].unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(input_ids, attention_mask=mask)
            prob = torch.sigmoid(out.logits)[0,0].item()
            if prob > 0.5: # Stał się toksyczny?
                side_effect_count += 1

    handle.remove()

    # Raport
    success_rate = (success_count / len(toxic_indices)) * 100
    side_effect_rate = (side_effect_count / len(safe_indices)) * 100

    report = f"""
    === STEERING REPORT ===
    Method: Difference of Means (Layer 5)
    Alpha: {ALPHA}
    Samples Evaluated: {len(toxic_indices)} toxic, {len(safe_indices)} safe.

    1. Detoxification Success Rate: {success_rate:.2f}%
       (Percentage of toxic samples dropped below 0.5 probability)

    2. Side Effects Rate: {side_effect_rate:.2f}%
       (Percentage of safe samples that became false positives)

    Status: {'SUCCESS' if success_rate > 80 and side_effect_rate < 5 else 'NEEDS TUNING'}
    """

    print(report)
    with open(f"{RESULTS_DIR}/steering_report.txt", "w") as f:
        f.write(report)
    print("Module D complete.")

In [None]:
# ==========================================
# 7. URUCHOMIENIE CAŁOŚCI
# ==========================================
print(f"=== STARTING EXPERIMENT (Results -> {RESULTS_DIR}) ===")

# Uruchomienie modułów sekwencyjnie
run_module_a_xai(model, tokenizer, eval_dataset)
run_module_b_repe(model, eval_dataset)
run_module_c_stability(model, tokenizer, eval_dataset)
run_module_d_steering(model, tokenizer, eval_dataset)

print("\n=== EXPERIMENT COMPLETE ===")
print(f"Files generated in {RESULTS_DIR}:")
print(os.listdir(RESULTS_DIR))