In [None]:
# noqa
import numpy as np
import scipy
import gensim
import nltk
nltk.download('wordnet')

%load_ext autoreload
%autoreload 2

In [None]:
# Load Google's pre-trained Word2Vec model.
model = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin', binary=True)

In [None]:
from nltk.corpus import wordnet as wn
nouns = set(x.name().split('.', 1)[0] for x in wn.all_synsets('n'))
frequent_words = set(x.name().split('.', 1)[0] for x in wn.all_synsets())


def filter_word(key):
    return key.isalpha() and key.islower() and key in frequent_words


def read_words(filename):
    with open(filename, 'r') as f:
        for word in f:
            yield word.strip().lower()


card_frequent = set(read_words('./frequent_words.txt'))


def filter_cards(key):
    return filter_word(key) and key in nouns and key in card_frequent


unfiltered_vocab = set(key for key in model.vocab.keys() if filter_word(key))
filtered_vocab = set(key for key in unfiltered_vocab if key in card_frequent)
vocab = unfiltered_vocab
card_vocab = set(key for key in filtered_vocab if filter_cards(key))


In [None]:
len(unfiltered_vocab), len(vocab), len(card_vocab)

In [None]:
import environment as env
from environment import Team
master = env.DistanceMaster(model, vocab, card_vocab)
guesser = env.DistanceGuesser(model, vocab, card_vocab)
generator = env.StateGenerator(model, vocab, card_vocab)

## Generate clue

In [None]:
# Generate clue
from copy import deepcopy

print('')
state = generator.generate_state()
tru_state = deepcopy(state)

print('Hidden:', state.hidden_str)
print('Giving clue...')
clue = master.give_clue(state, team=Team.BLUE)
print('Clue:', clue.word, clue.number)

## Generate guesses

In [None]:
# Generate guesses
for i in range(clue.number):
    iteration = i + 1
    guess = guesser.guess(state, clue, iteration, team=Team.BLUE)
    print(' Guess ', iteration, ':', guess)
    if guess in state.blue:
        state.blue.remove(guess)
    else:
        break
print('Whole clue: ', clue)
print('Truth:', tru_state.truth_str)
print('')