In [6]:
import json
from src.data_processing import print_progress, scores_as_list, load_data, load_cmu, load_task3_data
from src.pronunciations import phonetic_distance
import numpy as np
from gensim import models
from src.pun_algorithms import is_Tom_Swifty, word_sentence_similarity


from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC
from sklearn.metrics import classification_report, precision_recall_fscore_support
from sklearn.linear_model import LogisticRegression
from sklearn.utils import class_weight
from sklearn.ensemble import RandomForestClassifier
from pprint import pprint

from nltk.stem.porter import PorterStemmer
from nltk.stem import WordNetLemmatizer

In [7]:
task1, task2, task3, min_pairs, strings, pun_strings = load_data()

In [8]:
sub_rankings = load_task3_data('phonetic_filter_no_pos', task1, task2, task3)

 |████████████████████████████████████████████████████████████████████████████████████████████████████| 99.9% 

In [9]:
import itertools
def measures(sub_rankings):

    guesses, tp, fp = 0, 0, 0
    mrrank = 0
    mrank = 0
    missed = 0

    for i, sub in enumerate(sub_rankings):
        
        if not sub:
            continue

        targets = task3[i]['sense_tags'][0]
        translation = None
        key_found = False

        for t in targets:
            if t in sub:
                translation = sub[t]
                break
        
        if not translation:
            
            max_score, max_k = 0, None
            max_k = (max ([(phonetic_distance(k, t), k) 
                           for k, t in itertools.product(sub.keys(), targets)]))[1]
            if max_k:
                translation = sub[max_k]
            else:
                continue
            
        for r, tr in enumerate(translation):
            if [w for w in task3[i]['sense_tags'][1] if w in tr['derivations']]:
                mrank += (r+1)
                mrrank += 1/(r+1)
            
        guesses += int(bool(translation))  
        correct = [w for w in task3[i]['sense_tags'][1] if w in translation[0]['derivations']]
        tp += int(bool(correct))
        fp += int(not bool(correct))
    
    return {'coverage': guesses/len(task3),
            'precision': tp/guesses,
            'recall': tp/len(task3),
            'f1-score': (2 * (tp/guesses) * (tp/len(task3))) / ((tp/guesses) + (tp/len(task3))),
            'mrrank': mrrank / len(task3)}

In [10]:
def row_to_string(phonetic, pos, d):
    return (         str(phonetic)
            + ", " + str(pos)
            + ", " + "{0:.3f}".format(d['coverage'])
            + ", " + "{0:.3f}".format(d['precision'])
            + ", " + "{0:.3f}".format(d['recall'])
            + ", " + "{0:.3f}".format(d['f1-score'])
            + ", " + "{0:.3f}\n".format(d['mrrank']))

In [5]:
locations = ([('phonetic_filter_no_pos', True, False),
              ('phonetic_filter_with_pos', True, True),
              ('all_trigram_no_pos', False, False),
              ('all_trigram_with_pos', False, True)])

with open("results/tables/translation.json", 'w') as f:
    f.write("Phonetic Filter, Use Position, Coverage, Precision, Recall, F1, Mean Reciprocal Rank\n")
    for loc, phon, filt in locations:
        subs = load_task3_data(loc, task1, task2, task3)
        measure = measures(subs)
        f.write(row_to_string(phon, filt, measure))

NameError: name 'load_task3_data' is not defined

In [39]:
len([t for t in task1 if t['pun']])

1271

In [26]:
measure

{'coverage': 0.8916211293260473,
 'f1-score': 0.4631680308136736,
 'mrrank': 0.6317187857357554,
 'precision': 0.491317671092952,
 'recall': 0.43806921675774135}

In [27]:
row_to_string("a", "b", measure)

'a, b, 0.892, 0.491, 0.438, 0.463, 0.632\n'

In [29]:
measure = measures(subs)

In [30]:
measure

{'coverage': 0.8916211293260473,
 'f1-score': 0.4631680308136736,
 'mrrank': 0.6317168144223086,
 'precision': 0.491317671092952,
 'recall': 0.43806921675774135}

In [31]:
row_to_string("a", "b", measure)

'a, b, 0.892, 0.491, 0.438, 0.463, 0.632\n'