In [1]:
import json

In [2]:
test_source_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/corrupted_nodup_files/test.source'
test_target_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/corrupted_nodup_files/test.target'
test_meta_path = '/home/ml/cadencao/Two-Steps-Summarization/datasets/cnn_dm/corrupted_nodup_files/test.metadata'
pred_file_path = 'preds/nodup_preds_bm1_cpbest2.hypo'

In [3]:
def read_metadata(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

In [4]:
def read_summaries(file_path):
    lines = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            lines.append(line.strip())
    return lines

In [5]:
sources = read_summaries(test_source_path)
targets = read_summaries(test_target_path)
print(len(sources))

11490


In [6]:
meta_data = read_metadata(test_meta_path)
print(len(meta_data))
print(meta_data[0])

11490
{'claim': 'Marseille prosecutor says "so far no videos weren\'t used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .', 'label': 'INCORRECT', 'backtranslation': False, 'augmentation': 'NegateSentences', 'augmentation_span': [8, 9], 'noise': False}


In [7]:
predicts = read_summaries(pred_file_path)

In [8]:
def post_process(preds):
    processed_preds = []
    for p in preds:
        if p[0] == p[1] or p[0] == '"' or p[0] == "'":
            processed_preds.append(p[1:])
#         if (p[0] == '"' or p[0] == "'") and p.count(p[0]) % 2 == 1:
#             processed_preds.append(p[1:])
#         elif p[0] == p[1]:
#             processed_preds.append(p[1:])
        else:
            processed_preds.append(p)
    return processed_preds

In [9]:
predicts = post_process(predicts)

In [10]:
print(len(predicts))
assert len(sources) == len(targets) == len(meta_data) == len(predicts)

11490


In [11]:
index = len(sources) // 2
print('- Metadata:')
print(meta_data[index])
print('- Corrupted:')
print(sources[index][: sources[index].find('</s>')])
print('- Target:')
print(targets[index])
print('- Prediction:')
print(predicts[index])

- Metadata:
{'claim': "Gertrude Weaver became the world's oldest person last week\xa0following the death of a 117-year-old woman in Japan . Waver died from complications due to pneumonia in Camden . She attributed her long life to treating others well and eating her own cooking . Weaver was born in Arkansas in 1898 and worked as a domestic helper . 115-year-old Jeralean Talley, of Detroit, is not now the world's oldest person .", 'label': 'INCORRECT', 'backtranslation': False, 'augmentation': 'NegateSentences', 'augmentation_span': [71, 72], 'noise': False}
- Corrupted:
Gertrude Weaver became the world's oldest person last week following the death of a 117-year-old woman in Japan . Waver died from complications due to pneumonia in Camden . She attributed her long life to treating others well and eating her own cooking . Weaver was born in Arkansas in 1898 and worked as a domestic helper . 115-year-old Jeralean Talley, of Detroit, is not now the world's oldest person . 
- Target:
Gertru

#### Classification Accuracy

In [12]:
# classification accuracy

In [13]:
true_labels, pred_labels = [], []
for i, (m, p, s) in enumerate(zip(meta_data, predicts, sources)):
    if m['label'] == "CORRECT":
        true_labels.append(1)
    else:
        true_labels.append(0)
    
    s = s[: s.find('</s>') - 1]
    p_tokens = p.lower().split()
    s_tokens = s.lower().split()

    if p_tokens == s_tokens:  # does not change source: predict correct
        pred_labels.append(1)
    else:
        pred_labels.append(0)

In [14]:
print(true_labels[:10])
print(pred_labels[:10])
assert len(true_labels) == len(pred_labels)

[0, 1, 1, 0, 0, 1, 1, 0, 1, 1]
[0, 1, 1, 0, 0, 0, 1, 0, 1, 1]


In [15]:
from sklearn.metrics import classification_report

In [16]:
target_names = ['Corrupted', 'Not Corrupted']
print(classification_report(true_labels, pred_labels, target_names=target_names))

               precision    recall  f1-score   support

    Corrupted       0.78      0.95      0.86      5780
Not Corrupted       0.93      0.73      0.82      5710

    micro avg       0.84      0.84      0.84     11490
    macro avg       0.86      0.84      0.84     11490
 weighted avg       0.86      0.84      0.84     11490



#### String Matching Accuracy

In [17]:
match_num = 0
clean_match, corrupt_match = [], []

for i, (m, s, t, p) in enumerate(zip(meta_data, sources, targets, predicts)):    
#     if p[0] == p[1] or p[0] == '"' or p[0] == "'":
#         p = p[1:]

    s = s[: s.find('</s>') - 1]
    t_tokens = t.split()
    p_tokens = p.split()

#     if len(p_tokens) > len(t_tokens):
#         p_tokens = p_tokens[: len(t_tokens)]

    pred_correct = (t_tokens == p_tokens)
    if pred_correct:
        match_num += 1

    if m['label'] == 'CORRECT':
        clean_match.append(pred_correct)
    else:
        corrupt_match.append(pred_correct)
#     else:
#         if print_count < 10 and s != t:
#             print('Metadata:')
#             print(m)
#             print('Inference:')
#             print(t)
#             print('Corrupted:')
#             print(s)
#             print('Prediction:')
#             print(p)
#             print()
#             print_count += 1

In [18]:
match_num / len(targets)  # 0.6504786771105309

0.6537859007832898

In [19]:
sum(clean_match) / len(clean_match)

0.7094570928196147

In [20]:
sum(corrupt_match) / len(corrupt_match)

0.5987889273356402