In [146]:
import dataclasses
import pickle
import random
import numpy as np

from annoy import AnnoyIndex
from tqdm.notebook import tqdm
from scipy.spatial.distance import euclidean, pdist, squareform
import scipy
import scipy.stats as stats
from sklearn.decomposition import PCA

from matplotlib import pyplot as plt
from typing import *

In [2]:
def read_glove_file() -> Dict[str, List[float]]:
    """
    Yields the next 'size' vectors in a glove file.
    """
    glove_file = '/mnt/Spookley/datasets/glove/glove.6B.50d.txt'
    w_vecs = {}
    with tqdm(total=400000) as pbar:
        with open(glove_file) as fh:
            for line in fh.readlines():
                pbar.update(1)
                toks = line.strip().split()
                word = toks[0]
                # non-words like punctuation marks have entries, but we don't want those
                if not word.isalnum():  
                    continue
                # Some bigrams and trigrams are in the dataset. Skip those.
                try:
                    float(toks[1])
                except ValueError:
                    continue
                # No errors, parse the line.
                vec = [float(s) for s in toks[1:]]
                w_vecs[word] = vec
    return w_vecs

In [152]:
def build_index(w_vecs: Dict[str, List[float]]) -> Tuple[Dict[int, str], AnnoyIndex]:
    for v in w_vecs.values():
        GLOVE_VEC_SIZE = len(v)
        break
    idx_to_word = {}
    ann_index = AnnoyIndex(GLOVE_VEC_SIZE, 'euclidean')
    with tqdm(total=len(w_vecs)) as pbar:
        for i, w in enumerate(w_vecs.keys()):
            pbar.update(1)
            ann_index.add_item(i, w_vecs[w])
            idx_to_word[i] = w
    ann_index.build(20) # n trees
    return idx_to_word, ann_index

In [315]:
w_vecs = read_glove_file()

  0%|          | 0/400000 [00:00<?, ?it/s]

In [316]:
pca = PCA(n_components=6)
mat_full = np.array([w_vecs[w] for w in w_vecs])
mat = pca.fit_transform(mat_full)

print(mat.shape)
for i, w in enumerate(w_vecs.keys()):
    w_vecs[w] = mat[i, :]

(336158, 6)


In [317]:

idx_to_word, ann_index = build_index(w_vecs)  # fast


  0%|          | 0/336158 [00:00<?, ?it/s]

In [296]:
def random_point_in_dist(point, dist):
    # For when we know the dist but have no idea what direction to travel
    vec = np.random.random((len(point)))
    vec = vec / scipy.linalg.norm(vec)
    vec = vec * dist
    return vec+point


def directed_point_in_dist(p1, p2, p1_dist, p2_dist):
    # Generate a vector using p1 and p2.
    # Check if it will point in the general direction of our target.
    p1p2 = (p1-p2)
    p1p2mag = scipy.linalg.norm(p1p2)
    if p1p2mag < 0.00001:
        return None, 0
    p1p2_unit = p1p2 / p1p2mag
    if p1_dist < p2_dist:
        # p1 is closer to target
        mag = p1_dist
        target_point = p1 + p1p2_unit*mag
        confidence = (p2_dist-p1_dist) / p1p2mag
        assert confidence >= 0
    else:
        # j is closer to target
        # make a vector from j to a target that is dists[j] away
        mag = p2_dist
        target_point = p2 - p1p2_unit*mag
        confidence = (p1_dist-p2_dist) / p1p2mag
        assert confidence >= 0
    return target_point, confidence

In [156]:
# test case - target at [0,6], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([0,6])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print('expect [0,6]', directed_point_in_dist(p1, p2, d1, d2))
print('expect [0,6]', directed_point_in_dist(p2, p1, d2, d1))

expect [0,6] (array([0., 6.]), 1.0)
expect [0,6] (array([0., 6.]), 1.0)


In [132]:
# test case - target at [1,3], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([1,3])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print('expect [1,3]', directed_point_in_dist(p1, p2, d1, d2))
print('expect [1,3]', directed_point_in_dist(p2, p1, d2, d1))

expect [0,3] (array([0.        , 3.41421356]), 0.8740320488976422)
expect [0,3] (array([0.        , 3.41421356]), 0.8740320488976422)


In [141]:
# test case - target at [3,1], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([3,1])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print(d1, d2)
print('expect zero confidence', directed_point_in_dist(p1, p2, d1, d2))
print('expect zero confidence', directed_point_in_dist(p2, p1, d2, d1))

3.1622776601683795 3.1622776601683795
expect zero confidence (array([0.        , 5.16227766]), 0.0)
expect zero confidence (array([ 0.        , -3.16227766]), 0.0)


In [297]:
class SemantleGame():
    def __init__(self, w_vecs):
        w_list = list(w_vecs.keys())
        self.target_word = random.choice(w_list[1000:10000])
        self.target_vec = w_vecs[self.target_word]
        
    def guess(self, word, vec) -> Tuple[bool, float]:
        # construct guess
        dist = euclidean(vec, self.target_vec)
        # check if win
        if word == self.target_word:
            return True, dist
        else:
            return False, dist
    
    def display_guesses(self):
        s = []
        for g in sorted(self.guesses, key = lambda g: g.dist):
            s.append(str(g))
        print('\n'.join(s))
        
    def __str__(self):
        return '\n'.join('{}: {}'.format(k, v) for k, v in self.__dict__.items())

In [318]:
@dataclasses.dataclass
class Guess:
    word: str
    num: int
    dist: float
    
class SemantleSolver:
    
    def __init__(self, game: SemantleGame, n_random_guesses=2):
        self.game = game
        self.n_random_guesses = n_random_guesses
        self.closest_dist = float('inf')
        self.guesses = []  # List[Guess]
        self.guessed_words = set()  # for fast lookup
        self.best_guess = None
        
        self.EXH_THRESH = 0.0001
        self.N_RANDOM = 2
        self.CONF_THRESH = 0.33
        
        self.stats = {
            'grd_high_conf': 0,
            'grd_random_dist': 0,
            'times_gradient': 0,
            'times_exhaustive': 0,
            'times_random': 0,
        }
        
    def _gradient_method(self, w_vecs, ann_index):
        # Use gradient method to get a closer guess.
        p1 = np.array(w_vecs[self.guesses[-1].word])
        p1_dist = self.guesses[-1].dist
        
        # Consider the few most recent points. 
        # Try and find one with a vector through p1 that points towards the target.
        best_point = None
        best_confidence = 0
        for i in range(2, min(10, len(self.guesses))):
            p2 = np.array(w_vecs[self.guesses[-i].word])
            p2_dist = self.guesses[-i].dist
            
            # where does p2->p1 point? and how well aligned is that spot with the target?
            target_point, confidence = directed_point_in_dist(p1, p2, p1_dist, p2_dist)
            if confidence > best_confidence:
                best_confidence = confidence
                best_point = target_point
                
        if best_confidence < self.CONF_THRESH:
            self.stats['grd_random_dist'] += 1
            vec = np.array(w_vecs[self.best_guess])
            best_point = random_point_in_dist(vec, self.closest_dist)
        else:
            self.stats['grd_high_conf'] += 1

        return best_point

    
    def find_next_guess(self, w_vecs, ann_index, idx_to_word) -> bool:
        if len(self.guesses) < self.N_RANDOM:
            self.stats['times_random'] += 1
            next_word = random.choice(list(w_vecs.keys()))
        elif self.closest_dist > self.EXH_THRESH:
            self.stats['times_gradient'] += 1
            v = self._gradient_method(w_vecs, ann_index)
            idxs_near_best = ann_index.get_nns_by_vector(v, 1000)
            for idx in idxs_near_best:
                w = idx_to_word[idx]
                if w not in self.guessed_words:
                    next_word = w
                    break
        else:
            self.stats['times_exhaustive'] += 1
            # We're close enough to start exhaustive search
            v = w_vecs[self.best_guess]
            idxs_near_best = ann_index.get_nns_by_vector(v, 1000)
            for idx in idxs_near_best:
                w = idx_to_word[idx]
                if w not in self.guessed_words:
                    next_word = w
                    break
            
        return next_word

    def make_guess(self, word):
        # guess the word
        win, dist = self.game.guess(word, w_vecs[word])
        self.guessed_words.add(word)
        self.guesses.append(Guess(word=word, dist=dist, num=len(self.guesses)+1))
        
        # see if this one's better
        if self.best_guess is None or dist < self.closest_dist:
            #print(word, round(dist, 3))
            self.closest_dist = dist
            self.best_guess = word
        
        if win:
            #print("I win!")
            return True
        else:
            return False
        

In [320]:
game = SemantleGame(w_vecs)
player = SemantleSolver(game)
print(game.target_word)
won = False
while not won:
    word = player.find_next_guess(w_vecs, ann_index, idx_to_word)
    won = player.make_guess(word)
    g = player.guesses[-1]
    print(g.word, round(g.dist, 3))
    if len(player.guesses) > 5000:
        print('stopped. ')
        print('Best guess:', player.best_guess, 'dist:', player.closest_dist)
        break

print(player.stats)

embraced
gii 3.489
viagem 4.574
med 4.005
doce 4.705
en 5.356
shree 4.725
westin 3.403
winthrop 3.672
du 4.979
suscriptos 4.138
muh 5.351
motherwell 3.037
oven 2.796
browned 3.121
eighteens 3.681
petersburg 3.88
sarajevo 3.384
poland 3.265
washington 3.431
fdch 4.487
fraiche 3.83
germany 2.926
spain 3.084
brussels 3.107
sweden 2.809
geidar 3.704
cholesterol 3.491
1916 2.032
gov 3.075
courtroom 2.227
blacks 1.666
wickets 2.25
lingering 1.738
disputes 1.762
considerable 1.837
boundaries 2.089
receivers 1.523
spacewalking 1.867
defensive 1.831
tackle 2.043
anyone 1.967
murdoch 1.959
cowboys 1.516
wild 1.788
award 2.406
cheaply 1.193
controversies 1.553
defensemen 1.867
judaism 1.375
electrodes 2.334
outstanding 1.28
emotions 1.404
ways 1.984
appropriations 1.826
appear 1.856
fcc 1.866
cannot 1.779
spike 1.763
domestic 1.609
muslims 2.078
ideas 1.832
lighting 1.224
puck 1.346
australians 1.426
bolt 1.224
couch 1.202
heal 0.958
sharper 1.731
christianity 1.422
downfield 1.562
household 0.90

In [272]:
# grid search time.
def run_trial(exh, n_rand, conf_thresh, w_vecs, idx_to_word, ann_index):
    # make a game
    
    game = SemantleGame(w_vecs)
    player = SemantleSolver(game)
    
    player.EXH_THRESH = exh
    player.N_RANDOM = n_rand
    player.CONF_THRESH = conf_thresh
    
    won = False
    while not won:
        word = player.find_next_guess(w_vecs, ann_index, idx_to_word)
        won = player.make_guess(word)
        if len(player.guesses) > 1000:
            break
    return len(player.guesses)

In [290]:

n_dims = [2, 3, 4, 5]
exh_threshes = [0.001]
n_randoms = [2, 3, 4]
conf_threshes = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
n_trials = 20

params_results = {}

for n_dim in n_dims:
    
    # set up space
    w_vecs = read_glove_file()
    pca = PCA(n_components=n_dim)
    mat_full = np.array([w_vecs[w] for w in w_vecs])
    mat = pca.fit_transform(mat_full)

    for i, w in enumerate(w_vecs.keys()):
        w_vecs[w] = mat[i, :]

    idx_to_word, ann_index = build_index(w_vecs)

    
    for exh in exh_threshes:
        for n_rand in n_randoms:
            for conf_thresh in conf_threshes:
                params = (n_dim, exh, n_rand, conf_thresh)
                for trial in range(n_trials):
                    n_guesses = run_trial(exh, n_rand, conf_thresh, w_vecs, idx_to_word, ann_index)
                    if not params in params_results:
                        params_results[params] = 0
                    params_results[params] += n_guesses
                params_results[params] /= n_trials
                print(params, params_results[params])

print(params_results)

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(2, 0.001, 2, 0.2) 15.9
(2, 0.001, 2, 0.3) 13.65
(2, 0.001, 2, 0.4) 13.9
(2, 0.001, 2, 0.5) 14.0
(2, 0.001, 2, 0.6) 13.25
(2, 0.001, 2, 0.7) 14.15
(2, 0.001, 2, 0.8) 12.35
(2, 0.001, 3, 0.2) 13.5
(2, 0.001, 3, 0.3) 12.1
(2, 0.001, 3, 0.4) 14.1
(2, 0.001, 3, 0.5) 16.55
(2, 0.001, 3, 0.6) 13.75
(2, 0.001, 3, 0.7) 13.15
(2, 0.001, 3, 0.8) 15.25
(2, 0.001, 4, 0.2) 15.55
(2, 0.001, 4, 0.3) 15.95
(2, 0.001, 4, 0.4) 14.9
(2, 0.001, 4, 0.5) 16.8
(2, 0.001, 4, 0.6) 14.65
(2, 0.001, 4, 0.7) 13.5
(2, 0.001, 4, 0.8) 12.9
(2, 0.001, 5, 0.2) 15.45
(2, 0.001, 5, 0.3) 16.25
(2, 0.001, 5, 0.4) 16.3
(2, 0.001, 5, 0.5) 15.5
(2, 0.001, 5, 0.6) 14.25
(2, 0.001, 5, 0.7) 16.9
(2, 0.001, 5, 0.8) 14.95
(2, 0.001, 6, 0.2) 16.35
(2, 0.001, 6, 0.3) 15.5
(2, 0.001, 6, 0.4) 16.55
(2, 0.001, 6, 0.5) 14.9
(2, 0.001, 6, 0.6) 15.95
(2, 0.001, 6, 0.7) 15.8
(2, 0.001, 6, 0.8) 15.3
(2, 0.001, 7, 0.2) 16.2
(2, 0.001, 7, 0.3) 17.95
(2, 0.001, 7, 0.4) 16.2
(2, 0.001, 7, 0.5) 16.0
(2, 0.001, 7, 0.6) 16.1
(2, 0.001, 7, 0.7) 16

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(3, 0.001, 2, 0.2) 27.7
(3, 0.001, 2, 0.3) 34.25
(3, 0.001, 2, 0.4) 32.55
(3, 0.001, 2, 0.5) 26.7
(3, 0.001, 2, 0.6) 22.7
(3, 0.001, 2, 0.7) 23.75
(3, 0.001, 2, 0.8) 28.75
(3, 0.001, 3, 0.2) 29.25
(3, 0.001, 3, 0.3) 26.9
(3, 0.001, 3, 0.4) 28.95
(3, 0.001, 3, 0.5) 28.0
(3, 0.001, 3, 0.6) 24.75
(3, 0.001, 3, 0.7) 24.6
(3, 0.001, 3, 0.8) 27.6
(3, 0.001, 4, 0.2) 34.45
(3, 0.001, 4, 0.3) 24.4
(3, 0.001, 4, 0.4) 29.2
(3, 0.001, 4, 0.5) 28.2
(3, 0.001, 4, 0.6) 28.8
(3, 0.001, 4, 0.7) 20.2
(3, 0.001, 4, 0.8) 23.9
(3, 0.001, 5, 0.2) 28.35
(3, 0.001, 5, 0.3) 30.35
(3, 0.001, 5, 0.4) 30.1
(3, 0.001, 5, 0.5) 27.95
(3, 0.001, 5, 0.6) 27.8
(3, 0.001, 5, 0.7) 23.35
(3, 0.001, 5, 0.8) 28.9
(3, 0.001, 6, 0.2) 29.8
(3, 0.001, 6, 0.3) 28.65
(3, 0.001, 6, 0.4) 26.8
(3, 0.001, 6, 0.5) 29.4
(3, 0.001, 6, 0.6) 29.55
(3, 0.001, 6, 0.7) 24.2
(3, 0.001, 6, 0.8) 23.65
(3, 0.001, 7, 0.2) 29.8
(3, 0.001, 7, 0.3) 42.2
(3, 0.001, 7, 0.4) 31.85
(3, 0.001, 7, 0.5) 32.2
(3, 0.001, 7, 0.6) 27.5
(3, 0.001, 7, 0.7) 26.3


  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(4, 0.001, 2, 0.2) 47.95
(4, 0.001, 2, 0.3) 53.0
(4, 0.001, 2, 0.4) 51.05
(4, 0.001, 2, 0.5) 44.15
(4, 0.001, 2, 0.6) 41.5
(4, 0.001, 2, 0.7) 34.0
(4, 0.001, 2, 0.8) 38.0
(4, 0.001, 3, 0.2) 53.9
(4, 0.001, 3, 0.3) 56.35
(4, 0.001, 3, 0.4) 47.8
(4, 0.001, 3, 0.5) 42.35
(4, 0.001, 3, 0.6) 38.7
(4, 0.001, 3, 0.7) 36.1
(4, 0.001, 3, 0.8) 38.6
(4, 0.001, 4, 0.2) 55.0
(4, 0.001, 4, 0.3) 47.2
(4, 0.001, 4, 0.4) 46.35
(4, 0.001, 4, 0.5) 46.45
(4, 0.001, 4, 0.6) 44.3
(4, 0.001, 4, 0.7) 39.15
(4, 0.001, 4, 0.8) 36.1
(4, 0.001, 5, 0.2) 66.0
(4, 0.001, 5, 0.3) 50.0
(4, 0.001, 5, 0.4) 37.1
(4, 0.001, 5, 0.5) 43.5
(4, 0.001, 5, 0.6) 41.75
(4, 0.001, 5, 0.7) 36.75
(4, 0.001, 5, 0.8) 38.6
(4, 0.001, 6, 0.2) 51.0
(4, 0.001, 6, 0.3) 56.5
(4, 0.001, 6, 0.4) 45.9
(4, 0.001, 6, 0.5) 41.25
(4, 0.001, 6, 0.6) 38.7
(4, 0.001, 6, 0.7) 34.0
(4, 0.001, 6, 0.8) 28.35
(4, 0.001, 7, 0.2) 64.65
(4, 0.001, 7, 0.3) 54.05
(4, 0.001, 7, 0.4) 49.2
(4, 0.001, 7, 0.5) 44.6
(4, 0.001, 7, 0.6) 39.4
(4, 0.001, 7, 0.7) 40.9
(4

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(5, 0.001, 2, 0.2) 47.35
(5, 0.001, 2, 0.3) 66.15
(5, 0.001, 2, 0.4) 65.7
(5, 0.001, 2, 0.5) 45.0
(5, 0.001, 2, 0.6) 51.6
(5, 0.001, 2, 0.7) 50.35
(5, 0.001, 2, 0.8) 79.4
(5, 0.001, 3, 0.2) 60.7
(5, 0.001, 3, 0.3) 87.75
(5, 0.001, 3, 0.4) 70.3
(5, 0.001, 3, 0.5) 52.0
(5, 0.001, 3, 0.6) 49.35
(5, 0.001, 3, 0.7) 44.85
(5, 0.001, 3, 0.8) 37.75
(5, 0.001, 4, 0.2) 92.55
(5, 0.001, 4, 0.3) 86.45
(5, 0.001, 4, 0.4) 79.45
(5, 0.001, 4, 0.5) 49.25
(5, 0.001, 4, 0.6) 51.25
(5, 0.001, 4, 0.7) 45.5
(5, 0.001, 4, 0.8) 61.7
(5, 0.001, 5, 0.2) 74.5
(5, 0.001, 5, 0.3) 96.7
(5, 0.001, 5, 0.4) 71.05
(5, 0.001, 5, 0.5) 49.65
(5, 0.001, 5, 0.6) 58.4
(5, 0.001, 5, 0.7) 50.4
(5, 0.001, 5, 0.8) 51.6
(5, 0.001, 6, 0.2) 73.5
(5, 0.001, 6, 0.3) 76.95
(5, 0.001, 6, 0.4) 70.3
(5, 0.001, 6, 0.5) 53.8
(5, 0.001, 6, 0.6) 50.35
(5, 0.001, 6, 0.7) 43.85
(5, 0.001, 6, 0.8) 75.1
(5, 0.001, 7, 0.2) 74.65
(5, 0.001, 7, 0.3) 69.4
(5, 0.001, 7, 0.4) 55.7
(5, 0.001, 7, 0.5) 49.7
(5, 0.001, 7, 0.6) 40.35
(5, 0.001, 7, 0.7) 42

In [291]:
print(list(sorted(zip(params_results.items()), key=lambda x: x[1]))[:10])

IndexError: tuple index out of range

In [292]:
ls = list(sorted(zip(params_results.items()), key = lambda x: x[0][1]))

In [293]:
for item in ls:
    print(item)

(((2, 0.001, 3, 0.3), 12.1),)
(((2, 0.001, 2, 0.8), 12.35),)
(((2, 0.001, 4, 0.8), 12.9),)
(((2, 0.001, 3, 0.7), 13.15),)
(((2, 0.001, 2, 0.6), 13.25),)
(((2, 0.001, 3, 0.2), 13.5),)
(((2, 0.001, 4, 0.7), 13.5),)
(((2, 0.001, 2, 0.3), 13.65),)
(((2, 0.001, 3, 0.6), 13.75),)
(((2, 0.001, 2, 0.4), 13.9),)
(((2, 0.001, 2, 0.5), 14.0),)
(((2, 0.001, 3, 0.4), 14.1),)
(((2, 0.001, 2, 0.7), 14.15),)
(((2, 0.001, 5, 0.6), 14.25),)
(((2, 0.001, 4, 0.6), 14.65),)
(((2, 0.001, 4, 0.4), 14.9),)
(((2, 0.001, 6, 0.5), 14.9),)
(((2, 0.001, 5, 0.8), 14.95),)
(((2, 0.001, 3, 0.8), 15.25),)
(((2, 0.001, 6, 0.8), 15.3),)
(((2, 0.001, 5, 0.2), 15.45),)
(((2, 0.001, 5, 0.5), 15.5),)
(((2, 0.001, 6, 0.3), 15.5),)
(((2, 0.001, 4, 0.2), 15.55),)
(((2, 0.001, 6, 0.7), 15.8),)
(((2, 0.001, 2, 0.2), 15.9),)
(((2, 0.001, 4, 0.3), 15.95),)
(((2, 0.001, 6, 0.6), 15.95),)
(((2, 0.001, 7, 0.5), 16.0),)
(((2, 0.001, 7, 0.6), 16.1),)
(((2, 0.001, 7, 0.2), 16.2),)
(((2, 0.001, 7, 0.4), 16.2),)
(((2, 0.001, 5, 0.3), 16.2