In [None]:
from utils import *

data = load_data('../data/salsa_test.json')

### Define Metrics

In [None]:
import nltk
from utils.sari import *
import bert_score as bs
from bert_score import BERTScorer
from utils.util import avg

scorer = BERTScorer(lang="en", rescale_with_baseline=True)

comet_model_path = comet.download_model('wmt21-comet-mqm')
comet_mqm = comet.load_from_checkpoint(comet_model_path)

def BLEUsent(hypothesis, references):
    ref = [x.split(' ') for x in references]
    hyp = hypothesis.split(' ')
    BLEUscore = nltk.translate.bleu_score.sentence_bleu(ref, hyp)
    return BLEUscore

def BERTSCOREsent(hypothesis, references):
    hypothesis_all = [hypothesis for _ in range(len(references))]
    _, _, F1 = scorer.score(hypothesis_all, references)
    return avg(F1.tolist())

### Calculate Sentence Scores

In [None]:
# For test set: patch with external human references
file_path = '../data/test_set_human_written.txt'
with open(file_path, 'r', encoding='utf-8') as f:
    human_references = f.read().split('\n\n')
    human_references = [sent.split('\n') for sent in human_references]
    human_references = [{
        'original': sent[0],
        'simplified': sent[1]
    } for sent in human_references]

# Create a scores list with references and simplifications
scores = []
for orig in set([x['original'] for x in data]):
    sents = [sent for sent in data if sent['original'] == orig]

    humans = [sent for sent in human_references if sent['original'] == orig]
    assert len(humans) == 2
    
    systems = [sent for sent in sents if 'Human' not in sent['system']]

    if len(systems) == 0:
        continue

    for system in systems:
        references = list(set([sent['simplified'] for sent in humans]))
        prediction = system['simplified']
        
        # if system['simpeval_scores'] is None:
        #     continue
        # simpeval_score = avg(system['simpeval_scores'])

        score = {
            'original': orig,
            'simplified': prediction,
            'references': references,
            'system': system['system'],
            # 'simpeval': simpeval_score,
        }

        scores += [score]
    
    # Add human written outputs with other human output as reference
    # for human in humans:
    #     reference = list(set([sent['simplified'] for sent in humans if sent['simplified'] != human['simplified']]))
    #     scores += [{
    #         'original': orig,
    #         'simplified': human['simplified'],
    #         'references': reference,
    #         'system': human['system'],
    #     }]

print(f"Calculating sensitivity with scores for {len(scores)} sentences")

In [None]:
# For train set: create a scores list with references and simplifications
scores = []
for orig in set([x['original'] for x in data]):
    sents = [sent for sent in data if sent['original'] == orig]

    humans = [sent for sent in sents if 'Human' in sent['system']]
    systems = [sent for sent in sents if 'Human' not in sent['system']]

    if len(systems) == 0:
        continue

    for system in systems:
        references = list(set([sent['simplified'] for sent in humans]))
        prediction = system['simplified']
        
        # if system['simpeval_scores'] is None:
        #     continue
        # simpeval_score = avg(system['simpeval_scores'])

        score = {
            'original': orig,
            'simplified': prediction,
            'references': references,
            'system': system['system'],
            # 'simpeval': simpeval_score,
        }

        scores += [score]
    
    # Add human written outputs with other human output as reference
    for human in humans:
        reference = list(set([sent['simplified'] for sent in humans if sent['simplified'] != human['simplified']]))
        scores += [{
            'original': orig,
            'simplified': human['simplified'],
            'references': reference,
            'system': human['system'],
        }]

print(f"Calculating sensitivity with scores for {len(scores)} sentences")

In [None]:
# Add calculated scores
for score in scores:    
    original, prediction, references = score['original'], score['simplified'], score['references']
    
    # Calculate BLEU
    try:
        score['bleu'] = BLEUsent(prediction, references)
    except Exception:
        print("Skipping corrput sentence...")
        continue

    # Calculate BERTScore
    score['bertscore'] = BERTSCOREsent(prediction, references)

    # Calculate SARI
    score['sari_add'], score['sari_keep'], score['sari_del'], score['sari'] = SARIsent(original, prediction, references, components=True)

In [None]:
# Write scoring setup to json to add COMET scores
with open('../lens/1-scores-no-comet-lens.json', 'w') as f:
    json.dump(scores, f, indent=4)

In [None]:
# Add COMET scores
with open('../lens/2-scores-comet-only.json', 'r') as f:
    comet_results = json.load(f)

for score in scores:
    # Ensure both the system and sentence is the same
    aligned = [sent for sent in comet_results if 
        sent['original'] == score['original'] and
        sent['system'] == score['system']
    ]
    
    if len(aligned) == 0 or 'comet' not in aligned[0].keys():
        score['comet'] = 0
        continue
    
    comet_score = aligned[0]['comet']
    score['comet'] = comet_score

with open('../lens/3-scores-no-lens.json', 'w') as f:
    json.dump(scores, f, indent=4)

In [None]:
# TODO: Add LENS scores

## Metric Sensitivity

In [None]:
# Add our scores
conditions = [
    'quality_content', 
    'quality_syntax', 
    'quality_lexical', 
    'error_content', 
    'error_syntax', 
    'error_lexical', 
    'quality', 
    'error', 
    'all'
]

for s in scores:
    s['system'] = s['system'] \
        .replace('simpeval-22', 'new-wiki-1') \
        .replace('simpeval-ext', 'new-wiki-1')

# Create a scores list with references and simplifications
our_scores = []
for orig in set([x['original'] for x in data]):
    sents = [sent for sent in data if sent['original'] == orig]

    human = [sent for sent in sents if 'Human' in sent['system']]
    systems = [sent for sent in sents if 'Human' not in sent['system']] # Or you can try just GPT-3-Few

    if len(systems) == 0:
        continue

    for system in systems:
        # print(system['system'])
        # print(score['system'])
        aligned = [sent for sent in scores if 
            sent['original'] == orig and
            sent['system'] in system['system'] 
        ]

        if len(aligned) == 0:
            continue
            
        n_score = aligned[0]
        
        n_score['our_score'] = system['score']
        for condition in conditions:
            n_score[f'our_score_{condition}'] = calculate_sentence_score(system, get_params(condition))
            # if 'error' in condition:
            #     n_score[f'our_score_{condition}'] = -n_score[f'our_score_{condition}']

        our_scores += [n_score]

print(f"Calculating sensitivity with scores for {len(our_scores)} sentences")

In [None]:
condition_name_mapping = {
    'quality_lexical': 'Lexical',
    'quality_syntax': 'Syntax',
    'quality_content': 'Conceptual',
    'error_lexical': 'Lexical',
    'error_syntax': 'Syntax',
    'error_content': 'Conceptual',
    'error': 'All Error',
    'quality': 'All Quality',
    'all': 'All Edits'
}

In [None]:
# Calculate Kendall Tau correlation for each statistic
from scipy.stats import kendalltau, pearsonr, spearmanr
import heapq
metrics = ['bleu', 'sari', 'bertscore', 'comet', 'lens']
all_results = []
prec = 3

# Calculate metric corrleation
for metric in metrics:
    sys_results = []
    for condition in condition_name_mapping.keys():
        p = pearsonr(
            [s[f'our_score_{condition}'] for s in scores if s[metric] is not None], 
            [s[metric] for s in scores if s[metric] is not None]
        )
        sp = spearmanr(
            [s[f'our_score_{condition}'] for s in scores if s[metric] is not None], 
            [s[metric] for s in scores if s[metric] is not None]
        )
        results = (round(sp[0], prec), round(p[0], prec))
        sys_results += [results]
    all_results += [sys_results]

# Render LaTeX table
delimiters = [
    '\multirow{3}{*}{\\rotatebox[origin=c]{90}{Quality}}',
    '\multirow{3}{*}{\\rotatebox[origin=c]{90}{Error}}',
    '\multirow{3}{*}{\\rotatebox[origin=c]{90}{All}}'
]

out = ''
for i, condition in enumerate(condition_name_mapping.keys()):
    line = ''

    if i % 3 == 0:
        if i != 0:
            line += '\\midrule\n'
        line += f'{delimiters[int(i/3)]} '
        
    line += f'& {condition_name_mapping[condition]} & '
    a_max_p, b_max_p = heapq.nlargest(2, [x[i][0] for x in all_results])
    a_max_sp, b_max_sp = heapq.nlargest(2, [x[i][1] for x in all_results])
    for j, metric in enumerate(metrics):
        p = all_results[j][i][0]
        sp = all_results[j][i][1]

        if str(p) == 'nan':
            p = '---'
        if str(sp) == 'nan':
            sp = '---'

        if p == a_max_p:
            p = f'\\textbf{{{round(p, prec):.3f}}}'
        elif sp == a_max_sp:
            sp = f'\\textbf{{{round(sp, prec):.3f}}}'
        elif p == b_max_p:
            p = f'\\underline{{{round(p, prec):.3f}}}'
        elif sp == b_max_sp:
            sp = f'\\underline{{{round(sp, prec):.3f}}}'
        else:
            p = f'{round(p, prec):.3f}'

        # line += f'{p} & {sp} & '
        line += f'{p} & '
    out += line[:-2] + '\\tabularnewline\n'
print(out)