In [114]:
import pickle
import time
import random
import numpy as np

from collections import Counter
from annoy import AnnoyIndex
from tqdm.notebook import tqdm
from scipy.spatial.distance import euclidean, pdist, squareform
import scipy
import scipy.stats as stats
from sklearn.cluster import AgglomerativeClustering
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import seaborn as sns
from matplotlib import pyplot as plt
from typing import *

In [13]:
def read_glove_file(size=None) -> Dict[str, List[float]]:
    glove_file = '/mnt/Spookley/datasets/glove/glove.6B.50d.txt'
    w_vecs = {}
    with tqdm(total=400000) as pbar:
        with open(glove_file) as fh:
            for line in fh.readlines():
                
                pbar.update(1)
                toks = line.strip().split()
                word = toks[0]
                # non-words like punctuation marks have entries, but we don't want those
                if not word.isalnum():  
                    continue
                # Some bigrams and trigrams are in the dataset. Skip those.
                try:
                    float(toks[1])
                except ValueError:
                    continue
                # No errors, parse the line.
                vec = [float(s) for s in toks[1:]]
                w_vecs[word] = vec
                if size is not None and len(w_vecs) == size:
                    break
    return w_vecs

In [4]:
def build_index(w_vecs: Dict[str, List[float]]) -> Tuple[Dict[int, str], AnnoyIndex]:
    GLOVE_VEC_SIZE = 50
    idx_to_word = {}
    ann_index = AnnoyIndex(GLOVE_VEC_SIZE, 'euclidean')
    with tqdm(total=len(w_vecs)) as pbar:
        for i, w in enumerate(w_vecs.keys()):
            pbar.update(1)
            ann_index.add_item(i, w_vecs[w])
            idx_to_word[i] = w
    ann_index.build(20) # n trees
    return idx_to_word, ann_index

In [75]:
w_vecs = read_glove_file(size=20000)

  0%|          | 0/400000 [00:00<?, ?it/s]

In [76]:
mat = np.array([w_vecs[w] for w in w_vecs])

In [103]:
time_start = time.time()
tsne = TSNE(n_components=3, verbose=0, perplexity=40, n_iter=300)
tsne_mat = tsne.fit_transform(mat)
print('elapsed', time.time()-time_start)

elapsed 113.28300595283508


In [104]:
def get_tsne_pos(word, w_vecs, tsne_mat):
    words = list(w_vecs.keys())
    word_idx = words.index(word)
    pos = tsne_mat[word_idx, :]
    return pos

strawberry = get_tsne_pos('strawberry', w_vecs, tsne_mat)
banana = get_tsne_pos('banana', w_vecs, tsne_mat)
peach = get_tsne_pos('peach', w_vecs, tsne_mat)
envelope = get_tsne_pos('envelope', w_vecs, tsne_mat)

print(euclidean(strawberry, banana))
print(euclidean(strawberry, envelope))
print(euclidean(strawberry, peach))

5.049500942230225
3.5831453800201416
0.42000818252563477


In [80]:
idx_to_word, ann_index = build_index(w_vecs)

  0%|          | 0/20000 [00:00<?, ?it/s]

In [105]:
v = w_vecs['banana']
items = ann_index.get_nns_by_vector(v, 20)
print(list(idx_to_word[i] for i in items))

['banana', 'bananas', 'coconut', 'peanut', 'bean', 'cane', 'peach', 'potato', 'growers', 'nut', 'goat', 'shrimp', 'candy', 'plum', 'spice', 'citrus', 'plantations', 'cocoa', 'pumpkin', 'honey']


In [113]:
print(euclidean(w_vecs['strawberry'], w_vecs['peach']))
print(euclidean(w_vecs['strawberry'], w_vecs['banana']))
print(euclidean(w_vecs['strawberry'], w_vecs['envelope']))


3.4086222159613113
4.4893891662675
5.775056670299608


In [7]:
class SemantleGame():
    def __init__(self, w_vecs):
        w_list = list(w_vecs.keys())
        self.target_word = random.choice(w_list[1000:50000])
        self.target_vec = w_vecs[self.target_word]
        
    def guess(self, word, vec) -> Tuple[bool, float]:
        # construct guess
        dist = euclidean(vec, self.target_vec)
        # check if win
        if word == self.target_word:
            return True, dist
        else:
            return False, dist
    
    def display_guesses(self):
        s = []
        for g in sorted(self.guesses, key = lambda g: g.dist):
            s.append(str(g))
        print('\n'.join(s))
        
    def __str__(self):
        return '\n'.join('{}: {}'.format(k, v) for k, v in self.__dict__.items())

In [8]:
class SemantleSolver:
    
    def __init__(self, game: SemantleGame, n_random_guesses=10):
        self.game = game
        self.n_random_guesses = n_random_guesses
        self.closest_dist = float('inf')
        self.guesses = {}  # Dict[str: (float, int)] of {word: (dist, guess_num)}
        self.best_guess = None

    
    def guess(self, w_vecs, ann_index, idx_to_word) -> bool:
        # determine next guess
        g_type = 'random:'
        if len(self.guesses) < self.n_random_guesses:
            next_word = random.choice(list(w_vecs.keys()))
        
        if self.best_guess is not None and self.closest_dist <= self.dist_thresh:
            g_type = 'nearby:'
            # 'exploit' - guess words near our best candidate
            idxs_near_best = ann_index.get_nns_by_vector(v, 10000)
            for idx in idxs_near_best:
                w = idx_to_word[idx]
                if w not in self.guesses:
                    next_word = w
                    break
        else:
            next_word = random.choice(list(w_vecs.keys()))
        
        # guess the word
        win, dist = self.game.guess(next_word, w_vecs[next_word])
        self.guesses[next_word] = (dist, len(self.guesses)+1)

        # see if this one's better
        if self.best_guess is None or dist < self.closest_dist:
            print(g_type, next_word, "dist:", dist, "best:", 
                  self.best_guess, "best_dist:", self.closest_dist, "guesses:", len(self.guesses))
            self.closest_dist = dist
            self.best_guess = next_word
        
        if win:
            print("I win!", self.guesses)
            return True
        else:
            return False