In [4]:
import re, numpy as np, textstat, torch, requests, json
from typing import List
from nltk.tokenize import sent_tokenize
from transformers import (
    AutoTokenizer, T5ForConditionalGeneration, pipeline
)
import evaluate

In [5]:
DEVICE   = 0 if torch.cuda.is_available() else -1
MAX_TOK  = 512        # CoEdit token cap
OLL_HOST = "http://localhost:11434"   # default Ollama REST host
OLL_MOD  = "mistral:latest"

# ---------- metric objects (loaded once) -------------------
_bleu       = evaluate.load("bleu")
_bertscore  = evaluate.load("bertscore")
_perplexity = evaluate.load("perplexity", module_type="metric")
_sari       = evaluate.load("sari") 

In [6]:
def preprocess_texts(texts):
    return [" ".join(t.strip().split()) for t in texts]

In [7]:
_tok_coedit = AutoTokenizer.from_pretrained("grammarly/coedit-large")
_mod_coedit = T5ForConditionalGeneration.from_pretrained(
    "grammarly/coedit-large"
).to(DEVICE)
_pipe_coedit = pipeline(
    "text2text-generation",
    model=_mod_coedit,
    tokenizer=_tok_coedit,
    device=DEVICE,
    do_sample=False,               # deterministic
    max_length=MAX_TOK
)

Device set to use cuda:0


In [8]:
def _coedit_single(text: str) -> str:
    prompt = ("Paraphrase and improve the clarity, style, and grammar "
              f"of the following text: {text}")
    return _pipe_coedit(prompt)[0]["generated_text"].strip()

In [9]:
def coedit_edit(texts: List[str]) -> List[str]:
    outs = []
    for t in texts:
        # chunk if needed
        cur, chunks = "", []
        for s in sent_tokenize(t):
            if _tok_coedit(cur + " " + s, return_tensors="pt").input_ids.shape[1] < MAX_TOK:
                cur += " " + s
            else:
                chunks.append(cur.strip()); cur = s
        if cur: chunks.append(cur.strip())
        fixed = " ".join(_coedit_single(c) for c in chunks)
        outs.append(fixed)
    return outs

In [10]:
_tag_re = re.compile(r"<fixg>(.*?)</fixg>", re.S)

def _ollama_call(prompt: str) -> str:
    data = {"model": OLL_MOD, "stream": False, "prompt": prompt}
    r = requests.post(f"{OLL_HOST}/api/generate", json=data, timeout=120)
    r.raise_for_status()
    return r.json()["response"]

def mistral_edit(texts: List[str]) -> List[str]:
    outs = []
    for t in texts:
        prm = ("Please fix grammatical errors in this sentence and improve "
               f"its style. Put the result between <fixg> and </fixg> tags.\n\n{t}")
        resp = _ollama_call(prm)
        m = _tag_re.search(resp)
        outs.append(m.group(1).strip() if m else resp.strip())
    return outs

In [11]:
def _pre(txts):        # whitespace normalise
    return [" ".join(x.strip().split()) for x in txts]

def evaluate_outputs(preds: List[str],
                     sources: List[str],
                     references: List[List[str]]) -> dict:
    src = _pre(sources)
    prd = _pre(preds)
    sari  = _sari.compute(sources=src, predictions=prd, references=references)["sari"]
    bleu  = _bleu.compute(predictions=prd, references=references)["bleu"]
    bert  = _bertscore.compute(predictions=prd,
                               references=[r[0] for r in references],
                               lang="en")["f1"]
    fkgl  = float(np.mean([textstat.flesch_kincaid_grade(p) for p in prd]))
    fre   = float(np.mean([textstat.flesch_reading_ease(p)  for p in prd]))
    ppl   = float(np.mean(_perplexity.compute(
                model_id="gpt2", predictions=prd)["perplexities"]))
    return dict(sari=sari, bleu=bleu,
                bert_f1=float(np.mean(bert)),
                fkgl=fkgl, flesch=fre, perplexity=ppl)

In [12]:
if __name__ == "__main__":
    from datasets import load_dataset
    ds   = load_dataset("chaojiang06/wiki_auto", "manual", split="test[:10]")
    src  = ds["normal_sentence"]
    refs = [[r] for r in ds["simple_sentence"]]

    coedit_preds  = coedit_edit(src)
    mistral_preds = mistral_edit(src)

    print("-- COEDIT metrics --")
    print(evaluate_outputs(coedit_preds, src, refs))

    print("\n-- MISTRAL metrics --")
    print(evaluate_outputs(mistral_preds, src, refs))

-- COEDIT metrics --


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1/1 [00:01<00:00,  1.86s/it]


{'sari': 37.63946900280527, 'bleu': 0.022973195426686147, 'bert_f1': 0.8749582767486572, 'fkgl': 11.290000000000001, 'flesch': 53.376999999999995, 'perplexity': 43.729872703552246}

-- MISTRAL metrics --


100%|██████████| 1/1 [00:03<00:00,  3.52s/it]

{'sari': 38.79894941991607, 'bleu': 0.014694292423958269, 'bert_f1': 0.8709092020988465, 'fkgl': 12.66, 'flesch': 43.81699999999999, 'perplexity': 41.56548843383789}



