In [None]:
from gensim import downloader
from gensim.similarities.annoy import AnnoyIndexer
import os
import ipywidgets as widgets
from IPython.display import display
import json

In [None]:
# Choose the version
version_widget = widgets.Dropdown(
    options=['normal', 'junior'],
    value='junior',
    description='Semantle Version:',
    disabled=False,
)
display(version_widget)

In [None]:
with open('./targets.json', 'r') as file:
    data = json.load(file)

target = data[version_widget.value]

In [None]:
print("Downloading word2vec model", end="\r")

model = downloader.load('word2vec-google-news-300')

print("Downloaded word2vec model ")

fname = './tmp/mymodel.index'

if os.path.exists(fname):
    print("Loading random forest for approximate nearest neighbors")
    annoy_index = AnnoyIndexer()
    annoy_index.load(fname)
    annoy_index.model = model
else:
    print("Creating random forest for approximate nearest neighbors")   
    # 100 trees are being used in this example
    annoy_index = AnnoyIndexer(model, 10)
    # Persist index to disk
    annoy_index.save(fname)

print("All set!")

In [None]:
input_values = []
text_input = widgets.Text(description="Guess: ")
output = widgets.Output()

In [None]:
def on_text_submit(sender):
    
    guess = text_input.value.lower()
    output_str = f"({len(input_values) + 1})"
    if guess == target:
        input_values.insert(0, f"{output_str} SUCCESS {target}!")
        text_input.disabled = True
    else:
        try:
            approx_direction = model[target] - model[guess]
            approximate_neighbors = model.most_similar([approx_direction], topn=3, indexer=annoy_index)
            approximate_neighbors_str = " ".join([x[0].lower() for x in approximate_neighbors])
            input_values.insert(0, f"{output_str} {approximate_neighbors_str}")
        except Exception:
            input_values.insert(0, f"{output_str} {guess} not found")
    
    output.clear_output()
    with output:
        print("\n".join(input_values))



In [None]:
text_input.continuous_update = False
text_input.observe(on_text_submit)

In [None]:
display(text_input)
display(output)