In [13]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import pipeline
from datasets import load_dataset
import evaluate
from tqdm import tqdm
import textstat

In [14]:
# ----------------------------
# PREPROCESSING FUNCTIONS
# ----------------------------
def preprocess_texts(texts):
    return [" ".join(t.strip().split()) for t in texts]

def add_control_token(texts, token="<simplify>"):
    return [f"{token} {t}" for t in texts]

In [15]:
# ----------------------------
# CONFIGURATION & MODELS
# ----------------------------
GRAMMAR_MODELS = [
    "google/flan-t5-small",
    "facebook/bart-base"
]
STYLE_MODELS = [
    "google/flan-t5-base",
    "sshleifer/distilbart-cnn-12-6"
]

DEVICE = 0 if torch.cuda.is_available() else -1
BATCH_SIZE = 8
PIPELINE_MAX_LENGTH = 256

In [16]:
# ----------------------------
# LOAD STYLE DATASET (WikiAuto-Manual)
# ----------------------------
dataset = load_dataset("chaojiang06/wiki_auto", "manual")

# Extract test split
test_split = dataset["test"]
# Determine source and target fields
if "src" in test_split.column_names:
    src_field = "src"
elif "complex" in test_split.column_names:
    src_field = "complex"
elif "normal_sentence" in test_split.column_names:
    src_field = "normal_sentence"
elif "original" in test_split.column_names:
    src_field = "original"
else:
    raise KeyError("No source field found in dataset. Columns: {}".format(test_split.column_names))

if "dst" in test_split.column_names:
    tgt_field = "dst"
elif "simple_sentence" in test_split.column_names:
    tgt_field = "simple_sentence"
elif "simple" in test_split.column_names:
    tgt_field = "simple"
elif "correction" in test_split.column_names:
    tgt_field = "correction"
else:
    raise KeyError("No target field found in dataset. Columns: {}".format(test_split.column_names))

sources = test_split[src_field]
references = [[r] for r in test_split[tgt_field]]

In [17]:
# ----------------------------
# BUILD INFERENCE PIPELINES
# ----------------------------
grammar_pipes = {m: pipeline("text2text-generation", model=m, device=DEVICE) for m in GRAMMAR_MODELS}
style_pipes   = {m: pipeline("text2text-generation", model=m, device=DEVICE) for m in STYLE_MODELS}

Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0


In [18]:
# ----------------------------
# LOAD METRICS
# ----------------------------
bleu       = evaluate.load("bleu")
bertscore  = evaluate.load("bertscore")
perplexity = evaluate.load("perplexity", module_type="metric")
sari       = evaluate.load("sari")

In [None]:
# ----------------------------
# EVALUATION FUNCTIONS
# ----------------------------
def evaluate_combination(gm, sm, sources, references):
    # Preprocess and correct grammar
    clean_src = preprocess_texts(sources)
    g_outs = grammar_pipes[gm](clean_src, max_length=PIPELINE_MAX_LENGTH, batch_size=BATCH_SIZE)
    g_texts = [o['generated_text'] for o in g_outs]
    g_clean = preprocess_texts(g_texts)
    print("Grammar done!")
    
    # Simplify style
    style_inputs = add_control_token(g_clean)
    s_outs = style_pipes[sm](style_inputs, max_length=PIPELINE_MAX_LENGTH, batch_size=BATCH_SIZE)
    preds = [o['generated_text'] for o in s_outs]
    print("Styling done!")

    # Compute metrics
    sari_score = sari.compute(sources=clean_src, predictions=preds, references=references)["sari"]
    bleu_score = bleu.compute(predictions=preds, references=references)["bleu"]
    bert_res   = bertscore.compute(predictions=preds, references=[r[0] for r in references], lang="en")
    bert_f1    = np.mean(bert_res['f1'])
    
    # Readability metrics via textstat, averaged over all outputs
    fkgl_vals = [textstat.flesch_kincaid_grade(p) for p in preds]
    fre_vals  = [textstat.flesch_reading_ease(p) for p in preds]
    fkgl       = sum(fkgl_vals) / len(fkgl_vals)
    fre        = sum(fre_vals) / len(fre_vals)
    
    ppl_res    = perplexity.compute(model_id="gpt2", predictions=preds)
    ppl        = np.mean(ppl_res.get('perplexities', ppl_res))
    return {
        'grammar_model': gm,
        'style_model': sm,
        'sari': sari_score,
        'bleu': bleu_score,
        'bert_f1': bert_f1,
        'fkgl': fkgl,
        'flesch': fre,
        'perplexity': ppl
    }

In [None]:
def evaluate_all(sources, references):
    results = []
    combinations = [(gm, sm) for gm in GRAMMAR_MODELS for sm in STYLE_MODELS]
    for gm, sm in tqdm(combinations, desc="Evaluating combos", unit="combo"):
        res = evaluate_combination(gm, sm, sources, references)
        results.append(res)
    return pd.DataFrame(results)

In [None]:
# ----------------------------
# RUN EVALUATION
# ----------------------------
df = evaluate_all(sources, references)

In [20]:
results = []

In [21]:
res = evaluate_combination(GRAMMAR_MODELS[0], STYLE_MODELS[0], sources, references)
results.append(res)

KeyboardInterrupt: 

In [None]:
# Sort and display
df = df.sort_values('sari', ascending=False).reset_index(drop=True)
print(df)

In [None]:
# ----------------------------
# PLOT RESULTS
# ----------------------------
pivot_sari = df.pivot(index='grammar_model', columns='style_model', values='sari')
pivot_bleu = df.pivot(index='grammar_model', columns='style_model', values='bleu')
labels = pivot_sari.index.tolist()
x = np.arange(len(labels))
width = 0.8 / len(STYLE_MODELS)

plt.figure(figsize=(8,5))
for i, sm in enumerate(pivot_sari.columns):
    plt.bar(x + i*width, pivot_sari[sm], width, label=sm)
plt.xticks(x + width*(len(STYLE_MODELS)-1)/2, labels, rotation=45)
plt.ylabel('SARI')
plt.title('SARI by Pipeline Combo')
plt.legend(title='Style Model')
plt.tight_layout()
plt.show()

plt.figure(figsize=(8,5))
for i, sm in enumerate(pivot_bleu.columns):
    plt.bar(x + i*width, pivot_bleu[sm], width, label=sm)
plt.xticks(x + width*(len(STYLE_MODELS)-1)/2, labels, rotation=45)
plt.ylabel('BLEU')
plt.title('BLEU by Pipeline Combo')
plt.legend(title='Style Model')
plt.tight_layout()
plt.show()

KeyboardInterrupt: 

In [None]:
# ----------------------------
# SAVE RESULTS
# ----------------------------
csv_path = os.path.join(os.getcwd(), 'pipeline_evaluation_results.csv')
df.to_csv(csv_path, index=False)
print(f"Results saved to {csv_path}")