In [1]:
import json
from src.data_processing import print_progress, scores_as_list, load_data, load_cmu
import numpy as np
from gensim import models
from src.pun_algorithms import is_Tom_Swifty


from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC
from sklearn.metrics import classification_report, precision_recall_fscore_support
from sklearn.linear_model import LogisticRegression
from sklearn.utils import class_weight
from sklearn.ensemble import RandomForestClassifier
from pprint import pprint


import warnings
warnings.filterwarnings('ignore')

In [2]:
# model = models.KeyedVectors.load_word2vec_format("/home/doogy/Data/GoogleNews-vectors-negative300.bin.gz", binary=True)

In [3]:
task1, task2, task3, min_pairs, strings, pun_strings = load_data()
with open("results/tom_swifties.json") as f:
    tom_swifty_annotations = json.load(f)

In [4]:
def get_subs(path):
    substitutions = scores_as_list(path)

    with open("data/t1-t2-mappings.json") as f:
        mappings = {int(k): int(v) for k, v in json.load(f).items()}

    subs = []
    for i, (t1, t2) in enumerate(mappings.items()):
        subs.append(list(sorted(substitutions[t2].items(), key=lambda x: x[1][0][1], reverse=True)))
    return subs

In [5]:
def row_to_string(phonetic, pos, d, r):
    return (         str(phonetic)
            + ", " + str(pos)
            + ", " + "{0:.3f}".format(d['coverage'])
            + ", " + "{0:.3f}".format(d['precision'])
            + ", " + "{0:.3f}".format(d['recall'])
            + ", " + "{0:.3f}".format(d['f1-score'])
            + ", " + "{0:.3f}\n".format(r))

# Coverage, Precision, Recall, Accuracy

In [6]:
def measures(dataset):
    tp = 0
    guesses = 0
    for i in range(len(task2)):
        if tom_swifty_annotations[i]:
            tp += int(task2[i]['target'] == tom_swifty_annotations[i][0][0])
            guesses += 1
        else:
            try:
                tp += int(task2[i]['target'] == dataset[i][0][0].split()[1])
                guesses += 1
            except:
                pass
    results = {}
    results["coverage" ] = guesses / len(task2)
    results["precision"] = tp/guesses
    results["recall"   ] = tp/len(task2)
    results["f1-score" ] = (2 * (tp/guesses) * (tp/len(task2))) / ((tp/guesses) + (tp/len(task2)))

    return results

# Mean Reciprocal Rank

In [7]:
def mrr(substitutions):
    total_rank = 0
    for i in range(len(task2)):
        target = task2[i]['target']

        if tom_swifty_annotations[i]:
            ranks = [r[0] for r in tom_swifty_annotations[i]]
            if target in ranks:
                total_rank += 1 / (ranks.index(target) + 1)
        else:
            try:
                ranks = [r[0].split()[1] for r in substitutions[i]]
                if target in ranks:
                    total_rank += 1 / (ranks.index(target) + 1)
            except:
                pass   

    print("Mean Reciprocal Rank: ", total_rank / len(task2))
    return total_rank / len(task2)

In [8]:
locations = ([('phonetic_filter_no_pos', True, False),
              ('phonetic_filter_with_pos', True, True),
              ('all_trigram_no_pos', False, False),
              ('all_trigram_with_pos', False, True)])

with open("results/tables/detection.json", 'w') as f:
    f.write("Phonetic Filter, Use Position, Coverage, Precision, Recall, F1\n")
    for loc, phon, filt in locations:
        subs = get_subs(loc)
        measure = measures(subs)
        r = mrr(subs)
        f.write(row_to_string(phon, filt, measure, r))
        

 |████████████████████████████████████████████████████████████████████████████████████████████████████| 99.9% Mean Reciprocal Rank:  0.7144636768948335
 |████████████████████████████████████████████████████████████████████████████████████████████████████| 99.9% Mean Reciprocal Rank:  0.7152504589562025
 |████████████████████████████████████████████████████████████████████████████████████████████████████| 99.9% Mean Reciprocal Rank:  0.8027930763178601
returning from json
Mean Reciprocal Rank:  0.81168371361133
