In [None]:
! pip install py-readability-metrics

import nltk
from readability import Readability

nltk.download('punkt_tab')

In [None]:
! pip install transformers

In [None]:
from transformers import AutoTokenizer, pipeline

In [None]:
cefr_map = {'A1':0,'A2':1,'B1':2,'B2':3,'C1':4,'C2':5}
cefr_labeler = pipeline(task="text-classification",model="AbdullahBarayan/ModernBERT-base-doc_sent_en-Cefr",batch_size=16)

In [None]:
%%bash
git clone https://github.com/google-research/bleurt.git
cd bleurt
pip install .

In [None]:
!wget https://storage.googleapis.com/bleurt-oss-21/BLEURT-20.zip .
!unzip BLEURT-20.zip

In [None]:
from bleurt import score
checkpoint = "BLEURT-20"
bleurt_scorer = score.BleurtScorer(checkpoint)

In [None]:
import pickle

In [None]:
# run part 1 first to generate these
reconstruction = pickle.load(open('reconstruction.pkl','rb'))
embeddings = pickle.load(open('embeddings.pkl','rb'))
all_sentences = pickle.load(open('all_sentences.pkl','rb'))

In [None]:
def get_bleurt_score(refs,sys):
  scores = bleurt_scorer.score(references=refs, candidates=sys, batch_size=32)
  scores = sum(scores)/len(scores)
  return scores

def get_stats(text):
  r = Readability("\n".join(text))
  fkgl = r.flesch_kincaid()
  ari = r.ari()
  cefr_labels = cefr_labeler(text)
  cefr_labels = [cefr_map[cefr_label['label']] for cefr_label in cefr_labels]
  cefr = sum(cefr_labels)/len(cefr_labels)
  return (fkgl.score, cefr, ari.score)


def get_dataset_stats(src, tgt):
  src_stats = get_stats(src)
  tgt_stats = get_stats(tgt)
  bleurt = get_bleurt_score(src,tgt)
  print(src_stats[0], tgt_stats[0], src_stats[0]-tgt_stats[0])
  print(src_stats[1], tgt_stats[1], src_stats[1]-tgt_stats[1])
  print(src_stats[2], tgt_stats[2], src_stats[2]-tgt_stats[2])
  print(bleurt)


In [None]:
get_dataset_stats(all_sentences['asset.valid.simp.0']['src'], reconstruction['asset_comp_train'])
get_dataset_stats(all_sentences['asset.valid.simp.0']['tgt'], reconstruction['asset_simp_train'])
get_dataset_stats(all_sentences['wiki_auto']['src'][:2000], reconstruction['wauto_comp_train'])
get_dataset_stats(all_sentences['wiki_auto']['tgt'][:2000], reconstruction['wauto_simp_train'])