# TF-IDF Guesser Comparison

This notebook uses the test dataset to compare accuracy of different versions of the TFIDF model.

Note that this notebook uses `tqdm` which can be installed using `pip install tqdm`.  `tqdm` creates progress bars so you can see the status of ongoing processes.

In [52]:
import os
import pickle
import json
import numpy as np

from tqdm import tqdm_notebook, trange

import sys
sys.path.append("../src")

from qanta.tfidf import TfidfGuesser

In [21]:
os.chdir("../src")
guesser = TfidfGuesser.load(stem=False)
stem_guesser = TfidfGuesser.load(stem=True)


Loading tfidf.pickle guesser




Loading stem-tfidf.pickle guesser


# Load testing data

In [89]:
def load_data(filename, ignore_ratio=0, rebalance=False):
    data = list()
    with open(filename) as json_data:
        questions = json.load(json_data)["questions"]
        questions = questions[:int(len(questions) * (1- ignore_ratio))]
        
        for q in questions:
            q_text = q['text'].split()
            label = q['page']
            data.append((q_text, label))
    return data

test_file = "../data/qanta.test.2018.04.18.json"
test_exs = load_data(test_file)
print("Total questions in dataset: {}".format(len(test_exs)))

Total questions in dataset: 4104


# Perform a single guess

In [34]:
guesser.guess([" ".join(test_exs[0][0])], 1)

[[('Francis_Bacon', 0.28755560466002594)]]

# Count correct guesses per model

In [87]:
def score(guesser, exes, batch_size=200):
    no_correct = 0
    for idx in tqdm_notebook(range(0, len(exes), batch_size), leave=False):
        data = exes[idx: idx+batch_size]
        questions = [" ".join(s[0]) for s in data]
        answers = [s[1] for s in data]
        guesses = np.array([ans[0][0] for ans in guesser.guess(questions, 1)])
        no_correct += (guesses == np.array(answers)).sum()

    return no_correct

print("Original TFIDF: {}".format(score(guesser, test_exs)))
print("Stemming TFIDF: {}".format(score(stem_guesser, test_exs)))

HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

Original TFIDF: 1923


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

Stemming TFIDF: 1937
