In [1]:
import json
import nltk
from sklearn.metrics import cohen_kappa_score, balanced_accuracy_score, recall_score, precision_score
import numpy as np

In [2]:
ss_file = ('storysumm.json')

with open(ss_file, 'r') as f:
    storysumm = json.loads(f.read())
    
print(storysumm[list(storysumm.keys())[0]].keys())

dict_keys(['label', 'difficulty', 'story', 'summary', 'errors', 'story-id', 'explanations', 'claims', 'split', 'model'])


In [3]:
splits = {'val': [], 'test': [], 'easy': [], 'hard': [], 'all': []}
summary_to_key = {}

for id, ss in storysumm.items():
    splits['all'].append(ss)
    ss['id'] = id
    summary_to_key[' '.join(ss['summary']).strip()] = id
    if ss['split'] == 'val':
        splits['val'].append(ss)
    else:
        splits['test'].append(ss)
    if ss['difficulty'] == 'easy':
        splits['easy'].append(ss)
    elif ss['difficulty'] == 'hard':
        splits['hard'].append(ss)
        
def word_count(split, text):
    wcs = []
    for ss in split:
        if text == 'summary':
            temp = ' '.join(ss[text])
        else:
            temp = ss[text]
        wcs.append(len(nltk.word_tokenize(temp)))
    return round(np.mean(wcs))

def f_percent(split):
    labels = []
    for ss in split:
        labels.append(int(ss['label']))
    return round(np.mean(labels)*100, 1)

In [9]:
print(', '.join(["split", "count", "story len.", "summary len.", "% faith."]))
print('------------------------------------------------')
for split, split_data in splits.items():
    result = [len(split_data)]
    result.append(word_count(split_data, 'story'))
    result.append(word_count(split_data, 'summary'))
    result.append(f_percent(split_data))
    print(split, result) 

split, count, story len., summary len., % faith.
------------------------------------------------
val [33, 605, 120, 24.2]
test [63, 844, 149, 44.4]
easy [20, 755, 113, 0.0]
hard [40, 707, 141, 0.0]
all [96, 762, 139, 37.5]


In [10]:
def regular_load(filename):
    with open(filename, 'r') as f:
        data = json.loads(f.read())
        labels = {}
        for key, val in data.items():
            label_key = 'label' if 'label' in val else 'probs'
            labels[key] = val[label_key]
    return labels

In [11]:
alignscore_data = regular_load('evaluators/predicted_labels/alignscore-roberta-large.json')

unieval_data = regular_load('evaluators/predicted_labels/unieval.json')
    
minicheck_data = regular_load('evaluators/predicted_labels/minicheck-flan-t5-large.json')
            
fables_data = regular_load('evaluators/predicted_labels/fables-gpt-4-turbo-preview.json')
        
model_methods = {}
models = ['claude-3-opus-20240229', 'gpt-4-0125-preview', 'mixtral']
modes = ['justquestion', 'cot']
for mode in modes:
    for model in models:
        model_methods[f'{mode}-{model}'] = regular_load(f'evaluators/predicted_labels/{model}/{mode}.json')

In [12]:
def tune_thresh(vals, num=150):
    tuning, gvals = [], []
    for gd in splits['val']:
        tuning.append(vals[gd['id']])
        gvals.append(int(gd['label']))
    thresholds = np.linspace(0, 1, num)
    max_score, best_thresh = 0, 0
    for thresh in thresholds:
        pred_labels = [1 if x >= thresh else 0 for x in tuning]
        score = balanced_accuracy_score(gvals, pred_labels)
        if score > max_score:
            max_score = score
            best_thresh = thresh
    return best_thresh

In [17]:
def table_results(vals, source = None, threshold=None):
    output, difficulty, g_labels, p_labels = [], [], [], []
    for ss in splits['all']:
        if source is None or ss['split'] == source:
            difficulty.append(ss['difficulty'])
            g_labels.append(int(ss['label']))
            p_labels.append(vals[ss['id']])
    if threshold:
        p_labels = p_labels > threshold
    output.append(round(cohen_kappa_score(g_labels, p_labels), 2))
    output.append(round(np.mean(p_labels)*100))
    output.append(round(precision_score(g_labels, p_labels), 2))
    output.append(round(recall_score(g_labels, p_labels), 2))
    output.append(100-round(np.mean([p_labels[i] for i in range(len(difficulty)) if difficulty[i] == 'easy'])*100, 1))
    output.append(100-round(np.mean([p_labels[i] for i in range(len(difficulty)) if difficulty[i] == 'hard'])*100, 1))
    output.append(round(balanced_accuracy_score(g_labels, p_labels)*100, 1))
    return " | ".join([str(x) for x in output]), threshold

In [18]:
print(', '.join(['method', 'Coh. k', '% faith.', 'prec.', 'rec.', '% easy', '% hard', 'bal. acc.', 'threshold' ]))
print('------------------------------------------------------------------------')
for key, val in model_methods.items():
    print(key, table_results(val))
print('-----------------')
print('fables', table_results(fables_data))
print('minicheck', table_results(minicheck_data))
print('unieval', table_results(unieval_data, threshold=tune_thresh(unieval_data), source='val'))
print('unieval', table_results(unieval_data, threshold=tune_thresh(unieval_data), source='test'))
print('alignscore', table_results(alignscore_data, threshold=tune_thresh(alignscore_data), source='val'))
print('alignscore', table_results(alignscore_data, threshold=tune_thresh(alignscore_data), source='test'))

method, Coh. k, % faith., prec., rec., % easy, % hard, bal. acc., threshold
------------------------------------------------------------------------
justquestion-claude-3-opus-20240229 ('0.06 | 95 | 0.4 | 1.0 | 20.0 | 2.5 | 54.2', None)
justquestion-gpt-4-0125-preview ('0.11 | 70 | 0.42 | 0.78 | 55.0 | 25.0 | 56.4', None)
justquestion-mixtral ('0.12 | 91 | 0.41 | 1.0 | 15.0 | 15.0 | 57.5', None)
cot-claude-3-opus-20240229 ('0.1 | 90 | 0.41 | 0.97 | 25.0 | 10.0 | 56.1', None)
cot-gpt-4-0125-preview ('0.08 | 94 | 0.4 | 1.0 | 25.0 | 2.5 | 55.0', None)
cot-mixtral ('0.04 | 97 | 0.39 | 1.0 | 0.0 | 7.5 | 52.5', None)
-----------------
fables ('0.33 | 55 | 0.53 | 0.78 | 70.0 | 52.5 | 68.1', None)
minicheck ('0.02 | 16 | 0.4 | 0.17 | 90.0 | 82.5 | 50.8', None)
unieval ('0.25 | 39 | 0.38 | 0.62 | 80.0 | 60.0 | 65.3', 0.8791946308724832)
unieval ('0.04 | 30 | 0.47 | 0.32 | 80.0 | 68.0 | 51.8', 0.8791946308724832)
alignscore ('0.21 | 42 | 0.36 | 0.62 | 80.0 | 53.3 | 63.3', 0.785234899328859)
alig