# Connecting and Uploading

In [None]:
#loading Last Datasets:
import pandas as pd
dataset_5A = pd.read_pickle("Dataset (5 Authors).pkl")

In [None]:
# taging:

dataset_5A['Merged'] = "<" + dataset_5A['Label_(number)'].astype(str) + "> " + dataset_5A['Sentence'] + " <end>"# Sentence Merge with their Label

display(dataset_5A.head(3))
print()

#Number of sample in each group:

num_samples = dataset_5A['Author'].value_counts()

display(num_samples)

print("\nSum = ", num_samples.sum())

In [None]:
# Example:
print(dataset_5A["Sentence"][10])

# Evaluator (RoBERTa):

## ---- Data Preperation:

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
!pip install transformers dataset ipython-sql

In [None]:
!pip install evaluate

In [None]:
import transformers, datasets, evaluate, pyarrow
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("evaluate:", evaluate.__version__)
print("pyarrow:", pyarrow.__version__)

from transformers import Trainer, TrainingArguments, RobertaTokenizer, RobertaForSequenceClassification
from datasets import Dataset as HFDataset
print("Imports OK")

In [None]:
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset as HFDataset
import evaluate
from datetime import datetime

In [None]:
dataset_5A.head(3)

In [None]:
display(dataset_5A['Author'].value_counts())

print("\nSum = ", dataset_5A['Author'].value_counts().sum())


In [None]:
dataset_BERT = dataset_5A[['Sentence', 'Author']].copy()

In [None]:
dataset_BERT.head(3)

In [None]:
# <0> = Charles Dickens
# <1> = Jane Austen
# <2> = Mark Twain
# <3> = Louisa May Alcott
# <4> = Herman Melville

In [None]:
# Mapping from year to label
custom_label2id = {
    "Charles Dickens": 0,
    "Jane Austen": 1,
    "Mark Twain": 2,
    "Louisa May Alcott": 3,
    "Herman Melville": 4}

dataset_BERT['Author'] = dataset_BERT['Author'].map(custom_label2id)

dataset_BERT.rename(columns={'Author': 'label'}, inplace=True)

In [None]:
dataset_BERT.head(3)

In [None]:
display(dataset_BERT['label'].value_counts())

print("\n\nTotal Number:", sum(dataset_BERT['label'].value_counts()),"\n\n")

In [None]:
dataset_BERT.rename(columns={'Sentence': 'text'}, inplace=True)
dataset_BERT.rename(columns={'label': 'label'}, inplace=True)
dataset_BERT["label"] = dataset_BERT["label"].astype("int32")
dataset_BERT.head(3)

In [None]:
# Check class distribution
num_classes = dataset_BERT["label"].nunique()
print(f"Number of unique Authors (classes): {num_classes}")

In [None]:
# Split dataset into train and test sets
train_texts, test_texts, train_labels, test_labels = train_test_split(
    dataset_BERT["text"].tolist(), dataset_BERT["label"].tolist(), test_size=0.2, random_state=42)

## ---- Tokenizing:

In [None]:
model_name = "roberta-large"

In [None]:
# Load RoBERTa tokenizer
###tokenizer = RobertaTokenizer.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

max_length = 256

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

In [None]:
# Convert to Hugging Face Dataset format
train_data = HFDataset.from_dict({"text": train_texts, "label": train_labels})
test_data  = HFDataset.from_dict({"text": test_texts,  "label": test_labels })

In [None]:
# Tokenize datasets
train_data = train_data.map(tokenize_function, batched=True)
test_data  = test_data.map(tokenize_function,  batched=True)

## ---- Training:

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, f1_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Some models return (logits, ...) — keep only the first
    if isinstance(logits, tuple):
        logits = logits[0]
    preds = np.argmax(logits, axis=-1)

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="weighted", zero_division=0
    )

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

In [None]:
# Load pre-trained RoBERTa model for classification

model = RobertaForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)

In [None]:
# Define paths and filenames
base_dir = '/content/drive/MyDrive/Colab/LLMs Project' # Local directory for saving
current_date = datetime.now().strftime("%Y.%m.%d")

# Define save paths in Google Drive with date
model_path     = f'{base_dir}/roberta_author_classifier/saved_model_{current_date}'
tokenizer_path = f'{base_dir}/roberta_author_classifier/saved_tokenizer_{current_date}'
output_dir           = f'{base_dir}/roberta_author_classifier/results_{current_date}'
logging_dir          = f'{base_dir}/roberta_author_classifier/logs_{current_date}'


In [None]:
print(model_name,":")
print(model_path)
print(tokenizer_path)
print(output_dir)
print(logging_dir)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=2,
    gradient_checkpointing=False,     # huge memory saver when True
    bf16=True, # A100
    tf32=True, # A100
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir=logging_dir,
    logging_steps=10,
    logging_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better = True,
    report_to="wandb",
    run_name="deberta-v3-large (5-authors)",
    optim="adamw_torch_fused"
)

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
!pip install wandb

In [None]:
import wandb
wandb.login()

In [None]:
# Train the model
trainer.train()

In [None]:
model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

In [None]:
# Evaluate on the test set
# Evaluate on the test set
results = trainer.evaluate()

# Ensure W&B gets plain scalars (not numpy types)
wandb.log({k: float(v) for k, v in results.items()})

print(f"Test Accuracy: {results.get('eval_accuracy', float('nan')):.4f}")
print(f"Test F1:       {results.get('eval_f1', float('nan')):.4f}")
print(f"Precision:     {results.get('eval_precision', float('nan')):.4f}")
print(f"Recall:        {results.get('eval_recall', float('nan')):.4f}")

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Get predictions on the test set
predictions_output = trainer.predict(test_data)
preds = np.argmax(predictions_output.predictions, axis=-1)
labels = predictions_output.label_ids

# Compute confusion matrix
cm = confusion_matrix(labels, preds)

# Print or plot the confusion matrix
#print("Confusion Matrix:\n", cm)

# Define the sorted list of years as labels
year_labels = ["Charles Dickens",
               "Jane Austen",
               "Mark Twain",
               "Louisa May Alcott",
               "Herman Melville"]

# Optional: use seaborn to plot a nice heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=year_labels, yticklabels=year_labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

## ---- Loading:

In [None]:
ls

In [None]:
roberta_model_path      = 'roberta_author_classifier/saved_model_2025.08.10'
roberta_tokenizer_path  = 'roberta_author_classifier/saved_tokenizer_2025.08.10'
roberta_output_dir      = 'roberta_author_classifier/results_2025.08.10'
#roberta_logging_dir    = 'roberta_author_classifier/logs_2025.05.07'

In [None]:
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch

model = RobertaForSequenceClassification.from_pretrained(roberta_model_path)
tokenizer = RobertaTokenizer.from_pretrained(roberta_tokenizer_path)

# Evaluator (Deberta-v3-Large):

## ---- Data Preperation:

In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
!pip install transformers dataset ipython-sql

In [None]:
!pip install evaluate

In [None]:
import transformers, datasets, evaluate, pyarrow
print("transformers:", transformers.__version__)
print("datasets:", datasets.__version__)
print("evaluate:", evaluate.__version__)
print("pyarrow:", pyarrow.__version__)

from transformers import Trainer, TrainingArguments, RobertaTokenizer, RobertaForSequenceClassification
from datasets import Dataset as HFDataset
print("Imports OK")

In [None]:
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset as HFDataset
import evaluate
from datetime import datetime

In [None]:
dataset_5A.head(3)

In [None]:
display(dataset_5A['Author'].value_counts())

print("\nSum = ", dataset_5A['Author'].value_counts().sum())


In [None]:
dataset_BERT = dataset_5A[['Sentence', 'Author']].copy()

In [None]:
dataset_BERT.head(3)

In [None]:
# <0> = Charles Dickens
# <1> = Jane Austen
# <2> = Mark Twain
# <3> = Louisa May Alcott
# <4> = Herman Melville

In [None]:
# Mapping from year to label
custom_label2id = {
    "Charles Dickens": 0,
    "Jane Austen": 1,
    "Mark Twain": 2,
    "Louisa May Alcott": 3,
    "Herman Melville": 4}

dataset_BERT['Author'] = dataset_BERT['Author'].map(custom_label2id)

dataset_BERT.rename(columns={'Author': 'label'}, inplace=True)

In [None]:
dataset_BERT.head(3)

In [None]:
display(dataset_BERT['label'].value_counts())

print("\n\nTotal Number:", sum(dataset_BERT['label'].value_counts()),"\n\n")

In [None]:
dataset_BERT.rename(columns={'Sentence': 'text'}, inplace=True)
dataset_BERT.rename(columns={'label': 'label'}, inplace=True)
dataset_BERT["label"] = dataset_BERT["label"].astype("int32")
dataset_BERT.head(3)

In [None]:
# Check class distribution
num_classes = dataset_BERT["label"].nunique()
print(f"Number of unique Authors (classes): {num_classes}")

In [None]:
# Split dataset into train and test sets
train_texts, test_texts, train_labels, test_labels = train_test_split(
    dataset_BERT["text"].tolist(), dataset_BERT["label"].tolist(), test_size=0.2, random_state=42)

In [None]:
model_name = "microsoft/deberta-v3-large"

## ---- Tokenizing:

In [None]:
import os, torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.set_float32_matmul_precision("high")  # enables TF32

In [None]:
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    DataCollatorWithPadding, Trainer, TrainingArguments,
    EarlyStoppingCallback
)

model_name = "microsoft/deberta-v3-large"

In [None]:
# Load RoBERTa tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

max_length = 512 ### 512

# Tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_length)

In [None]:
# Convert to Hugging Face Dataset format
train_data = HFDataset.from_dict({"text": train_texts, "label": train_labels})
test_data  = HFDataset.from_dict({"text": test_texts,  "label": test_labels })

In [None]:
# Tokenize datasets
train_data = train_data.map(tokenize_function, batched=True)
test_data  = test_data.map(tokenize_function,  batched=True)

## ---- Training:

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, f1_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # Some models return (logits, ...) — keep only the first
    if isinstance(logits, tuple):
        logits = logits[0]
    preds = np.argmax(logits, axis=-1)

    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="weighted", zero_division=0
    )

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

import numpy as np
print("Label set (train):", set(np.unique(train_data["label"])))
print("Label set (test): ", set(np.unique(test_data["label"])))
assert max(train_data["label"] + test_data["label"]) < num_classes, "Found label >= num_labels"
assert min(train_data["label"] + test_data["label"]) >= 0, "Found negative label"

In [None]:
# Load pre-trained RoBERTa model for classification

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_classes,
)

# Important with gradient checkpointing:
model.config.use_cache = False

# Use non-reentrant checkpointing if your Transformers supports it (4.38+)
if hasattr(model, "gradient_checkpointing_enable"):
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    except TypeError:
        # older versions don't accept kwargs; will still enable GC (reentrant)
        model.gradient_checkpointing_enable()

# Required on some stacks so the first layer gets grads with GC:
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()

In [None]:
# Define paths and filenames
base_dir = '/content/drive/MyDrive/Colab/LLMs Project' # Local directory for saving
current_date = datetime.now().strftime("%Y.%m.%d")

# Define save paths in Google Drive with date
deberta_model_path     = f'{base_dir}/Deberta_v3_large_author_classifier/saved_model_{current_date}'
deberta_tokenizer_path = f'{base_dir}/Deberta_v3_large_author_classifier/saved_tokenizer_{current_date}'
deberta_output_dir           = f'{base_dir}/Deberta_v3_large_author_classifier/results_{current_date}'
deberta_logging_dir          = f'{base_dir}/Deberta_v3_large_author_classifier/logs_{current_date}'


In [None]:
print(model_name,":")
print(model_path)
print(tokenizer_path)
print(output_dir)
print(logging_dir)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir= deberta_output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,     # huge memory saver when True
    bf16=True, # A100
    tf32=True, # A100
    num_train_epochs=20,
    weight_decay=0.01,
    logging_dir=deberta_logging_dir,
    logging_steps=10,
    logging_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better = True,
    report_to="wandb",
    run_name="deberta-v3-large (5-authors)",
    optim="adamw_torch_fused"
)

In [None]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    data_collator=data_collator,
    eval_dataset=test_data,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(2)]
)

In [None]:
!pip install wandb

In [None]:
import wandb
wandb.login()

In [None]:
# Train the model
trainer.train()

In [None]:
model.save_pretrained(model_path)
tokenizer.save_pretrained(tokenizer_path)

In [None]:
# Evaluate on the test set
# Evaluate on the test set
results = trainer.evaluate()

# Ensure W&B gets plain scalars (not numpy types)
wandb.log({k: float(v) for k, v in results.items()})

print(f"Test Accuracy: {results.get('eval_accuracy', float('nan')):.4f}")
print(f"Test F1:       {results.get('eval_f1', float('nan')):.4f}")
print(f"Precision:     {results.get('eval_precision', float('nan')):.4f}")
print(f"Recall:        {results.get('eval_recall', float('nan')):.4f}")

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Get predictions on the test set
predictions_output = trainer.predict(test_data)
preds = np.argmax(predictions_output.predictions, axis=-1)
labels = predictions_output.label_ids

# 2. Compute confusion matrix
cm = confusion_matrix(labels, preds)

# 3. Print or plot the confusion matrix
#print("Confusion Matrix:\n", cm)

# Define the sorted list of years as labels
year_labels = ["Charles Dickens",
               "Jane Austen",
               "Mark Twain",
               "Louisa May Alcott",
               "Herman Melville"]

# Optional: use seaborn to plot a nice heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=year_labels, yticklabels=year_labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

In [None]:
# Reliability Diagram (Calibration Curve)

import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import log_loss

# Get logits/labels on the test set
pred_out = trainer.predict(test_data)
logits = pred_out.predictions[0] if isinstance(pred_out.predictions, tuple) else pred_out.predictions  # [N, C]
labels = pred_out.label_ids  # [N]

# Convert to probabilities, confidences, predictions
probs = torch.softmax(torch.tensor(logits), dim=1).cpu().numpy()  # [N, C]
conf = probs.max(axis=1)                                          # [N]
pred = probs.argmax(axis=1)                                       # [N]
correct = (pred == labels).astype(np.float32)                     # [N]

# Reliability diagram bins (+ ECE/MCE)
def reliability_bins(confidences, correct, n_bins=15):
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(confidences, bins) - 1  # 0..n_bins-1
    bin_acc = np.full(n_bins, np.nan, dtype=np.float32)
    bin_conf = np.full(n_bins, np.nan, dtype=np.float32)
    bin_cnt = np.zeros(n_bins, dtype=np.int64)

    for b in range(n_bins):
        m = (bin_ids == b)
        if np.any(m):
            bin_acc[b] = correct[m].mean()
            bin_conf[b] = confidences[m].mean()
            bin_cnt[b] = m.sum()
        else:
            bin_conf[b] = 0.5 * (bins[b] + bins[b+1])

    ece = np.nansum(np.abs(bin_acc - bin_conf) * (bin_cnt / len(confidences)))
    mce = np.nanmax(np.abs(bin_acc - bin_conf))
    return bins, bin_conf, bin_acc, bin_cnt, ece, mce

bins, bin_conf, bin_acc, bin_cnt, ece, mce = reliability_bins(conf, correct, n_bins=15)

# Brier score (multiclass) + NLL
eye = np.eye(probs.shape[1], dtype=np.float32)
one_hot = eye[labels]                     # [N, C]
brier = np.mean(np.sum((probs - one_hot) ** 2, axis=1))
try:
    nll = log_loss(labels, probs, labels=list(range(probs.shape[1])))
except Exception:
    nll = float("nan")

print(f"ECE:  {ece:.4f}")
print(f"MCE:  {mce:.4f}")
print(f"Brier:{brier:.4f}")
print(f"NLL:  {nll:.4f}")

# Plot reliability diagram (single plot; no specific colors/styles)
valid = ~np.isnan(bin_acc)
plt.figure(figsize=(6,6))
plt.plot([0, 1], [0, 1], linestyle="--")                    # perfect calibration
plt.plot(bin_conf[valid], bin_acc[valid], marker="o")       # model calibration
plt.xlabel("Confidence (predicted probability)")
plt.ylabel("Accuracy")
plt.title(f"Reliability Diagram — ECE={ece:.3f}")
plt.grid(True)
plt.show()

# (Optional) also print a small table of bins
import pandas as pd
calib_df = pd.DataFrame({
    "bin_left": bins[:-1],
    "bin_right": bins[1:],
    "bin_count": bin_cnt,
    "avg_confidence": bin_conf,
    "avg_accuracy": bin_acc,
})
display(calib_df)


## ---- Load Best Checkpoint:

In [None]:
ls

In [None]:
deberta_model_path      = 'Deberta_v3_large_author_classifier/saved_model_2025.08.12'
deberta_tokenizer_path  = 'Deberta_v3_large_author_classifier/saved_tokenizer_2025.08.12'
deberta_output_dir      = 'Deberta_v3_large_author_classifier/results_2025.08.12'
deberta_logging_dir    = 'Deberta_v3_large_author_classifier/logs_2025.08.12'

In [None]:
import os, json, glob

# Pick the best checkpoint if Trainer recorded it; else fall back to latest
trainer_state_path = os.path.join(deberta_output_dir, "trainer_state.json")
best_ckpt = None
if os.path.isfile(trainer_state_path):
    with open(trainer_state_path, "r") as f:
        st = json.load(f)
    best_ckpt = st.get("best_model_checkpoint", None)

if not best_ckpt:
    ckpts = glob.glob(os.path.join(deberta_output_dir, "checkpoint-*"))
    assert ckpts, "No checkpoints found in results folder."
    ckpts.sort(key=lambda p: int(p.split("-")[-1]))
    best_ckpt = ckpts[-1]

print("Loading checkpoint:", best_ckpt)

In [None]:
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    DataCollatorWithPadding, Trainer, TrainingArguments,
    EarlyStoppingCallback
)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(best_ckpt)

In [None]:
# Load tokenizer (you didn’t save one yet, so load from base model id)

base_model_id = "microsoft/deberta-v3-large"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)

In [None]:
# (Re)build data collator if needed

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=8)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=deberta_output_dir,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,     # huge memory saver when True
    bf16=True, # A100
    tf32=True, # A100
    num_train_epochs=20,
    weight_decay=0.01,
    logging_dir=deberta_logging_dir,
    logging_steps=10,
    logging_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better = True,
    report_to="wandb",
    run_name="deberta-v3-large (5-authors)",
    optim="adamw_torch_fused"
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,          # reuse your existing TrainingArguments
    train_dataset=train_data,    # or omit if you only want evaluate()
    eval_dataset=test_data,
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
#print(trainer.evaluate(metric_key_prefix="test"))

In [None]:
model.save_pretrained(deberta_model_path)
tokenizer.save_pretrained(deberta_tokenizer_path)

print("Saved model to:", deberta_model_path)
print("Saved tokenizer to:", deberta_tokenizer_path)

In [None]:
#resume training

from dataclasses import replace
from transformers import Trainer

new_args = replace(
    training_args,
    num_train_epochs=training_args.num_train_epochs + 10,  # resume with 10 more epochs
    per_device_train_batch_size=32,      # <= 8 is safe on A100 with DeBERTa-large
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,      # 8 * 4 = 32 effective batch
    gradient_checkpointing=False,
    run_name=(training_args.run_name + "-extra10")
)

trainer = Trainer(
    model=model,
    args=new_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=data_collator,
    processing_class=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(2)],
)


In [None]:
trainer.train(resume_from_checkpoint=best_ckpt)

In [None]:
model.save_pretrained(deberta_model_path)
tokenizer.save_pretrained(deberta_tokenizer_path)

print("Saved model to:", deberta_model_path)
print("Saved tokenizer to:", deberta_tokenizer_path)

## ---- Loading:

In [None]:
ls

In [None]:
deberta_model_path      = 'Deberta_v3_large_author_classifier/saved_model_2025.08.12'
deberta_tokenizer_path  = 'Deberta_v3_large_author_classifier/saved_tokenizer_2025.08.12'
deberta_output_dir      = 'Deberta_v3_large_author_classifier/results_2025.08.12'
#roberta_logging_dir    = 'Deberta_v3_large_author_classifier/logs_2025.08.12'

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained(deberta_tokenizer_path, use_fast=True, local_files_only=True)
model     = AutoModelForSequenceClassification.from_pretrained(deberta_model_path, local_files_only=True)
model.config.use_cache = False

# AI-Evaluate-IA

## ---- Evaluation By RoBERTa:

In [None]:
# First Run Below Subsections in AI-Evaluate-AI Section:
## ---- Data Preperation:
## ---- Loading:

In [None]:
import pandas as pd

generated_texts_gpt3      = pd.read_csv("generated_texts_gpt3.csv")
generated_texts_gpt3_lora = pd.read_csv("generated_texts_gpt3_lora.csv")

In [None]:
generated_texts_gpt3['Model'] = 'gpt3'
generated_texts_gpt3_lora['Model'] = 'gpt3_lora'

generated_texts_merged = pd.concat([generated_texts_gpt3, generated_texts_gpt3_lora], ignore_index=True)

In [None]:
generated_texts_merged['expected_label'] = generated_texts_merged['Text'].str.extract(r'^<(\d+)>').astype(int)

In [None]:
generated_texts_merged['Text'] = generated_texts_merged['Text'].str.replace(r'^<\d+>', '', regex=True)
generated_texts_merged['Text'] = generated_texts_merged['Text'].str.replace(r'<end>.*', '', regex=True).str.strip()

In [None]:
generated_texts_merged

In [None]:
for i in range(1000,1100):
  print(f"{i}: {generated_texts_merged.loc[i,"Model"]} [{generated_texts_merged.loc[i,"expected_label"]}] {generated_texts_merged.loc[i,"Text"]}")

In [None]:
list_generated_texts_merged = generated_texts_merged["Text"].tolist()

In [None]:
len(list_generated_texts_merged)

In [None]:
inputs = tokenizer( list_generated_texts_merged,
                    padding=True,
                    truncation=True,
                    return_tensors="pt"
                  )

In [None]:
# Set model to eval mode
import torch
import pandas as pd
from tqdm.auto import tqdm
import numpy as np

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Run predictions

label_map = {
    0: "Charles Dickens",
    1: "Jane Austen",
    2: "Mark Twain",
    3: "Louisa May Alcott",
    4: "Herman Melville"
}

# inputs is a dict of tensors like {'input_ids': ..., 'attention_mask': ...}
N = inputs["input_ids"].shape[0]
bs = 64  # 16/32/64

all_preds = []
all_probs = []   # will hold per-class probabilities

with torch.inference_mode(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
    for i in tqdm(range(0, N, bs), desc="Predicting", total=(N + bs - 1)//bs):
        batch = {k: v[i:i+bs].to(device, non_blocking=True) for k, v in inputs.items()}
        logits = model(**batch).logits                         # [B, 5]
        probs  = torch.softmax(logits.float(), dim=1).cpu()    # [B, 5] on CPU
        preds  = probs.argmax(dim=1)

        all_probs.append(probs)
        all_preds.append(preds)

# concat to full arrays
predictions = torch.cat(all_preds)                 # [N]
probs_np = torch.cat(all_probs).numpy()            # [N, 5]

In [None]:
# optional: tidy column names
prob_cols = [f"prob_{label_map[i].replace(' ', '_')}" for i in range(5)]

for i, col in enumerate(prob_cols):
    generated_texts_merged[col] = probs_np[:, i]

# keep your predicted labels too
generated_texts_merged["predicted_Label_roberta"] = predictions.numpy()

In [None]:
generated_texts_merged

## ---- Visulization:

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge
import numpy as np
import pandas as pd

# Map labels to author names
label_map = {
    0: "Charles Dickens",
    1: "Jane Austen",
    2: "Mark Twain",
    3: "Louisa May Alcott",
    4: "Herman Melville"
}

generated_texts_merged["Expected Author"] = generated_texts_merged["expected_label"].map(label_map)
generated_texts_merged["Predicted Author"] = generated_texts_merged["predicted_Label_roberta"].map(label_map)

# Prepare pie-chart data at each (expected, predicted) coordinate
grouped = (
    generated_texts_merged.groupby(["Expected Author", "Predicted Author", "Model"])
    .size()
    .reset_index(name="Count")
)

# Create a pivot table to get model proportions at each coordinate
pivot = grouped.pivot_table(index=["Expected Author", "Predicted Author"],
                             columns="Model",
                             values="Count",
                             fill_value=0)

# Color map for models
model_colors = dict(zip(pivot.columns, plt.cm.Set2.colors[:len(pivot.columns)]))

# Plot pies at each grid location
fig, ax = plt.subplots(figsize=(10, 8))

x_labels = sorted(pivot.index.get_level_values(0).unique())
y_labels = sorted(pivot.index.get_level_values(1).unique())

# Create mappings for positioning
x_map = {name: i for i, name in enumerate(x_labels)}
y_map = {name: i for i, name in enumerate(y_labels)}

# Draw pies
for (x_label, y_label), row in pivot.iterrows():
    total = row.sum()
    if total == 0:
        continue

    x = x_map[x_label]
    y = y_map[y_label]
    radius = 0.05 + 0.1 * (total / pivot.values.max())  # Adjust pie size

    start_angle = 0
    for model, count in row.items():
        if count == 0:
            continue
        angle = 360 * count / total
        wedge = Wedge(center=(x, y), r=radius,
                      theta1=start_angle, theta2=start_angle + angle,
                      facecolor=model_colors[model],
                      edgecolor='gray', linewidth=0.5)
        ax.add_patch(wedge)
        start_angle += angle

# Configure axis
ax.set_xlim(-0.5 - 0.3, len(x_map) - 0.5 + 0.3)
ax.set_ylim(-0.5 - 0.3, len(y_map) - 0.5 + 0.3)
ax.set_xticks(list(x_map.values()))
ax.set_yticks(list(y_map.values()))
ax.set_xticklabels(x_labels, rotation=30)
ax.set_yticklabels(y_labels)
ax.set_xlabel("Expected Author")
ax.set_ylabel("Predicted Author")
ax.set_title("Expected vs. Predicted (Pie chart by Model)")

for (x_label, y_label), row in pivot.iterrows():
    total = row.sum()
    if total == 0:
        continue

    x = x_map[x_label]
    y = y_map[y_label]
    radius = 0.1 + 0.2 * (total / pivot.values.max())  # smaller pies

    start_angle = 0
    for model, count in row.items():
        if count == 0:
            continue
        angle = 360 * count / total
        wedge = Wedge(center=(x, y), r=radius,
                      theta1=start_angle, theta2=start_angle + angle,
                      facecolor=model_colors[model],
                      edgecolor='gray', linewidth=0.5)
        ax.add_patch(wedge)
        start_angle += angle

    ax.text(
        x, y, str(int(total)),
        ha='center', va='center',
        fontsize=9, weight='bold', color='black'
    )


# Legend
legend_patches = [plt.Line2D([0], [0], marker='o', color='w',
                label=model, markerfacecolor=color, markersize=10)
                for model, color in model_colors.items()]
ax.legend(handles=legend_patches, title="Model", bbox_to_anchor=(1.05, 1), loc="upper left")

ax.set_aspect('equal')
ax.grid(True, linestyle="--", alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
def plot_pie_chart_for_model(data, model_filter, title_suffix):
    # Filter for a single model
    data = data[data["Model"] == model_filter]

    # Group and pivot
    grouped = (
        data.groupby(["Expected Author", "Predicted Author", "Model"])
        .size()
        .reset_index(name="Count")
    )
    pivot = grouped.pivot_table(
        index=["Expected Author", "Predicted Author"],
        columns="Model",
        values="Count",
        fill_value=0
    )

    # Color map
    model_colors = dict(zip(pivot.columns, plt.cm.Set2.colors[:len(pivot.columns)]))

    # Figure setup
    fig, ax = plt.subplots(figsize=(10, 8))
    x_labels = sorted(pivot.index.get_level_values(0).unique())
    y_labels = sorted(pivot.index.get_level_values(1).unique())
    x_map = {name: i for i, name in enumerate(x_labels)}
    y_map = {name: i for i, name in enumerate(y_labels)}

    # Draw pies
    for (x_label, y_label), row in pivot.iterrows():
        total = row.sum()
        if total == 0:
            continue

        x = x_map[x_label]
        y = y_map[y_label]
        radius = 0.1 + 0.2 * (total / pivot.values.max())

        start_angle = 0
        for model, count in row.items():
            if count == 0:
                continue
            angle = 360 * count / total
            wedge = Wedge(
                center=(x, y), r=radius,
                theta1=start_angle, theta2=start_angle + angle,
                facecolor=model_colors[model],
                edgecolor='gray', linewidth=0.5
            )
            ax.add_patch(wedge)
            start_angle += angle

        ax.text(
            x, y, str(int(total)),
            ha='center', va='center',
            fontsize=9, weight='bold', color='black'
        )

    # Configure axis
    ax.set_xlim(-0.8, len(x_map) - 0.2)
    ax.set_ylim(-0.8, len(y_map) - 0.2)
    ax.set_xticks(list(x_map.values()))
    ax.set_yticks(list(y_map.values()))
    ax.set_xticklabels(x_labels, rotation=30)
    ax.set_yticklabels(y_labels)
    ax.set_xlabel("Expected Author")
    ax.set_ylabel("Predicted Author")
    ax.set_title(f"Expected vs. Predicted ({model_filter} only {title_suffix})")

    # Legend
    legend_patches = [
        plt.Line2D([0], [0], marker='o', color='w',
                   label=model, markerfacecolor=color, markersize=10)
        for model, color in model_colors.items()
    ]
    ax.legend(handles=legend_patches, title="Model", bbox_to_anchor=(1.05, 1), loc="upper left")

    ax.set_aspect('equal')
    ax.grid(True, linestyle="--", alpha=0.3)
    plt.tight_layout()
    plt.show()


# Create the two separate plots
plot_pie_chart_for_model(generated_texts_merged, "gpt3", "(GPT-3)")
plot_pie_chart_for_model(generated_texts_merged, "gpt3_lora", "(GPT-3 LoRA)")


## ---- Analysis:

In [None]:
generated_texts_merged

In [None]:
# Add diagnostic columns:

In [None]:
import numpy as np

y_true = generated_texts_merged["expected_label"].to_numpy()
y_pred = generated_texts_merged["predicted_Label_roberta"].to_numpy()

# If you already have probs_np from earlier: shape [N,5]
p_true = probs_np[np.arange(len(probs_np)), y_true]
p_pred = probs_np[np.arange(len(probs_np)), y_pred]
conf   = probs_np.max(axis=1)                       # model confidence
entropy = -np.sum(probs_np * np.log(probs_np + 1e-12), axis=1)  # uncertainty
margin = p_pred - p_true                            # how much more the model preferred pred over true

generated_texts_merged["p_true"] = p_true
generated_texts_merged["p_pred"] = p_pred
generated_texts_merged["confidence"] = conf
generated_texts_merged["entropy"] = entropy
generated_texts_merged["margin"] = margin
generated_texts_merged["is_error"] = (y_true != y_pred)

In [None]:
generated_texts_merged

In [None]:
# 2) Where are errors concentrated?

In [None]:
def author_confusion_table(df):
    """
    Creates a table of confusion stats between Expected Author and Predicted Author.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain columns:
        - "Expected Author" (string)
        - "Predicted Author" (string)
        - "Text" (any, used for count)
        - "p_true" (float, probability assigned to the true label)
        - "confidence" (float, max predicted probability)

    Returns
    -------
    pandas.DataFrame
        Index: (Expected Author, Predicted Author)
        Columns: n, mean_p_true, mean_conf
        Sorted by n descending.
    """
    tbl = (
        df.groupby(["Expected Author", "Predicted Author"])
          .agg(n=("Text", "size"),
               mean_p_true=("p_true", "mean"),
               mean_conf=("confidence", "mean"))
          .sort_values(["n"], ascending=False)
    )
    return tbl

In [None]:
# Full dataset
full_tbl = author_confusion_table(generated_texts_merged)

# Only GPT-3 samples
gpt3_tbl = author_confusion_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3"])

# Only GPT-3 LoRA samples
gpt3_lora_tbl = author_confusion_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3_lora"])

In [None]:
gpt3_tbl

In [None]:
gpt3_lora_tbl

In [None]:
# 3) Confidence & calibration:

In [None]:
import numpy as np
import pandas as pd

def calibration_table(df, bins=np.linspace(0, 1, 21)):
    """
    Compute calibration stats (n, acc, mean_conf, gap) by confidence bin.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain:
        - "confidence": float, predicted max probability
        - "is_error": bool, True if prediction != actual label
        - "Text": used for counting samples
    bins : array-like, optional
        Bin edges for confidence (default: 0.0 to 1.0 in steps of 0.1)

    Returns
    -------
    pandas.DataFrame
        Index: conf_bin (int bin index)
        Columns: n, acc, mean_conf, gap
    """
    df = df.copy()
    df["conf_bin"] = np.digitize(df["confidence"], bins)

    calib = (
        df.groupby("conf_bin")
          .agg(n=("Text", "size"),
               acc=("is_error", lambda x: 1 - x.mean()),
               mean_conf=("confidence", "mean"))
    )
    calib["gap"] = calib["mean_conf"] - calib["acc"]
    return calib


In [None]:
# Full dataset
full_calib = calibration_table(generated_texts_merged)

# GPT-3 only
gpt3_calib = calibration_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3"])

# GPT-3 LoRA only
gpt3_lora_calib = calibration_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3_lora"])

In [None]:
gpt3_calib

In [None]:
gpt3_lora_calib

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

calib_for_plot = gpt3_calib.copy() # full_calib \ gpt3_calib \ gpt3_lora_calib

plt.figure(figsize=(7,6))
plt.plot(calib_for_plot["mean_conf"], calib_for_plot["acc"], marker='o', label="Actual accuracy")
plt.plot([0,1],[0,1],'--',color='gray',label="Perfect calibration")

sizes = (calib_for_plot["n"] / calib_for_plot["n"].max()) * 300
plt.scatter(calib_for_plot["mean_conf"], calib_for_plot["acc"], s=sizes, alpha=0.6)

plt.xlabel("Predicted confidence")
plt.ylabel("Percentage of Matching Labels")
plt.title("Reliability Curve (Calibration Plot)")
#plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.show()

In [None]:
generated_texts_merged[generated_texts_merged["Model"] == "gpt3"]

In [None]:
high_conf_samples = generated_texts_merged[generated_texts_merged["confidence"] > 0.93]

In [None]:
high_conf_samples = high_conf_samples.reset_index(drop=True)
high_conf_samples

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge
import numpy as np
import pandas as pd

# Map labels to author names
label_map = {
    0: "Charles Dickens",
    1: "Jane Austen",
    2: "Mark Twain",
    3: "Louisa May Alcott",
    4: "Herman Melville"
}

high_conf_samples["Expected Author"] = high_conf_samples["expected_label"].map(label_map)
high_conf_samples["Predicted Author"] = high_conf_samples["predicted_Label_roberta"].map(label_map)

# Prepare pie-chart data at each (expected, predicted) coordinate
grouped = (
    high_conf_samples.groupby(["Expected Author", "Predicted Author", "Model"])
    .size()
    .reset_index(name="Count")
)

# Create a pivot table to get model proportions at each coordinate
pivot = grouped.pivot_table(index=["Expected Author", "Predicted Author"],
                             columns="Model",
                             values="Count",
                             fill_value=0)

# Color map for models
model_colors = dict(zip(pivot.columns, plt.cm.Set2.colors[:len(pivot.columns)]))

# Plot pies at each grid location
fig, ax = plt.subplots(figsize=(10, 8))

x_labels = sorted(pivot.index.get_level_values(0).unique())
y_labels = sorted(pivot.index.get_level_values(1).unique())

# Create mappings for positioning
x_map = {name: i for i, name in enumerate(x_labels)}
y_map = {name: i for i, name in enumerate(y_labels)}

# Draw pies
for (x_label, y_label), row in pivot.iterrows():
    total = row.sum()
    if total == 0:
        continue

    x = x_map[x_label]
    y = y_map[y_label]
    radius = 0.05 + 0.1 * (total / pivot.values.max())  # Adjust pie size

    start_angle = 0
    for model, count in row.items():
        if count == 0:
            continue
        angle = 360 * count / total
        wedge = Wedge(center=(x, y), r=radius,
                      theta1=start_angle, theta2=start_angle + angle,
                      facecolor=model_colors[model],
                      edgecolor='gray', linewidth=0.5)
        ax.add_patch(wedge)
        start_angle += angle

# Configure axis
ax.set_xlim(-0.5 - 0.3, len(x_map) - 0.5 + 0.3)
ax.set_ylim(-0.5 - 0.3, len(y_map) - 0.5 + 0.3)
ax.set_xticks(list(x_map.values()))
ax.set_yticks(list(y_map.values()))
ax.set_xticklabels(x_labels, rotation=30)
ax.set_yticklabels(y_labels)
ax.set_xlabel("Expected Author")
ax.set_ylabel("Predicted Author")
ax.set_title("Expected vs. Predicted (Pie chart by Model)")

for (x_label, y_label), row in pivot.iterrows():
    total = row.sum()
    if total == 0:
        continue

    x = x_map[x_label]
    y = y_map[y_label]
    radius = 0.1 + 0.2 * (total / pivot.values.max())  # smaller pies

    start_angle = 0
    for model, count in row.items():
        if count == 0:
            continue
        angle = 360 * count / total
        wedge = Wedge(center=(x, y), r=radius,
                      theta1=start_angle, theta2=start_angle + angle,
                      facecolor=model_colors[model],
                      edgecolor='gray', linewidth=0.5)
        ax.add_patch(wedge)
        start_angle += angle

    ax.text(
        x, y, str(int(total)),
        ha='center', va='center',
        fontsize=9, weight='bold', color='black'
    )


# Legend
legend_patches = [plt.Line2D([0], [0], marker='o', color='w',
                label=model, markerfacecolor=color, markersize=10)
                for model, color in model_colors.items()]
ax.legend(handles=legend_patches, title="Model", bbox_to_anchor=(1.05, 1), loc="upper left")

ax.set_aspect('equal')
ax.grid(True, linestyle="--", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
def plot_pie_chart_for_model(data, model_filter, title_suffix):
    # Filter for a single model
    data = data[data["Model"] == model_filter]

    # Group and pivot
    grouped = (
        data.groupby(["Expected Author", "Predicted Author", "Model"])
        .size()
        .reset_index(name="Count")
    )
    pivot = grouped.pivot_table(
        index=["Expected Author", "Predicted Author"],
        columns="Model",
        values="Count",
        fill_value=0
    )

    # Color map
    model_colors = dict(zip(pivot.columns, plt.cm.Set2.colors[:len(pivot.columns)]))

    # Figure setup
    fig, ax = plt.subplots(figsize=(10, 8))
    x_labels = sorted(pivot.index.get_level_values(0).unique())
    y_labels = sorted(pivot.index.get_level_values(1).unique())
    x_map = {name: i for i, name in enumerate(x_labels)}
    y_map = {name: i for i, name in enumerate(y_labels)}

    # Draw pies
    for (x_label, y_label), row in pivot.iterrows():
        total = row.sum()
        if total == 0:
            continue

        x = x_map[x_label]
        y = y_map[y_label]
        radius = 0.1 + 0.2 * (total / pivot.values.max())

        start_angle = 0
        for model, count in row.items():
            if count == 0:
                continue
            angle = 360 * count / total
            wedge = Wedge(
                center=(x, y), r=radius,
                theta1=start_angle, theta2=start_angle + angle,
                facecolor=model_colors[model],
                edgecolor='gray', linewidth=0.5
            )
            ax.add_patch(wedge)
            start_angle += angle

        ax.text(
            x, y, str(int(total)),
            ha='center', va='center',
            fontsize=9, weight='bold', color='black'
        )

    # Configure axis
    ax.set_xlim(-0.8, len(x_map) - 0.2)
    ax.set_ylim(-0.8, len(y_map) - 0.2)
    ax.set_xticks(list(x_map.values()))
    ax.set_yticks(list(y_map.values()))
    ax.set_xticklabels(x_labels, rotation=30)
    ax.set_yticklabels(y_labels)
    ax.set_xlabel("Expected Author")
    ax.set_ylabel("Predicted Author")
    ax.set_title(f"Expected vs. Predicted ({model_filter} only {title_suffix})")

    # Legend
    legend_patches = [
        plt.Line2D([0], [0], marker='o', color='w',
                   label=model, markerfacecolor=color, markersize=10)
        for model, color in model_colors.items()
    ]
    ax.legend(handles=legend_patches, title="Model", bbox_to_anchor=(1.05, 1), loc="upper left")

    ax.set_aspect('equal')
    ax.grid(True, linestyle="--", alpha=0.3)
    plt.tight_layout()
    plt.show()


# Create the two separate plots
plot_pie_chart_for_model(high_conf_samples, "gpt3", "(GPT-3)")
plot_pie_chart_for_model(high_conf_samples, "gpt3_lora", "(GPT-3 LoRA)")


In [None]:
def correct_percentages(df):
    result = (
        df.groupby("Model")["is_error"]
        .agg(
            total_samples="count",
            error_count="sum"
        )
        .reset_index()
    )
    result["correct_percentage"] = 1 - (result["error_count"] / result["total_samples"])
    return result


In [None]:
# For all data
print(correct_percentages(generated_texts_merged))
print()
# For only high-confidence subset
high_conf = generated_texts_merged[generated_texts_merged["confidence"] > 0.93]
print(correct_percentages(high_conf))

In [None]:
# 4) Hard vs easy mistakes (use the probabilities):

In [None]:
hi_conf_wrong = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3") &
    (generated_texts_merged["is_error"]) &
    (generated_texts_merged["confidence"] > 0.93)]

total_high_confidence = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3") &
    (generated_texts_merged["confidence"] > 0.93)]

lo_conf_wrong = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3") &
    (generated_texts_merged["is_error"]) &
    (generated_texts_merged["confidence"] < 0.5)]

total_low_confidence = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3") &
    (generated_texts_merged["confidence"] < 0.5)]


print("The number of mismatched sentences with high model confidence ( > 90% ):")
print(len(hi_conf_wrong),"/", len(total_high_confidence), "=", round(len(hi_conf_wrong)/len(total_high_confidence), 2))

print("\nThe number of mismatched sentences with low model confidence ( < 50% ):")
print(len(lo_conf_wrong),"/", len(total_low_confidence), "=", round(len(lo_conf_wrong)/len(total_low_confidence), 2))

In [None]:
hi_conf_wrong = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3_lora") &
    (generated_texts_merged["is_error"]) &
    (generated_texts_merged["confidence"] > 0.93)]

total_high_confidence = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3_lora") &
    (generated_texts_merged["confidence"] > 0.93)]

lo_conf_wrong = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3_lora") &
    (generated_texts_merged["is_error"]) &
    (generated_texts_merged["confidence"] < 0.5)]

total_low_confidence = generated_texts_merged[
    (generated_texts_merged["Model"] == "gpt3_lora") &
    (generated_texts_merged["confidence"] < 0.5)]


print("The number of mismatched sentences with high model confidence ( > 90% ):")
print(len(hi_conf_wrong),"/", len(total_high_confidence), "=", round(len(hi_conf_wrong)/len(total_high_confidence), 2))

print("\nThe number of mismatched sentences with low model confidence ( < 50% ):")
print(len(lo_conf_wrong),"/", len(total_low_confidence), "=", round(len(lo_conf_wrong)/len(total_low_confidence), 2))

In [None]:
# 5) Top-k signal (is the true class “close”?)

In [None]:
import numpy as np

def topk_accuracy(df, probs, k_values=(1, 2, 3)):
    """
    Compute top-k accuracy for given dataset and probability matrix.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain:
        - "expected_label": int, true class index
        - "predicted_Label_roberta": int, predicted class index
    probs : np.ndarray
        Shape (N, C) array of predicted probabilities for each class.
        Rows must align with df rows.
    k_values : tuple
        Top-k values to compute (default: (1, 2, 3)).

    Returns
    -------
    dict
        Keys: "top-1", "top-2", ..., values: accuracy as float.
    """
    y_true = df["expected_label"].to_numpy()
    results = {}

    for k in k_values:
        topk_right = (np.argsort(-probs, axis=1)[:, :k] == y_true[:, None]).any(axis=1).mean()
        results[f"top-{k}"] = round(topk_right, 2)

    return results


In [None]:
# Full dataset
full_topk = topk_accuracy(generated_texts_merged, probs_np)

# GPT-3 only
mask_gpt3 = generated_texts_merged["Model"] == "gpt3"
gpt3_topk = topk_accuracy(generated_texts_merged[mask_gpt3], probs_np[mask_gpt3])

# GPT-3 LoRA only
mask_lora = generated_texts_merged["Model"] == "gpt3_lora"
gpt3_lora_topk = topk_accuracy(generated_texts_merged[mask_lora], probs_np[mask_lora])

In [None]:
gpt3_topk

In [None]:
gpt3_lora_topk

In [None]:
# Filter: confidence > 0.93
high_conf_mask = generated_texts_merged["confidence"] > 0.93

# Full dataset (high confidence only)
full_topk = topk_accuracy(
    generated_texts_merged[high_conf_mask],
    probs_np[high_conf_mask]
)

# GPT-3 only (high confidence only)
mask_gpt3 = (generated_texts_merged["Model"] == "gpt3") & high_conf_mask
gpt3_topk = topk_accuracy(
    generated_texts_merged[mask_gpt3],
    probs_np[mask_gpt3]
)

# GPT-3 LoRA only (high confidence only)
mask_lora = (generated_texts_merged["Model"] == "gpt3_lora") & high_conf_mask
gpt3_lora_topk = topk_accuracy(
    generated_texts_merged[mask_lora],
    probs_np[mask_lora]
)

In [None]:
gpt3_topk

In [None]:
gpt3_lora_topk

In [None]:
# 6) Margin analysis (which errors are “close”?)

In [None]:
def close_and_stubborn_errors(df, model=None, prob_cols_prefix="prob_",
                               true_label_col="expected_label",
                               pred_label_col="predicted_Label_roberta"):
    """
    Identify 'close call' errors (small gap between true and predicted prob)
    and 'stubborn' errors (predicted much higher than true) for a dataset or subset.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain:
        - probability columns starting with prob_cols_prefix (default: "prob_")
        - true_label_col (default: "expected_label")
        - pred_label_col (default: "predicted_Label_roberta")
    model : str or None
        If given, filters df where Model == model before analysis.
    prob_cols_prefix : str
        Prefix for the probability columns.
    true_label_col : str
        Name of the column with true class indices.
    pred_label_col : str
        Name of the column with predicted class indices.

    Returns
    -------
    tuple of pd.DataFrame
        (close_calls_df, stubborn_errors_df)
    """
    # Optional filtering by model
    if model is not None:
        dfx = df[df["Model"] == model].copy()
    else:
        dfx = df.copy()

    # Get probability matrix
    prob_cols = [c for c in dfx.columns if c.startswith(prob_cols_prefix)]
    if not prob_cols:
        raise ValueError(f"No probability columns found starting with '{prob_cols_prefix}'.")

    probs = dfx[prob_cols].to_numpy()
    y_true = dfx[true_label_col].to_numpy()
    y_pred = dfx[pred_label_col].to_numpy()

    # Compute required metrics if missing
    if "p_true" not in dfx.columns:
        dfx["p_true"] = probs[np.arange(len(probs)), y_true]
    if "confidence" not in dfx.columns:
        dfx["confidence"] = probs.max(axis=1)
    if "is_error" not in dfx.columns:
        dfx["is_error"] = y_true != y_pred
    if "margin" not in dfx.columns:
        dfx["margin"] = dfx["confidence"] - dfx["p_true"]

    # true_minus_best = p_true - confidence
    dfx["true_minus_best"] = dfx["p_true"] - dfx["confidence"]

    # Define subsets
    close_calls_df = dfx.query("is_error and true_minus_best > -0.1")
    stubborn_errors_df = dfx.query("is_error and margin > 0.3")

    return close_calls_df, stubborn_errors_df

In [None]:
# Full dataset
close_calls_all, stubborn_errors_all = close_and_stubborn_errors(generated_texts_merged)

# Only GPT-3
close_calls_gpt3, stubborn_errors_gpt3 = close_and_stubborn_errors(generated_texts_merged, model="gpt3")

# Only GPT-3 LoRA
close_calls_lora, stubborn_errors_lora = close_and_stubborn_errors(generated_texts_merged, model="gpt3_lora")

In [None]:
print("gpt3:\n")
print(f"Close calls:     {len(close_calls_gpt3)} samples")
print(f"Stubborn errors: {len(stubborn_errors_gpt3)} samples")

In [None]:
print("gpt3_lora:\n")
print(f"Close calls:     {len(close_calls_lora)} samples")
print(f"Stubborn errors: {len(stubborn_errors_lora)} samples")

In [None]:
# 7) Length & truncation effects

In [None]:
#8) Per-author difficulty & bias:

In [None]:
def author_stats_table(df):
    """
    Compute per-author performance statistics.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain:
        - "Expected Author": str
        - "Text": used for counting samples
        - "is_error": bool, True if prediction != actual
        - "p_true": float, probability assigned to the true label
        - "confidence": float, max predicted probability

    Returns
    -------
    pandas.DataFrame
        Index: Expected Author
        Columns: n, acc, mean_p_true, mean_conf
        Sorted by acc ascending.
    """
    stats = (
        df.groupby("Expected Author")
          .agg(n=("Text", "size"),
               acc=("is_error", lambda x: 1 - x.mean()),
               mean_p_true=("p_true", "mean"),
               mean_conf=("confidence", "mean"))
          .sort_values("acc")
    )
    return stats


In [None]:
# Full dataset
full_author_stats = author_stats_table(generated_texts_merged)

# GPT-3 only
gpt3_author_stats = author_stats_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3"])

# GPT-3 LoRA only
gpt3_lora_author_stats = author_stats_table(generated_texts_merged[generated_texts_merged["Model"] == "gpt3_lora"])


In [None]:
gpt3_author_stats

In [None]:
gpt3_lora_author_stats

In [None]:
# Full dataset
full_author_stats_h_c = author_stats_table(generated_texts_merged[generated_texts_merged["confidence"] > 0.93])

# GPT-3 only
gpt3_author_stats_h_c = author_stats_table(generated_texts_merged[(generated_texts_merged["Model"] == "gpt3") & (generated_texts_merged["confidence"] > 0.93)])

# GPT-3 LoRA only
gpt3_lora_author_stats_h_c = author_stats_table(generated_texts_merged[(generated_texts_merged["Model"] == "gpt3_lora") & (generated_texts_merged["confidence"] > 0.93)])


In [None]:
gpt3_author_stats_h_c

In [None]:
gpt3_lora_author_stats_h_c

In [None]:
# 9) Compare generation sources (gpt3 vs gpt3_lora):

In [None]:
def author_accuracy_by_model(df):
    """
    Compute per-author accuracy stats split by Model.

    Parameters
    ----------
    df : pandas.DataFrame
        Must contain:
        - "Model": str
        - "Expected Author": str
        - "is_error": bool
        - "p_true": float
        - "confidence": float

    Returns
    -------
    pandas.DataFrame
        Pivot table with Expected Author as index, Models as columns, and accuracy as values.
        Also returns the grouped stats before pivoting.
    """
    grouped = (
        df.groupby(["Model", "Expected Author"])
          .agg(acc=("is_error", lambda x: 1 - x.mean()),
               mean_p_true=("p_true", "mean"),
               mean_conf=("confidence", "mean"))
          .reset_index()
    )
    pivot_acc = grouped.pivot(index="Expected Author", columns="Model", values="acc")
    return grouped, pivot_acc


In [None]:
# Full dataset
grouped_stats, pivot_acc = author_accuracy_by_model(generated_texts_merged)

# Show detailed grouped stats
print(grouped_stats)

# Show accuracy pivot table
print(pivot_acc)

## ---- Explainability (RoBERTa):

In [None]:
!pip install captum numpy>=2.0

In [None]:

import torch
import numpy as np
import pandas as pd

# --- pick representative mistakes ---
def select_representative_errors(df, top_k_per_pair=3, min_conf=0.93):
    """
    High-confidence wrong predictions, grouped by (Expected Author -> Predicted Author).
    """
    wrong = df.query("is_error == True and confidence >= @min_conf").copy()
    if wrong.empty:
        print("No high-confidence errors found with the current threshold.")
        return wrong

    wrong["pair"] = list(zip(wrong["Expected Author"], wrong["Predicted Author"]))
    reps = (
        wrong.sort_values("confidence", ascending=False)
             .groupby("pair", group_keys=False)
             .head(top_k_per_pair)
    )
    return reps

# --- RoBERTa BPE aggregation: merge subword pieces into words ---
def merge_bpe_tokens(tokens, scores):
    """
    RoBERTa uses byte-level BPE; tokens starting with 'Ġ' indicate word starts.
    Merge subword pieces by summing their scores.
    """
    words, w_scores = [], []
    curr_word, curr_score = "", 0.0

    for tok, sc in zip(tokens, scores):
        # Skip special tokens
        if tok in ("<s>", "</s>", "<pad>"):
            continue

        if tok.startswith("Ġ"):  # new word
            if curr_word:
                words.append(curr_word)
                w_scores.append(curr_score)
            curr_word = tok[1:]  # drop the leading space marker
            curr_score = float(sc)
        else:
            # continuation of the current word (subword piece)
            curr_word += tok
            curr_score += float(sc)

    if curr_word:
        words.append(curr_word)
        w_scores.append(curr_score)

    # Build a dataframe sorted by absolute contribution
    out = pd.DataFrame({"word": words, "attr": w_scores})
    out["attr_abs"] = out["attr"].abs()
    out = out.sort_values("attr_abs", ascending=False).reset_index(drop=True)
    return out


In [None]:
from captum.attr import IntegratedGradients

def _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds):
    out = model(
        attention_mask=attention_mask,
        token_type_ids=token_type_ids if token_type_ids is not None else None,
        inputs_embeds=inputs_embeds
    )
    return out.logits

def attribute_example_with_ig(model, tokenizer, text, target_class, max_length=256, device=None, n_steps=50):
    model.eval()
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device)

    enc = tokenizer(
        text, truncation=True, max_length=max_length, return_tensors="pt"
    )
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    token_type_ids = enc.get("token_type_ids", None)
    if token_type_ids is not None:
        token_type_ids = token_type_ids.to(device)

    with torch.no_grad():
        inputs_embeds = model.roberta.embeddings.word_embeddings(input_ids)  # [1, L, 768]
        pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 1
        baseline_ids = torch.full_like(input_ids, pad_id)
        baseline_embeds = model.roberta.embeddings.word_embeddings(baseline_ids)

        logits = _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds)
        probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
        pred_class = int(probs.argmax())

    def _target_logit_from_embeds(inputs_embeds_):
        logits_ = _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds_)
        return logits_[:, target_class]

    ig = IntegratedGradients(_target_logit_from_embeds)

    # FIX: only one return value when return_convergence_delta=False
    attributions = ig.attribute(
        inputs=inputs_embeds,
        baselines=baseline_embeds,
        n_steps=n_steps,
        return_convergence_delta=False
    )

    token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()  # [L]
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).detach().cpu().tolist())
    return tokens, token_attr, probs, pred_class


In [None]:

def explain_row(df, row_idx, model, tokenizer, max_length=256, device=None, top_k=15):
    """
    Run IG for both predicted and true classes on one example from df.
    Returns:
      dict with:
        - meta (text, expected/predicted authors, probs)
        - pred_df (word attributions toward predicted class)
        - true_df (word attributions toward true class)
    """
    row = df.iloc[row_idx]
    text = row["Text"]
    true_cls = int(row["expected_label"])
    pred_cls = int(row["predicted_Label_roberta"])

    tokens_p, attr_p, probs_p, pred_check = attribute_example_with_ig(
        model, tokenizer, text, target_class=pred_cls, max_length=max_length, device=device
    )
    tokens_t, attr_t, probs_t, _ = attribute_example_with_ig(
        model, tokenizer, text, target_class=true_cls, max_length=max_length, device=device
    )

    # Merge subwords -> words
    pred_df = merge_bpe_tokens(tokens_p, attr_p)
    true_df = merge_bpe_tokens(tokens_t, attr_t)

    # Keep only top_k by |attr|
    pred_df_top = pred_df.head(top_k).copy()
    true_df_top = true_df.head(top_k).copy()

    # Meta
    meta = {
        "text": text,
        "expected_label": true_cls,
        "predicted_label": pred_cls,
        "Expected Author": row["Expected Author"],
        "Predicted Author": row["Predicted Author"],
        "probs": probs_p,  # same as probs_t; recomputed per run but same input
        "confidence": float(row["confidence"]),
        "is_error": bool(row["is_error"])
    }

    return {
        "meta": meta,
        "pred_df": pred_df_top,  # tokens driving the predicted class
        "true_df": true_df_top   # tokens driving the true class
    }


In [None]:

# Example: pick representative errors (very confident but wrong)
reps = select_representative_errors(generated_texts_merged, top_k_per_pair=2, min_conf=0.93)

# If you want a specific confusion pair:
# reps = generated_texts_merged.query("is_error and `Expected Author`=='Jane Austen' and `Predicted Author`=='Charles Dickens'").nlargest(3, 'confidence')

# Explain the first 3 examples
results = []
for idx in reps.index[:3]:
    out = explain_row(generated_texts_merged, idx, model, tokenizer, max_length=256, top_k=15)
    results.append(out)

# Inspect one result
r0 = results[0]
print("Example meta:", {k:v for k,v in r0["meta"].items() if k!="probs"})
print("\nTop tokens toward PREDICTED class:")
display(r0["pred_df"])  # in Colab/Jupyter this shows a nice table

print("\nTop tokens toward TRUE class:")
display(r0["true_df"])


In [None]:
from IPython.display import HTML, display
import matplotlib

def highlight_text(tokens, scores, cmap="RdBu", score_range=None):
    """
    Render tokens with background color proportional to attribution score.
    Positive scores = red, negative = blue (by default).
    """
    # Normalize scores for consistent coloring
    if score_range is None:
        max_abs = max(abs(scores.min()), abs(scores.max())) or 1e-9
    else:
        max_abs = score_range
    norm = matplotlib.colors.Normalize(vmin=-max_abs, vmax=max_abs)
    cmap_obj = matplotlib.cm.get_cmap(cmap)

    html_tokens = []
    for tok, sc in zip(tokens, scores):
        color = matplotlib.colors.rgb2hex(cmap_obj(norm(sc))[:3])
        html_tokens.append(f"<span style='background-color:{color}; padding:2px; margin:1px;'>{tok}</span>")
    return " ".join(html_tokens)

def show_attribution_text(text, tokenizer, word_attr_df, title="", cmap="RdBu"):
    """
    Given raw text and word-level attributions (word_attr_df from merge_bpe_tokens),
    re-tokenize to split into words for alignment with attr values.
    """
    words = word_attr_df["word"].tolist()
    scores = word_attr_df["attr"].tolist()
    html_str = highlight_text(words, np.array(scores), cmap=cmap)
    display(HTML(f"<div><b>{title}</b><br>{html_str}</div>"))


In [None]:
# Predicted class visualization
show_attribution_text(
    r0["meta"]["text"],
    tokenizer,
    r0["pred_df"],
    title=f"Evidence for Predicted: {r0['meta']['Predicted Author']}"
)

# True class visualization
show_attribution_text(
    r0["meta"]["text"],
    tokenizer,
    r0["true_df"],
    title=f"Evidence for True: {r0['meta']['Expected Author']}"
)

In [None]:
def differential_attribution(pred_df, true_df, k=15):
    # merge on word; fill missing with 0
    df = pred_df[["word","attr"]].merge(
        true_df[["word","attr"]], on="word", how="outer", suffixes=("_pred","_true")
    ).fillna(0.0)
    df["diff"] = df["attr_pred"] - df["attr_true"]
    df["diff_abs"] = df["diff"].abs()
    df = df.sort_values("diff_abs", ascending=False).head(k)
    return df[["word","attr_pred","attr_true","diff"]]

In [None]:
# Example:
# diff > 0: evidence for the predicted class.
# diff < 0: evidence for the true class.

diff = differential_attribution(r0["pred_df"], r0["true_df"], k=20)
display(diff)

In [None]:
import torch
import numpy as np

def mask_and_score(model, tokenizer, text, target_class, max_length=256, device=None, words_to_mask=None):
    """
    Replace chosen words with <mask> (roughly) and return target_class prob delta.
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device).eval()

    # crude word masking: replace exact substrings; for robust masking, map tokens instead
    masked_text = text
    if words_to_mask:
        for w in words_to_mask:
            if w and w.strip():
                masked_text = masked_text.replace(w, tokenizer.mask_token or "<mask>")

    enc = tokenizer(masked_text, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
    return probs[target_class], probs

In [None]:
# Example usage:
target_pred = r0["meta"]["predicted_label"]
p_with, _ = mask_and_score(model, tokenizer, r0["meta"]["text"], target_pred, words_to_mask=["Marianne"])
print("P(pred class) after masking:", p_with)


In [None]:
def explain_with_diff(df, row_idx, model, tokenizer, max_length=256, device=None, top_k=15, do_ablate=True):
    out = explain_row(df, row_idx, model, tokenizer, max_length=max_length, device=device, top_k=top_k)
    pred_df, true_df = out["pred_df"], out["true_df"]
    diff = differential_attribution(pred_df, true_df, k=top_k)

    print("META:", {k:v for k,v in out["meta"].items() if k not in ("probs",)})
    print("\n--- Top tokens toward PREDICTED class ---")
    display(pred_df)
    print("\n--- Top tokens toward TRUE class ---")
    display(true_df)
    print("\n--- Differential (pred - true) ---")
    display(diff)

    if do_ablate:
        text = out["meta"]["text"]
        pred_cls = out["meta"]["predicted_label"]
        # test top-3 separating tokens (positive diff)
        top_sep = diff.sort_values("diff", ascending=False).head(3)["word"].tolist()
        try:
            p_after, _ = mask_and_score(model, tokenizer, text, pred_cls, max_length=max_length, device=device, words_to_mask=top_sep)
            print(f"\nAblation: masking {top_sep} -> P(pred_class) = {p_after:.4f}")
        except Exception as e:
            print("Ablation error:", e)

    return out, diff


In [None]:

# Example:
_ = explain_with_diff(generated_texts_merged, reps.index[0], model, tokenizer, max_length=256)


## ---- Explainability (DeBERTa):

In [None]:
!pip install captum numpy>=2.0

In [None]:
def get_word_embeddings_module(model):
    """
    Return the word embedding module for common HF encoder models.
    Supports: roberta, deberta, bert, distilbert, albert, electra.
    """
    for backbone_name in ["roberta", "deberta", "bert", "distilbert", "albert", "electra"]:
        if hasattr(model, backbone_name):
            backbone = getattr(model, backbone_name)
            if hasattr(backbone, "embeddings") and hasattr(backbone.embeddings, "word_embeddings"):
                return backbone.embeddings.word_embeddings
    # Some models may have embeddings directly
    if hasattr(model, "embeddings") and hasattr(model.embeddings, "word_embeddings"):
        return model.embeddings.word_embeddings
    raise AttributeError("Could not locate embeddings.word_embeddings on this model.")


In [None]:

import torch
import numpy as np
import pandas as pd

# --- pick representative mistakes ---
def select_representative_errors(df, top_k_per_pair=3, min_conf=0.93):
    """
    High-confidence wrong predictions, grouped by (Expected Author -> Predicted Author).
    """
    wrong = df.query("is_error == True and confidence >= @min_conf").copy()
    if wrong.empty:
        print("No high-confidence errors found with the current threshold.")
        return wrong

    wrong["pair"] = list(zip(wrong["Expected Author"], wrong["Predicted Author"]))
    reps = (
        wrong.sort_values("confidence", ascending=False)
             .groupby("pair", group_keys=False)
             .head(top_k_per_pair)
    )
    return reps

def merge_subword_tokens(tokens, scores):
    """
    Merge subwords to words for both byte-BPE (Ġword) and SentencePiece (▁word).
    - Starts of words: tokens starting with 'Ġ' (RoBERTa) or '▁' (SentencePiece).
    - Special tokens skipped: <s>, </s>, <pad>, [CLS], [SEP], [PAD]
    Aggregates scores by summing subpieces.
    """
    specials = {"<s>", "</s>", "<pad>", "[CLS]", "[SEP]", "[PAD]"}
    words, w_scores = [], []
    curr_word, curr_score = "", 0.0

    def flush():
        nonlocal curr_word, curr_score
        if curr_word:
            words.append(curr_word)
            w_scores.append(curr_score)
            curr_word, curr_score = "", 0.0

    for tok, sc in zip(tokens, scores):
        if tok in specials:
            continue
        # Word starts: leading marker
        if tok.startswith("Ġ") or tok.startswith("▁"):
            flush()
            base = tok[1:]  # drop marker
            curr_word = base
            curr_score = float(sc)
        else:
            # continuation piece
            curr_word += tok
            curr_score += float(sc)
    flush()

    out = pd.DataFrame({"word": words, "attr": w_scores})
    out["attr_abs"] = out["attr"].abs()
    out = out.sort_values("attr_abs", ascending=False).reset_index(drop=True)
    return out


In [None]:
from captum.attr import IntegratedGradients

def _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds):
    # Most HF sequence classification models accept inputs_embeds directly
    out = model(
        attention_mask=attention_mask,
        token_type_ids=token_type_ids if token_type_ids is not None else None,
        inputs_embeds=inputs_embeds
    )
    return out.logits


def attribute_example_with_ig(model, tokenizer, text, target_class, max_length=256, device=None, n_steps=50):
    """
    Compute IG token attributions for one text toward 'target_class' logit.
    Works with DeBERTa/RoBERTa/BERT/etc.
    """
    model.eval()
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device)

    enc = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc.get("attention_mask", None)
    attention_mask = attention_mask.to(device) if attention_mask is not None else None
    token_type_ids = enc.get("token_type_ids", None)
    token_type_ids = token_type_ids.to(device) if token_type_ids is not None else None

    # Get embeddings module generically
    we = get_word_embeddings_module(model)

    with torch.no_grad():
        inputs_embeds = we(input_ids)  # [1, L, H]
        pad_id = tokenizer.pad_token_id
        if pad_id is None:
            # Fallback: many SentencePiece models use 0 as pad (DeBERTa often does)
            pad_id = 0
        baseline_ids = torch.full_like(input_ids, pad_id)
        baseline_embeds = we(baseline_ids)

        logits = _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds)
        probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
        pred_class = int(probs.argmax())

    def _target_logit_from_embeds(inputs_embeds_):
        logits_ = _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds_)
        return logits_[:, target_class]

    ig = IntegratedGradients(_target_logit_from_embeds)

    attributions = ig.attribute(
        inputs=inputs_embeds,
        baselines=baseline_embeds,
        n_steps=n_steps,
        return_convergence_delta=False
    )

    token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()  # [L]
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).detach().cpu().tolist())
    return tokens, token_attr, probs, pred_class


In [None]:
def explain_row(df, row_idx, model, tokenizer, max_length=256, device=None, top_k=15):
    row = df.iloc[row_idx]
    text = row["Text"]
    true_cls = int(row["expected_label"])
    pred_cls = int(row["predicted_Label_roberta"])

    toks_p, attr_p, probs_p, _ = attribute_example_with_ig(
        model, tokenizer, text, target_class=pred_cls, max_length=max_length, device=device
    )
    toks_t, attr_t, probs_t, _ = attribute_example_with_ig(
        model, tokenizer, text, target_class=true_cls, max_length=max_length, device=device
    )

    # ORIGINAL ORDER (for visualization)
    pred_df_ordered = merge_subword_tokens(toks_p, attr_p, sort_by_abs=False)
    true_df_ordered = merge_subword_tokens(toks_t, attr_t, sort_by_abs=False)

    # SORTED (for top-k tables, optional)
    pred_df_sorted = merge_subword_tokens(toks_p, attr_p, sort_by_abs=True).head(top_k)
    true_df_sorted = merge_subword_tokens(toks_t, attr_t, sort_by_abs=True).head(top_k)

    meta = {
        "text": text,
        "expected_label": true_cls,
        "predicted_label": pred_cls,
        "Expected Author": row["Expected Author"],
        "Predicted Author": row["Predicted Author"],
        "probs": probs_p,
        "confidence": float(row["confidence"]),
        "is_error": bool(row["is_error"])
    }
    return {
        "meta": meta,
        "pred_ordered": pred_df_ordered,
        "true_ordered": true_df_ordered,
        "pred_df": pred_df_sorted,
        "true_df": true_df_sorted
    }


In [None]:
import pandas as pd

def merge_subword_tokens(tokens, scores, *, sort_by_abs=False):
    """
    Merge subwords to words for both byte-BPE (Ġword) and SentencePiece (▁word).
    Preserves original order by default. If sort_by_abs=True, sorts by |attr|.
    """
    specials = {"<s>", "</s>", "<pad>", "[CLS]", "[SEP]", "[PAD]"}
    words, w_scores = [], []
    curr_word, curr_score = "", 0.0

    def flush():
        nonlocal curr_word, curr_score
        if curr_word:
            words.append(curr_word)
            w_scores.append(curr_score)
            curr_word, curr_score = "", 0.0

    for tok, sc in zip(tokens, scores):
        if tok in specials:
            continue
        if tok.startswith("Ġ") or tok.startswith("▁"):  # new word
            flush()
            base = tok[1:]
            curr_word = base
            curr_score = float(sc)
        else:  # continuation
            curr_word += tok
            curr_score += float(sc)
    flush()

    df = pd.DataFrame({"word": words, "attr": w_scores})
    if sort_by_abs:
        df["attr_abs"] = df["attr"].abs()
        df = df.sort_values("attr_abs", ascending=False).reset_index(drop=True)
    return df


In [None]:
# Example: pick representative errors (very confident but wrong)
reps = select_representative_errors(generated_texts_merged, top_k_per_pair=2, min_conf=0.93)

# If you want a specific confusion pair:
# reps = generated_texts_merged.query("is_error and `Expected Author`=='Jane Austen' and `Predicted Author`=='Charles Dickens'").nlargest(3, 'confidence')

# Explain the first 3 examples
results = []
for idx in reps.index[:3]:
    out = explain_row(generated_texts_merged, idx, model, tokenizer, max_length=256, top_k=15)
    results.append(out)

# Inspect one result
r0 = results[0]
print("Example meta:", {k:v for k,v in r0["meta"].items() if k!="probs"})
print("\nTop tokens toward PREDICTED class:")
display(r0["pred_df"])  # in Colab/Jupyter this shows a nice table

print("\nTop tokens toward TRUE class:")
display(r0["true_df"])


In [None]:
from IPython.display import HTML, display
import numpy as np
import matplotlib

def highlight_text(tokens, scores, cmap="RdBu", score_range=None):
    """
    Render tokens with background color proportional to attribution score.
    Positive = red, negative = blue.
    """
    # New: use the non-deprecated API
    cmap_obj = matplotlib.colormaps.get_cmap(cmap)

    scores = np.asarray(scores, dtype=float)
    if score_range is None:
        max_abs = float(np.max(np.abs(scores))) or 1e-9
    else:
        max_abs = float(score_range)

    norm = matplotlib.colors.Normalize(vmin=-max_abs, vmax=max_abs)
    html_tokens = []
    for tok, sc in zip(tokens, scores):
        color = matplotlib.colors.rgb2hex(cmap_obj(norm(sc))[:3])
        safe_tok = tok.replace("&","&amp;").replace("<","&lt;").replace(">","&gt;")
        html_tokens.append(
            f"<span style='background-color:{color}; padding:2px 3px; margin:1px; border-radius:3px;'>{safe_tok}</span>"
        )
    return " ".join(html_tokens)

def show_attribution_text_ordered(word_attr_df, title="", cmap="RdBu", score_range=None):
    """
    Display words in their ORIGINAL order (no re-sorting).
    Expects a DataFrame with columns ['word','attr'] in original order.
    """
    html_str = highlight_text(word_attr_df["word"].tolist(),
                              word_attr_df["attr"].tolist(),
                              cmap=cmap, score_range=score_range)
    display(HTML(f"<div style='line-height:2'><b>{title}</b><br>{html_str}</div>"))


In [None]:
# Example with r0 (from your earlier results)
show_attribution_text_ordered(
    r0["pred_ordered"],
    title=f"Evidence for Predicted (original order): {r0['meta']['Predicted Author']}"
)

show_attribution_text_ordered(
    r0["true_ordered"],
    title=f"Evidence for True (original order): {r0['meta']['Expected Author']}"
)


In [None]:
from IPython.display import HTML, display

def side_by_side_ordered(pred_df_ordered, true_df_ordered, pred_title, true_title, cmap="RdBu"):
    html_left  = highlight_text(pred_df_ordered["word"].tolist(), pred_df_ordered["attr"].tolist(), cmap=cmap)
    html_right = highlight_text(true_df_ordered["word"].tolist(), true_df_ordered["attr"].tolist(), cmap=cmap)
    display(HTML(f"""
    <div style="display:flex; gap:24px">
      <div style="flex:1; line-height:2">
        <div><b>{pred_title}</b></div>
        <div>{html_left}</div>
      </div>
      <div style="flex:1; line-height:2">
        <div><b>{true_title}</b></div>
        <div>{html_right}</div>
      </div>
    </div>
    """))

# Usage:
# side_by_side_ordered(
#     r0["pred_ordered"], r0["true_ordered"],
#     f"Predicted: {r0['meta']['Predicted Author']}",
#     f"True: {r0['meta']['Expected Author']}"
# )


In [None]:
# Usage:
side_by_side_ordered(
     r0["pred_ordered"], r0["true_ordered"],
     f"Predicted: {r0['meta']['Predicted Author']}",
     f"True: {r0['meta']['Expected Author']}"
 )

In [None]:
import torch
import numpy as np

# your label map (same order you trained with)
label_map = {
    0: "Charles Dickens",
    1: "Jane Austen",
    2: "Mark Twain",
    3: "Louisa May Alcott",
    4: "Herman Melville",
}

text = "made the place the scene of their holiday entertainment."

model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# tokenize exactly as in training
enc = tokenizer(text, truncation=True, max_length=256, return_tensors="pt").to(device)

with torch.inference_mode():
    logits = model(**enc).logits            # [1, 5]
    probs  = torch.softmax(logits, dim=-1)  # [1, 5]

pred_idx = int(torch.argmax(probs, dim=-1).item())
pred_author = label_map[pred_idx]
probs_np = probs.squeeze(0).cpu().numpy()

print("Predicted:", pred_author, "\n")
for i, p in enumerate(probs_np):
    print(f"{label_map[i]:20s}  {p:.4f}")


In [None]:
def predict_texts(texts, model, tokenizer, max_length=256):
    model.eval()
    device = next(model.parameters()).device
    enc = tokenizer(
        texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt"
    ).to(device)
    with torch.inference_mode():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
    preds = probs.argmax(axis=1)
    authors = [label_map[i] for i in preds]
    return authors, probs  # list[str], np.ndarray [N,5]

authors, probs = predict_texts([text], model, tokenizer)
print("Predicted:", authors[0])


In [None]:
def differential_attribution(pred_df, true_df, k=15):
    # merge on word; fill missing with 0
    df = pred_df[["word","attr"]].merge(
        true_df[["word","attr"]], on="word", how="outer", suffixes=("_pred","_true")
    ).fillna(0.0)
    df["diff"] = df["attr_pred"] - df["attr_true"]
    df["diff_abs"] = df["diff"].abs()
    df = df.sort_values("diff_abs", ascending=False).head(k)
    return df[["word","attr_pred","attr_true","diff"]]

In [None]:
# Example:
# diff > 0: evidence for the predicted class.
# diff < 0: evidence for the true class.

diff = differential_attribution(r0["pred_df"], r0["true_df"], k=20)
display(diff)

In [None]:
import torch
import numpy as np

def mask_and_score(model, tokenizer, text, target_class, max_length=256, device=None, words_to_mask=None):
    """
    Replace chosen words with <mask> (roughly) and return target_class prob delta.
    """
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device).eval()

    # crude word masking: replace exact substrings; for robust masking, map tokens instead
    masked_text = text
    if words_to_mask:
        for w in words_to_mask:
            if w and w.strip():
                masked_text = masked_text.replace(w, tokenizer.mask_token or "<mask>")

    enc = tokenizer(masked_text, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    with torch.no_grad():
        logits = model(**enc).logits
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
    return probs[target_class], probs

In [None]:
# Example usage:
target_pred = r0["meta"]["predicted_label"]
p_with, _ = mask_and_score(model, tokenizer, r0["meta"]["text"], target_pred, words_to_mask=["Marianne"])
print("P(pred class) after masking:", p_with)


In [None]:
def explain_with_diff(df, row_idx, model, tokenizer, max_length=256, device=None, top_k=15, do_ablate=True):
    out = explain_row(df, row_idx, model, tokenizer, max_length=max_length, device=device, top_k=top_k)
    pred_df, true_df = out["pred_df"], out["true_df"]
    diff = differential_attribution(pred_df, true_df, k=top_k)

    print("META:", {k:v for k,v in out["meta"].items() if k not in ("probs",)})
    print("\n--- Top tokens toward PREDICTED class ---")
    display(pred_df)
    print("\n--- Top tokens toward TRUE class ---")
    display(true_df)
    print("\n--- Differential (pred - true) ---")
    display(diff)

    if do_ablate:
        text = out["meta"]["text"]
        pred_cls = out["meta"]["predicted_label"]
        # test top-3 separating tokens (positive diff)
        top_sep = diff.sort_values("diff", ascending=False).head(3)["word"].tolist()
        try:
            p_after, _ = mask_and_score(model, tokenizer, text, pred_cls, max_length=max_length, device=device, words_to_mask=top_sep)
            print(f"\nAblation: masking {top_sep} -> P(pred_class) = {p_after:.4f}")
        except Exception as e:
            print("Ablation error:", e)

    return out, diff


In [None]:

# Example:
_ = explain_with_diff(generated_texts_merged, reps.index[0], model, tokenizer, max_length=256)


## ---- Explainability (global):

In [None]:
# Differential attribution for the pair

In [None]:
# ---- CONFIG ----
PAIR_TRUE, PAIR_PRED = 1, 2      # Austen=1, Twain=2  (you can flip later)
MAXLEN = 256
N_PER_SIDE = 150                 # how many errors to sample per confusion direction
TOP_K_TOKENS = 20                # how many tokens to show in the bars

# sample errors for both directions
a_to_t = generated_texts_merged.query("expected_label==@PAIR_TRUE and predicted_Label_roberta==@PAIR_PRED")
t_to_a = generated_texts_merged.query("expected_label==@PAIR_PRED and predicted_Label_roberta==@PAIR_TRUE")

a_to_t = a_to_t.nlargest(N_PER_SIDE, "confidence") if len(a_to_t)>N_PER_SIDE else a_to_t
t_to_a = t_to_a.nlargest(N_PER_SIDE, "confidence") if len(t_to_a)>N_PER_SIDE else t_to_a

# helper to run IG on many rows and aggregate word-level diff
from collections import defaultdict

def aggregate_pair_diff(rows, true_cls, pred_cls, model, tokenizer, max_length=256, device=None):
    token2sum = defaultdict(float)
    token2cnt = defaultdict(int)
    for idx in rows.index:
        out = explain_row(generated_texts_merged, idx, model, tokenizer, max_length=max_length, device=device, top_k=999)
        pred_df, true_df = out["pred_df"], out["true_df"]
        diff = (pred_df[["word","attr"]].merge(true_df[["word","attr"]], on="word", how="outer",
                                               suffixes=("_pred","_true")).fillna(0.0))
        diff["d"] = diff["attr_pred"] - diff["attr_true"]
        for w, d in zip(diff["word"], diff["d"]):
            token2sum[w] += float(d)
            token2cnt[w] += 1
    # average diff per token
    items = [(w, token2sum[w]/max(token2cnt[w],1), token2cnt[w]) for w in token2sum]
    df = pd.DataFrame(items, columns=["word","mean_diff","count"])
    # filter trivial tokens
    df = df[~df["word"].str.match(r"^(\W+|)$")]
    return df.sort_values("mean_diff", ascending=False)

# aggregate for both directions
df_a_to_t = aggregate_pair_diff(a_to_t, true_cls=PAIR_TRUE, pred_cls=PAIR_PRED, model=model, tokenizer=tokenizer, max_length=MAXLEN)
df_t_to_a = aggregate_pair_diff(t_to_a, true_cls=PAIR_PRED, pred_cls=PAIR_TRUE, model=model, tokenizer=tokenizer, max_length=MAXLEN)

# plot top tokens pushing toward Twain (positive) and toward Austen (negative)
import matplotlib.pyplot as plt
def plot_top_tokens(df, title, k=TOP_K_TOKENS):
    top_pos = df.head(k)
    top_neg = df.tail(k).sort_values("mean_diff")
    fig, ax = plt.subplots(figsize=(8,6))
    y = list(top_pos["word"]) + list(top_neg["word"])
    x = list(top_pos["mean_diff"]) + list(top_neg["mean_diff"])
    ax.barh(y, x)
    ax.set_title(title)
    ax.set_xlabel("Mean differential attribution (pred - true)")
    ax.axvline(0, linestyle="--")
    plt.tight_layout()
    plt.show()

plot_top_tokens(df_a_to_t, "Austen→Twain errors: tokens pushing Twain (+) vs Austen (-)")
plot_top_tokens(df_t_to_a, "Twain→Austen errors: tokens pushing Austen (+) vs Twain (-)")


In [None]:
# Token frequency vs. differential attribution

In [None]:
def scatter_freq_vs_diff(df_diff, title):
    # approximate frequency from 'count' (how often token contributed across examples)
    freq = df_diff["count"].astype(float)
    mean_diff = df_diff["mean_diff"].astype(float)

    import matplotlib.pyplot as plt
    plt.figure(figsize=(7,6))
    plt.scatter(freq, mean_diff, alpha=0.6)
    plt.xlabel("Token frequency in confused examples")
    plt.ylabel("Mean differential attribution (pred - true)")
    plt.title(title)
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.show()

scatter_freq_vs_diff(df_a_to_t, "Austen→Twain: frequency vs differential attribution")
scatter_freq_vs_diff(df_t_to_a, "Twain→Austen: frequency vs differential attribution")


In [None]:
# Prototype & counter-prototype snippets (gallery)

In [None]:
def top_prototypes(df_pair, k=6):
    # already filtered to a confusion pair; pick by highest confidence
    return df_pair.nlargest(k, "confidence")[["Text","Expected Author","Predicted Author","confidence"]]

# Example galleries:
proto_a_to_t = top_prototypes(a_to_t, k=6)
proto_t_to_a = top_prototypes(t_to_a, k=6)
proto_a_to_t.head(), proto_t_to_a.head()

# (Optional) For each text in the gallery, call explain_row(...) then show_attribution_text(...) as you did.


In [None]:
# Counterfactual impact of top tokens (average Δprob)

In [None]:
def average_token_impact(df_rows, target_class, tokens, model, tokenizer, max_length=256, device=None, sample_n=100):
    rows = df_rows.nlargest(sample_n, "confidence") if len(df_rows)>sample_n else df_rows
    impacts = []
    for t in tokens:
        deltas = []
        for _, r in rows.iterrows():
            p0, _ = mask_and_score(model, tokenizer, r["Text"], target_class, max_length=max_length, device=device, words_to_mask=[])
            p1, _ = mask_and_score(model, tokenizer, r["Text"], target_class, max_length=max_length, device=device, words_to_mask=[t])
            deltas.append(p0 - p1)  # drop in prob when masking token t
        impacts.append((t, float(np.mean(deltas)), len(deltas)))
    out = pd.DataFrame(impacts, columns=["token","mean_delta_prob","n"])
    return out.sort_values("mean_delta_prob", ascending=False)

# Example: take top 15 Twain-pushing tokens from Austen→Twain errors and measure their avg impact on P(Twain)
twain_tokens = df_a_to_t.head(15)["word"].tolist()
impact_tbl = average_token_impact(a_to_t, target_class=2, tokens=twain_tokens, model=model, tokenizer=tokenizer, max_length=MAXLEN, sample_n=80)
impact_tbl.head(15)


## ---- IntegratedGradients:

In [None]:
# Step 1 – Setup

!pip install captum
import torch
import numpy as np
import pandas as pd
from captum.attr import IntegratedGradients
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval().to(device)


In [None]:
def get_word_embeddings_module(model):
    for name in ["deberta", "roberta", "bert", "distilbert", "albert", "electra"]:
        if hasattr(model, name):
            emb = getattr(model, name).embeddings
            if hasattr(emb, "word_embeddings"):
                return emb.word_embeddings
    # Fallback
    if hasattr(model, "embeddings") and hasattr(model.embeddings, "word_embeddings"):
        return model.embeddings.word_embeddings
    raise AttributeError("Could not locate embeddings.word_embeddings on this model.")

def _forward_from_embeds(model, attention_mask, token_type_ids, inputs_embeds):
    out = model(
        attention_mask=attention_mask,
        token_type_ids=token_type_ids if token_type_ids is not None else None,
        inputs_embeds=inputs_embeds,
    )
    return out.logits


In [None]:
# Attribution for a Single Sentence:

def get_token_attributions_embeds(model, tokenizer, text, target_class, max_length=128, n_steps=32, device=None):
    model.eval()
    device = device or (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
    model.to(device)

    enc = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc.get("attention_mask", None)
    attention_mask = attention_mask.to(device) if attention_mask is not None else None
    token_type_ids = enc.get("token_type_ids", None)
    token_type_ids = token_type_ids.to(device) if token_type_ids is not None else None

    we = get_word_embeddings_module(model)

    with torch.no_grad():
        inputs_embeds  = we(input_ids)                         # [1, L, H]
        pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
        baseline_ids   = torch.full_like(input_ids, pad_id)
        baseline_embeds = we(baseline_ids)

    def target_logit_from_embeds(inp_emb):
        logits = _forward_from_embeds(model, attention_mask, token_type_ids, inp_emb)
        return logits[:, target_class]

    ig = IntegratedGradients(target_logit_from_embeds)
    # returns a single tensor when return_convergence_delta=False
    attributions = ig.attribute(
        inputs=inputs_embeds,
        baselines=baseline_embeds,
        n_steps=n_steps,
        return_convergence_delta=False,
    )

    # reduce embedding-dim → per-token score
    token_attr = attributions.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0).detach().cpu().tolist())
    return tokens, token_attr


In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

def global_author_attributions(df, model, tokenizer,
                               sample_size=None,   # ← None = use ALL rows
                               max_length=128, n_steps=32,
                               by="Expected Author"):
    """
    Aggregate token attributions per author on df (optionally filtered before).
    by: "Expected Author" (ground-truth view) or "Predicted Author" (model-view).
    """
    import numpy as np, pandas as pd
    from tqdm import tqdm

    authors = df[by].unique().tolist()
    agg = {a: {} for a in authors}

    if sample_size is None:
        sampled = df
    else:
        sampled = df.sample(min(sample_size, len(df)), random_state=42)

    for _, row in tqdm(sampled.iterrows(), total=len(sampled)):
        text = row["Text"]
        target_cls = int(row["expected_label"]) if by == "Expected Author" else int(row["predicted_Label_roberta"])

        toks, attrs = get_token_attributions_embeds(
            model, tokenizer, text,
            target_class=target_cls, max_length=max_length, n_steps=n_steps
        )

        for tok, val in zip(toks, attrs):
            if tok in {"<s>", "</s>", "<pad>", "[CLS]", "[SEP]", "[PAD]"}:
                continue
            if tok.startswith("▁") or tok.startswith("Ġ"):
                tok = tok[1:]
            if not tok:
                continue
            a = row[by]
            agg[a][tok] = agg[a].get(tok, 0.0) + float(val)

    author_dfs = {}
    for a in authors:
        df_tok = pd.DataFrame(list(agg[a].items()), columns=["token", "total_attr"])
        df_tok["abs_attr"] = df_tok["total_attr"].abs()
        author_dfs[a] = df_tok.sort_values("abs_attr", ascending=False).reset_index(drop=True)
    return author_dfs



In [None]:
import matplotlib.pyplot as plt

def plot_top_tokens(author_df, author_name, top_n=30):
    top_df = author_df.head(top_n)
    plt.figure(figsize=(4,6))
    plt.barh(top_df["token"], top_df["total_attr"])
    plt.gca().invert_yaxis()
    plt.title(f"Top tokens influencing: {author_name}")
    plt.xlabel("Aggregated attribution")
    plt.tight_layout()
    plt.show()


In [None]:
high_conf_df = generated_texts_merged[generated_texts_merged["confidence"] > 0.93].copy()
print("High‑confidence samples:", len(high_conf_df))

In [None]:
# Step 5 – Run

author_results_all = global_author_attributions(
    high_conf_df, model, tokenizer,
    sample_size=None,      # ← use ALL high‑confidence rows
    max_length=128,        # increase to 256 if you can afford it
    n_steps=32,            # 16–32 is a good trade‑off
    by="Expected Author"   # or "Predicted Author" for model‑centric view
)

In [None]:
# Save
import pickle


# Save dictionary to a pickle file
with open("author_results_all.pkl", "wb") as f:
    pickle.dump(author_results_all, f)

In [None]:
# Load
import pickle

with open("author_results_all.pkl", "rb") as f:
    author_results_all = pickle.load(f)

In [None]:
for author, df_tok in author_results_all.items():
    plot_top_tokens(df_tok, author, top_n=15)

## ---- XAI Generators:

In [None]:
!pip install captum

In [None]:
import torch, math, random
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
from torch.nn import functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# Your paths
model = GPTNeoForCausalLM.from_pretrained(drive_model_path).to(device).eval()
tokenizer = GPT2Tokenizer.from_pretrained(drive_tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token  # safety

TAGS = ["<0>", "<1>", "<2>", "<3>", "<4>"]  # Dickens, Austen, Twain, Alcott, Melville

In [None]:
# (Optional) ensure the model is set to return attentions by default
model.config.output_attentions = True

@torch.no_grad()
def attention_to_from_tag_better(prompt_text, max_new_tokens=40):
    full = generate_text(prompt_text, max_new_tokens=max_new_tokens)

    enc = tokenizer(full, return_tensors="pt").to(model.device)
    out = model(**enc, output_attentions=True)
    attns = out.attentions  # tuple of [B,H,T,T]
    attn_mask = enc["attention_mask"][0].bool()  # [T]
    T = int(attn_mask.sum().item())              # exclude PAD at the tail
    valid_idx = torch.arange(enc["input_ids"].size(1), device=model.device)[attn_mask]

    # Find tag span at the start (like before)
    text_stripped = prompt_text.lstrip()
    tag = next((t for t in TAGS if text_stripped.startswith(t)), None)
    if tag is None:
        tag = text_stripped.split()[0] if text_stripped else TAGS[0]
    tag_ids = tokenizer(tag, add_special_tokens=False)["input_ids"]
    if prompt_text.startswith(tag + " "):
        tag_ids_space = tokenizer(tag + " ", add_special_tokens=False)["input_ids"]
        if tag_ids_space == enc["input_ids"][0, :len(tag_ids_space)].tolist():
            tag_ids = tag_ids_space

    tag_start, tag_end = 0, len(tag_ids)  # exclusive
    tag_len = tag_end - tag_start

    layer_stats = []
    for A in attns:  # [B,H,T,T]
        A = A.squeeze(0)[:, attn_mask, :][:, :, attn_mask]  # [H, T, T] over valid tokens only
        A_mean = A.mean(dim=0)  # average over heads -> [T,T]

        # ---- TO TAG ----
        to_tag_mass = A_mean[:, tag_start:tag_end].mean().item()
        baseline_to_tag = tag_len / T
        to_tag_enrichment = to_tag_mass / baseline_to_tag if baseline_to_tag > 0 else float('nan')

        # ---- FROM TAG (meaningful) ----
        tag_rows = A_mean[tag_start:tag_end, :]  # [tag_len, T]
        # exclude the tag columns when measuring broadcast
        non_tag_cols = torch.cat([torch.arange(0, tag_start, device=model.device),
                                  torch.arange(tag_end, T, device=model.device)])
        if non_tag_cols.numel() > 0:
            from_tag_non_tag_mean = tag_rows[:, non_tag_cols].mean().item()
            from_tag_max = tag_rows[:, non_tag_cols].max().item()
        else:
            from_tag_non_tag_mean = float('nan')
            from_tag_max = float('nan')

        layer_stats.append({
            "to_tag": to_tag_mass,
            "to_tag_enrichment": to_tag_enrichment,
            "from_tag_non_tag_mean": from_tag_non_tag_mean,
            "from_tag_max": from_tag_max,
            "T": T,
            "tag_len": tag_len
        })

    return full, layer_stats


In [None]:
# Example usage
full, stats = attention_to_from_tag_better("<0>")
for L, s in enumerate(stats):
    print(f"Layer {L:02d}: to_tag={s['to_tag']:.4f} (x{s['to_tag_enrichment']:.1f}), "
          f"from_tag_mean(non-tag)={s['from_tag_non_tag_mean']:.4f}, "
          f"from_tag_max={s['from_tag_max']:.4f}")
print("\nSample:\n", full)


In [None]:
# Example usage
full, stats = attention_to_from_tag_better("<1>")
for L, s in enumerate(stats):
    print(f"Layer {L:02d}: to_tag={s['to_tag']:.4f} (x{s['to_tag_enrichment']:.1f}), "
          f"from_tag_mean(non-tag)={s['from_tag_non_tag_mean']:.4f}, "
          f"from_tag_max={s['from_tag_max']:.4f}")
print("\nSample:\n", full)


In [None]:
# Example usage
full, stats = attention_to_from_tag_better("<2> ")
for L, s in enumerate(stats):
    print(f"Layer {L:02d}: to_tag={s['to_tag']:.4f} (x{s['to_tag_enrichment']:.1f}), "
          f"from_tag_mean(non-tag)={s['from_tag_non_tag_mean']:.4f}, "
          f"from_tag_max={s['from_tag_max']:.4f}")
print("\nSample:\n", full)


In [None]:
# Example usage
full, stats = attention_to_from_tag_better("<3> ")
for L, s in enumerate(stats):
    print(f"Layer {L:02d}: to_tag={s['to_tag']:.4f} (x{s['to_tag_enrichment']:.1f}), "
          f"from_tag_mean(non-tag)={s['from_tag_non_tag_mean']:.4f}, "
          f"from_tag_max={s['from_tag_max']:.4f}")
print("\nSample:\n", full)


In [None]:
# Example usage
full, stats = attention_to_from_tag_better("<4> ")
for L, s in enumerate(stats):
    print(f"Layer {L:02d}: to_tag={s['to_tag']:.4f} (x{s['to_tag_enrichment']:.1f}), "
          f"from_tag_mean(non-tag)={s['from_tag_non_tag_mean']:.4f}, "
          f"from_tag_max={s['from_tag_max']:.4f}")
print("\nSample:\n", full)


In [None]:
################################

In [None]:
import torch
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
from captum.attr import IntegratedGradients

device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval();

In [None]:
# Generation (explain one sample)

@torch.no_grad()
def generate_once(prompt, max_new_tokens=40, temperature=0.8, top_p=0.95, seed=42):
    torch.manual_seed(seed)
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    out_ids = model.generate(
        **enc,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id
    )[0]
    text = tokenizer.decode(out_ids, skip_special_tokens=True)
    return out_ids.unsqueeze(0), text  # [1, T], string


# Integrated Gradients

def build_baseline_and_inputs(input_ids, prompt_len):
    with torch.no_grad():
        base_no_grad = model.transformer.wte(input_ids)  # [1, T, d]
    # Captum will backprop through 'inputs' only; they must require_grad=True
    inputs_embeds = base_no_grad.clone().detach().requires_grad_(True)

    baseline = torch.zeros_like(inputs_embeds)
    # keep generated part fixed in both baseline and inputs to focus attribution on prompt
    baseline[:, prompt_len:, :] = base_no_grad[:, prompt_len:, :]
    return inputs_embeds, baseline

def token_logprob_from_embeds(embeds, attention_mask, target_index, target_token_id):
    """
    Returns shape (1,) tensor with log P(x[target_index] | x[:target_index]).
    """
    outputs = model(inputs_embeds=embeds, attention_mask=attention_mask)
    logits = outputs.logits  # [1, T, V]
    logprobs = F.log_softmax(logits[:, :-1, :], dim=-1)
    pred_pos = target_index - 1
    lp_scalar = logprobs[0, pred_pos, target_token_id]  # 0-dim scalar
    return lp_scalar.unsqueeze(0)  # -> shape (1,)

from captum.attr import IntegratedGradients

def compute_ig_matrix_for_generation(input_ids, prompt_len, n_steps=32, internal_bs=1):
    attention_mask = torch.ones_like(input_ids, device=input_ids.device)
    T = input_ids.size(1)
    gen_start = prompt_len
    gen_len = T - gen_start

    all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    prompt_tokens = all_tokens[:prompt_len]
    gen_tokens = all_tokens[gen_start:]

    # embeddings to attribute + baseline (both on device)
    inputs_embeds, baseline = build_baseline_and_inputs(input_ids, prompt_len)

    ig_matrix = np.zeros((prompt_len, gen_len), dtype=np.float32)

    for j_abs in range(gen_start, T):
        target_token_id = input_ids[0, j_abs].item()

        # forward that returns (1,) tensor for Captum
        def forward_scalar(embeds):
            return token_logprob_from_embeds(
                embeds, attention_mask, j_abs, target_token_id
            )

        ig = IntegratedGradients(forward_scalar)

        # IMPORTANT: pass inputs that require grad; baseline need not require grad
        attributions, _ = ig.attribute(
            inputs=inputs_embeds,
            baselines=baseline,
            n_steps=n_steps,
            internal_batch_size=internal_bs,
            return_convergence_delta=True
        )

        tok_scores = attributions.norm(dim=-1).squeeze(0)  # [T]
        ig_matrix[:, j_abs - gen_start] = tok_scores[:prompt_len].detach().cpu().numpy()

    return ig_matrix, prompt_tokens, gen_tokens


# Plotting utilities

def plot_heatmap(ig_matrix, prompt_tokens, gen_tokens, title="Token Attribution Heatmap (IG)"):
    plt.figure(figsize=(min(18, 2 + 0.3 * len(gen_tokens)), min(10, 2 + 0.25 * len(prompt_tokens))))
    im = plt.imshow(ig_matrix, aspect='auto', interpolation='nearest')
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(ticks=np.arange(len(gen_tokens)), labels=gen_tokens, rotation=90)
    plt.yticks(ticks=np.arange(len(prompt_tokens)), labels=prompt_tokens)
    plt.xlabel("Generated tokens")
    plt.ylabel("Prompt tokens (incl. tag)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def plot_prompt_importance(ig_matrix, prompt_tokens, title="Prompt Token Importance (sum over generated)"):
    scores = ig_matrix.sum(axis=1)
    plt.figure(figsize=(max(6, 0.4 * len(prompt_tokens)), 3.5))
    plt.bar(range(len(prompt_tokens)), scores)
    plt.xticks(range(len(prompt_tokens)), prompt_tokens, rotation=90)
    plt.ylabel("IG magnitude (sum over generated)")
    plt.title(title)
    plt.tight_layout()
    plt.show()

# One-call helper

def explain_generation_with_ig(prompt, max_new_tokens=40, n_steps=32, seed=42):
    input_ids, full_text = generate_once(prompt, max_new_tokens=max_new_tokens, seed=seed)
    prompt_len = len(tokenizer(prompt)["input_ids"])

    ig_matrix, prompt_tokens, gen_tokens = compute_ig_matrix_for_generation(
        input_ids.to(device), prompt_len, n_steps=n_steps
    )
    print("FULL TEXT:\n", tokenizer.decode(input_ids[0], skip_special_tokens=True))
    plot_heatmap(ig_matrix, prompt_tokens, gen_tokens,
                 title=f"IG Heatmap | prompt='{prompt.strip()}'")
    plot_prompt_importance(ig_matrix, prompt_tokens)

In [None]:
# Run it (example)

# Use your tags as the prompt start (e.g., "<0> ", "<1> ", etc.)

explain_generation_with_ig("<0>", max_new_tokens=40, n_steps=32, seed=123)

In [None]:
# Run it (example)

# Use your tags as the prompt start (e.g., "<0> ", "<1> ", etc.)

explain_generation_with_ig("<1>", max_new_tokens=40, n_steps=32, seed=123)

In [None]:
# Run it (example)

# Use your tags as the prompt start (e.g., "<0> ", "<1> ", etc.)

explain_generation_with_ig("<2>", max_new_tokens=40, n_steps=32, seed=123)

In [None]:
# Run it (example)

# Use your tags as the prompt start (e.g., "<0> ", "<1> ", etc.)

explain_generation_with_ig("<3>", max_new_tokens=40, n_steps=32, seed=123)

In [None]:
# Run it (example)

# Use your tags as the prompt start (e.g., "<0> ", "<1> ", etc.)

explain_generation_with_ig("<4>", max_new_tokens=40, n_steps=32, seed=123)

In [None]:
# Attention Head Visualizations

In [None]:
model = GPTNeoForCausalLM.from_pretrained(drive_model_path).to(device).eval()
tokenizer = GPT2Tokenizer.from_pretrained(drive_tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def generate_once(prompt, max_new_tokens=40, temperature=0.8, top_p=0.95, seed=123):
    torch.manual_seed(seed)
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **enc,
        do_sample=True, top_p=top_p, temperature=temperature,
        max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id
    )[0]
    text = tokenizer.decode(out, skip_special_tokens=True)
    return out.unsqueeze(0), text  # [1, T], string

@torch.no_grad()
def forward_with_attn(input_ids):
    out = model(input_ids=input_ids, output_attentions=True, output_hidden_states=False)
    # out.attentions: tuple(len=num_layers) of [1, num_heads, T, T]
    return out.attentions


In [None]:
def tag_span_indices(prompt_text):
    # Tokenize the exact prompt; the tag is at the start of the prompt
    ids = tokenizer(prompt_text, return_tensors="pt")["input_ids"][0].tolist()
    # If you pass prompt like "<0> " this returns however many subtokens that is.
    return list(range(len(ids)))  # tag is the whole prompt here


In [None]:
def tag_len_for(tag_text):
    return len(tokenizer(tag_text)["input_ids"])


In [None]:
def plot_attn_matrix(mat, tokens, title):
    plt.figure(figsize=(min(18, 2 + 0.25*len(tokens)), min(18, 2 + 0.25*len(tokens))))
    im = plt.imshow(mat, interpolation="nearest", aspect="auto", cmap="Blues")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    plt.xticks(range(len(tokens)), tokens, rotation=90)
    plt.yticks(range(len(tokens)), tokens)
    plt.xlabel("Key (attended-to) tokens →")
    plt.ylabel("Query (attending) tokens →")
    plt.title(title)
    plt.tight_layout()
    plt.show()

def tokens_for(ids):
    return tokenizer.convert_ids_to_tokens(ids[0].tolist())


In [None]:
@torch.no_grad()
def visualize_attention_overview(prompt, max_new_tokens=40, seed=123, layers_to_show=(0, 6, 12, 18, 23)):
    ids, text = generate_once(prompt, max_new_tokens=max_new_tokens, seed=seed)
    ids, toks = trim_to_nonpad(ids)   # <-- trim out [PAD]s
    attns = forward_with_attn(ids.to(device))  # [1,H,T,T]

    print("FULL TEXT:\n", tokenizer.decode(ids[0], skip_special_tokens=True))

    tag_idx = tag_span_indices(prompt)
    tag_mask = np.zeros(len(toks), dtype=bool)
    tag_mask[tag_idx] = True

    # Layer-avg heatmaps
    for L in layers_to_show:
        A = attns[L].mean(dim=1).squeeze(0).detach().cpu().numpy()
        plot_attn_matrix(A, toks, title=f"Layer {L} — attention (avg over heads)")

    # Attention to/from tag profile
    to_tag, from_tag = [], []
    for L in range(len(attns)):
        A = attns[L].mean(dim=1).squeeze(0).detach().cpu().numpy()
        to_tag.append(A[:, tag_mask].mean())
        from_tag.append(A[tag_mask, :].mean())

    Ls = np.arange(len(attns))
    plt.figure(figsize=(10,4))
    plt.plot(Ls, to_tag, marker="o", label="to tag (←)")
    plt.plot(Ls, from_tag, marker="o", label="from tag (→)")
    plt.xlabel("Layer"); plt.ylabel("Mean attention weight")
    plt.title("Attention to/from TAG across layers")
    plt.legend(); plt.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()


In [None]:
@torch.no_grad()
def visualize_layer_heads(prompt, layer=12, max_new_tokens=40, seed=123, head_limit=None):
    ids, _ = generate_once(prompt, max_new_tokens=max_new_tokens, seed=seed)
    ids, toks = trim_to_nonpad(ids)   # <-- trim out [PAD]s
    attns = forward_with_attn(ids.to(device))
    A = attns[layer].squeeze(0).detach().cpu().numpy()  # [H,T,T]
    H, T, _ = A.shape
    if head_limit is not None:
        H = min(H, head_limit)
        A = A[:H]

    cols = min(4, H)
    rows = int(np.ceil(H / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    axes = np.array(axes).reshape(-1)

    for h in range(H):
        ax = axes[h]
        im = ax.imshow(A[h], interpolation="nearest", aspect="auto", cmap="Blues")
        ax.set_title(f"Layer {layer} · Head {h}")
        ax.set_xticks(range(T)); ax.set_xticklabels(toks, rotation=90, fontsize=8)
        ax.set_yticks(range(T)); ax.set_yticklabels(toks, fontsize=8)
    for k in range(H, len(axes)): axes[k].axis("off")
    fig.colorbar(im, ax=axes[:H].tolist(), fraction=0.015, pad=0.01)
    plt.tight_layout(); plt.show()


In [None]:
def trim_to_nonpad(input_ids):
    """Trim tokens and IDs at the first [PAD]."""
    toks = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    if "[PAD]" in toks:
        cutoff = toks.index("[PAD]")
        toks = toks[:cutoff]
        input_ids = input_ids[:, :cutoff]
    return input_ids, toks

In [None]:
@torch.no_grad()
def plot_to_tag_for_head(prompt, layer=12, head=0, max_new_tokens=40, seed=123):
    ids, _ = generate_once(prompt, max_new_tokens=max_new_tokens, seed=seed)
    ids, toks = trim_to_nonpad(ids)   # <-- trim out [PAD]s
    attn = forward_with_attn(ids.to(device))[layer][0, head]  # [T,T]
    tag_idx = tag_span_indices(prompt)
    tag_mask = torch.zeros(attn.size(1), dtype=torch.bool, device=attn.device)
    tag_mask[tag_idx] = True
    to_tag = attn[:, tag_mask].mean(dim=1).detach().cpu().numpy()

    plt.figure(figsize=(max(8, 0.35*len(toks)), 3.5))
    plt.bar(range(len(toks)), to_tag)
    plt.xticks(range(len(toks)), toks, rotation=90)
    plt.ylabel("Attention → TAG")
    plt.title(f"Layer {layer}, Head {head}")
    plt.tight_layout()
    plt.show()


In [None]:
# A) Overview across selected layers + tag profile
visualize_attention_overview("<1> ", max_new_tokens=40, seed=123, layers_to_show=(0, 8, 16, 23))

# B) Grid of heads for one layer
visualize_layer_heads("<1> ", layer=12, head_limit=6)

# C) Focused: which tokens attend most to the tag in a specific head?
plot_to_tag_for_head("<1> ", layer=12, head=3)


In [None]:
# Counterfactual Comparison Plots

In [None]:
tokenizer = AutoTokenizer.from_pretrained(drive_tokenizer_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
@torch.no_grad()
def generate_once(prompt, max_new_tokens=60, temperature=0.8, top_p=0.95, seed=123):
    torch.manual_seed(seed)
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **enc,
        do_sample=True, top_p=top_p, temperature=temperature,
        max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id
    )[0]
    text = tokenizer.decode(out, skip_special_tokens=True)
    return text

def _token_strings(input_ids):
    return tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

def _group_tokens_into_words(token_strs):
    """
    Heuristic for GPT-2 byte-BPE:
    - Start a new word when a token string begins with a space ' '.
    - Otherwise, append to the current word.
    Returns: words, word_indices (len=number of tokens, mapping token->word_id)
    """
    words, word_indices = [], []
    cur = ""
    wid = -1
    for s in token_strs:
        # Decode byte-BPE artifacts for display only
        piece = tokenizer.convert_tokens_to_string([s])
        starts_new = piece.startswith(" ")
        if starts_new or wid == -1:
            if cur:
                words.append(cur)
            cur = piece.lstrip()
            wid += 1
        else:
            cur += piece
        word_indices.append(wid)
    if cur:
        words.append(cur)
    return words, word_indices

def score_with_tag(tag, continuation_text):
    """
    Score exact string: [tag + ' ' + continuation_text]
    Aggregate token log-probs into words via byte-BPE spacing heuristic.
    """
    text = f"{tag} {continuation_text}".strip()
    enc = tokenizer(text, return_tensors="pt").to(device)
    input_ids = enc["input_ids"]    # [1, T]
    attn = enc["attention_mask"]

    with torch.no_grad():
        logits = model(input_ids=input_ids, attention_mask=attn).logits  # [1,T,V]
        logprobs = F.log_softmax(logits[:, :-1, :], dim=-1)
        tgt = input_ids[:, 1:]
        tok_lp = logprobs.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)[0]   # [T-1]

    # Figure out token positions belonging to the continuation (exclude tag tokens)
    tag_len = len(tokenizer(tag)["input_ids"])
    # Prediction positions corresponding to continuation targets are indices [tag_len-1 ... T-2]
    cont_tok_lp = tok_lp[tag_len-1:]                   # [N_cont_tokens]
    cont_token_ids = input_ids[:, tag_len:].cpu()      # tokens that were predicted

    token_strs = _token_strings(cont_token_ids)
    words, word_idx = _group_tokens_into_words(token_strs)

    # Sum token log-probs into their word bins
    contrib = np.zeros(len(words), dtype=float)
    for k, lp in enumerate(cont_tok_lp.cpu().numpy()):
        w = word_idx[k]
        if 0 <= w < len(words):
            contrib[w] += lp

    total_logprob = float(cont_tok_lp.sum())
    token_count = int(cont_tok_lp.numel())
    return list(zip(words, contrib.tolist())), total_logprob, token_count

def make_counterfactual_plots(tag_a, tag_b, suffix="", max_new_tokens=60, seed=123):
    # Generate with Tag A; keep the continuation fixed
    gen_text = generate_once(f"{tag_a} {suffix}".strip(),
                             max_new_tokens=max_new_tokens, seed=seed)
    cont = gen_text.replace(tag_a, "", 1).strip()

    # Score same continuation under A and B
    words_A, total_A, nA = score_with_tag(tag_a, cont)
    words_B, total_B, nB = score_with_tag(tag_b, cont)

    words = [w for w,_ in words_A]
    a_vals = np.array([v for _,v in words_A])
    b_vals = np.array([v for _,v in words_B])
    delta = a_vals - b_vals

    # Plots
    plt.figure(figsize=(min(14, 0.5*len(words)+4), 4))
    plt.bar(range(len(words)), delta)
    plt.xticks(range(len(words)), words, rotation=90)
    plt.ylabel("Δ logP (A − B)")
    plt.title(f"Word-level preference: {tag_a} vs {tag_b}")
    plt.tight_layout(); plt.show()

    plt.figure(figsize=(min(14, 0.5*len(words)+4), 3.5))
    plt.plot(np.cumsum(delta))
    plt.xlabel("Word position →")
    plt.ylabel("Cumulative Δ logP (A − B)")
    plt.title("Cumulative preference along the continuation")
    plt.grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()

    total_delta = total_A - total_B
    ppl_A = np.exp(-total_A / max(1, nA))
    ppl_B = np.exp(-total_B / max(1, nB))
    print("=== Counterfactual summary (same continuation) ===")
    print(f"Continuation (first 200 chars): {cont[:200]!r}")
    print(f"Total logP (A={tag_a}): {total_A:.2f}   | tokens: {nA}   | approx ppl: {ppl_A:.2f}")
    print(f"Total logP (B={tag_b}): {total_B:.2f}   | tokens: {nB}   | approx ppl: {ppl_B:.2f}")
    print(f"Δ logP (A − B): {total_delta:.2f}  →  {'A preferred' if total_delta>0 else 'B preferred'}")

    return {
        "continuation": cont,
        "words": words,
        "delta_per_word": delta.tolist(),
        "total_logprob_A": total_A,
        "total_logprob_B": total_B,
        "total_delta": float(total_delta),
        "ppl_A": float(ppl_A),
        "ppl_B": float(ppl_B),
    }

In [None]:
# Compare Dickens (<0>) vs Austen (<1>) on the same continuation
res = make_counterfactual_plots("<0>", "<1>", suffix="", max_new_tokens=60, seed=123)

# Try other pairs:
# make_counterfactual_plots("<2>", "<4>", suffix="On the river", seed=7)