In [146]:
import dataclasses
import pickle
import random
import numpy as np

from annoy import AnnoyIndex
from tqdm.notebook import tqdm
from scipy.spatial.distance import euclidean, pdist, squareform
import scipy
import scipy.stats as stats
from sklearn.decomposition import PCA

from matplotlib import pyplot as plt
from typing import *

In [2]:
def read_glove_file() -> Dict[str, List[float]]:
    """
    Yields the next 'size' vectors in a glove file.
    """
    glove_file = '/mnt/Spookley/datasets/glove/glove.6B.50d.txt'
    w_vecs = {}
    with tqdm(total=400000) as pbar:
        with open(glove_file) as fh:
            for line in fh.readlines():
                pbar.update(1)
                toks = line.strip().split()
                word = toks[0]
                # non-words like punctuation marks have entries, but we don't want those
                if not word.isalnum():  
                    continue
                # Some bigrams and trigrams are in the dataset. Skip those.
                try:
                    float(toks[1])
                except ValueError:
                    continue
                # No errors, parse the line.
                vec = [float(s) for s in toks[1:]]
                w_vecs[word] = vec
    return w_vecs

In [152]:
def build_index(w_vecs: Dict[str, List[float]]) -> Tuple[Dict[int, str], AnnoyIndex]:
    for v in w_vecs.values():
        GLOVE_VEC_SIZE = len(v)
        break
    idx_to_word = {}
    ann_index = AnnoyIndex(GLOVE_VEC_SIZE, 'euclidean')
    with tqdm(total=len(w_vecs)) as pbar:
        for i, w in enumerate(w_vecs.keys()):
            pbar.update(1)
            ann_index.add_item(i, w_vecs[w])
            idx_to_word[i] = w
    ann_index.build(20) # n trees
    return idx_to_word, ann_index

In [148]:
w_vecs = read_glove_file()

  0%|          | 0/400000 [00:00<?, ?it/s]

In [153]:
pca = PCA(n_components=10)
mat_full = np.array([w_vecs[w] for w in w_vecs])
mat = pca.fit_transform(mat_full)

for i, w in enumerate(w_vecs.keys()):
    w_vecs[w] = mat[i, :]

(336158, 10)


In [154]:

idx_to_word, ann_index = build_index(w_vecs)  # fast


  0%|          | 0/336158 [00:00<?, ?it/s]

In [155]:
def random_point_in_dist(point, dist):
    # For when we know the dist but have no idea what direction to travel
    vec = np.random.random((len(point)))
    vec = vec / scipy.linalg.norm(vec)
    vec = vec * dist
    return vec+point


def directed_point_in_dist(p1, p2, p1_dist, p2_dist):
    # Generate a vector using p1 and p2.
    # Check if it will point in the general direction of our target.
    p1p2 = (p1-p2)
    p1p2mag = scipy.linalg.norm(p1p2)
    if p1p2mag < 0.00001:
        return None, 0
    p1p2_unit = p1p2 / p1p2mag
    if p1_dist < p2_dist:
        # p1 is closer to target
        mag = p1_dist
        target_point = p1 + p1p2_unit*mag
        confidence = (p2_dist-p1_dist) / p1p2mag
        assert confidence >= 0
    else:
        # j is closer to target
        # make a vector from j to a target that is dists[j] away
        mag = p2_dist
        target_point = p2 - p1p2_unit*mag
        confidence = (p1_dist-p2_dist) / p1p2mag
        assert confidence >= 0
    return target_point, confidence

In [156]:
# test case - target at [0,6], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([0,6])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print('expect [0,6]', directed_point_in_dist(p1, p2, d1, d2))
print('expect [0,6]', directed_point_in_dist(p2, p1, d2, d1))

expect [0,6] (array([0., 6.]), 1.0)
expect [0,6] (array([0., 6.]), 1.0)


In [132]:
# test case - target at [1,3], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([1,3])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print('expect [1,3]', directed_point_in_dist(p1, p2, d1, d2))
print('expect [1,3]', directed_point_in_dist(p2, p1, d2, d1))

expect [0,3] (array([0.        , 3.41421356]), 0.8740320488976422)
expect [0,3] (array([0.        , 3.41421356]), 0.8740320488976422)


In [141]:
# test case - target at [3,1], points on y axis
p1 = np.array([0,0])
p2 = np.array([0,2])
target = np.array([3,1])
d1 = euclidean(p1, target)
d2 = euclidean(p2, target)
print(d1, d2)
print('expect zero confidence', directed_point_in_dist(p1, p2, d1, d2))
print('expect zero confidence', directed_point_in_dist(p2, p1, d2, d1))

3.1622776601683795 3.1622776601683795
expect zero confidence (array([0.        , 5.16227766]), 0.0)
expect zero confidence (array([ 0.        , -3.16227766]), 0.0)


In [164]:
class SemantleGame():
    def __init__(self, w_vecs):
        w_list = list(w_vecs.keys())
        self.target_word = random.choice(w_list[1000:10000])
        self.target_vec = w_vecs[self.target_word]
        
    def guess(self, word, vec) -> Tuple[bool, float]:
        # construct guess
        dist = euclidean(vec, self.target_vec)
        # check if win
        if word == self.target_word:
            return True, dist
        else:
            return False, dist
    
    def display_guesses(self):
        s = []
        for g in sorted(self.guesses, key = lambda g: g.dist):
            s.append(str(g))
        print('\n'.join(s))
        
    def __str__(self):
        return '\n'.join('{}: {}'.format(k, v) for k, v in self.__dict__.items())

In [254]:
@dataclasses.dataclass
class Guess:
    word: str
    num: int
    dist: float
    
class SemantleSolver:
    
    def __init__(self, game: SemantleGame, n_random_guesses=10):
        self.game = game
        self.n_random_guesses = n_random_guesses
        self.closest_dist = float('inf')
        self.guesses = []  # List[Guess]
        self.guessed_words = set()  # for fast lookup
        self.best_guess = None
        
        self.EXH_THRESH = 0.5
        self.N_RANDOM = 10
        self.CONF_THRESH = 0.33
        
        self.stats = {
            'grd_high_conf': 0,
            'grd_random_dist': 0,
            'times_gradient': 0,
            'times_exhaustive': 0,
            'times_random': 0,
        }
        
    def _gradient_method(self, w_vecs, ann_index):
        # Use gradient method to get a closer guess.
        p1 = np.array(w_vecs[self.guesses[-1].word])
        p1_dist = self.guesses[-1].dist
        
        # Consider the few most recent points. 
        # Try and find one with a vector through p1 that points towards the target.
        best_point = None
        best_confidence = 0
        for i in range(2, min(10, len(self.guesses))):
            p2 = np.array(w_vecs[self.guesses[-i].word])
            p2_dist = self.guesses[-i].dist
            
            # where does p2->p1 point? and how well aligned is that spot with the target?
            target_point, confidence = directed_point_in_dist(p1, p2, p1_dist, p2_dist)
            if confidence > best_confidence:
                best_confidence = confidence
                best_point = target_point
                
        if best_confidence < self.CONF_THRESH:
            self.stats['grd_random_dist'] += 1
            vec = np.array(w_vecs[self.best_guess])
            best_point = random_point_in_dist(vec, self.closest_dist)
        else:
            self.stats['grd_high_conf'] += 1

        return best_point

    
    def find_next_guess(self, w_vecs, ann_index, idx_to_word) -> bool:
        if len(self.guesses) < self.N_RANDOM:
            self.stats['times_random'] += 1
            next_word = random.choice(list(w_vecs.keys()))
        elif self.closest_dist > self.EXH_THRESH:
            self.stats['times_gradient'] += 1
            v = self._gradient_method(w_vecs, ann_index)
            idxs_near_best = ann_index.get_nns_by_vector(v, 1000)
            for idx in idxs_near_best:
                w = idx_to_word[idx]
                if w not in self.guessed_words:
                    next_word = w
                    break
        else:
            self.stats['times_exhaustive'] += 1
            # We're close enough to start exhaustive search
            v = w_vecs[self.best_guess]
            idxs_near_best = ann_index.get_nns_by_vector(v, 1000)
            for idx in idxs_near_best:
                w = idx_to_word[idx]
                if w not in self.guessed_words:
                    next_word = w
                    break
            
        return next_word

    def make_guess(self, word):
        # guess the word
        win, dist = self.game.guess(word, w_vecs[word])
        self.guessed_words.add(word)
        self.guesses.append(Guess(word=word, dist=dist, num=len(self.guesses)+1))
        
        # see if this one's better
        if self.best_guess is None or dist < self.closest_dist:
            #print(word, round(dist, 3))
            self.closest_dist = dist
            self.best_guess = word
        
        if win:
            #print("I win!")
            return True
        else:
            return False
        

In [252]:
game = SemantleGame(w_vecs)
player = SemantleSolver(game)
print(game.target_word)
won = False
while not won:
    word = player.find_next_guess(w_vecs, ann_index, idx_to_word)
    won = player.make_guess(word)
    if len(player.guesses) > 5000:
        print('stopped. ')
        print('Best guess:', player.best_guess, 'dist:', player.closest_dist)
        break

print("")
print("last 10 guesses:")
for g in player.guesses[-10:]:
    print(g.word, round(g.dist, 3))
print(player.stats)

firms
kyō 5.493
aldaco 5.217
miano 5.033
rocko 4.771
impeachable 4.017
vegetation 3.961
species 3.773
trees 3.288
animals 2.98
every 1.708
number 1.336
liberal 1.177
gold 1.119
per 1.064
single 1.005
internet 0.881
launch 0.878
professional 0.847
deficit 0.772
credit 0.739
firms 0.0
I win!

last 10 guesses:
roof 0.981
entertainment 1.296
consecutive 1.13
guests 1.553
plus 1.073
watch 0.941
global 1.334
sweep 1.479
firms 0.0
{'grd_high_conf': 150, 'grd_random_dist': 51, 'times_gradient': 201, 'times_exhaustive': 0, 'times_random': 10}


In [272]:
# grid search time.
def run_trial(exh, n_rand, conf_thresh, w_vecs, idx_to_word, ann_index):
    # make a game
    
    game = SemantleGame(w_vecs)
    player = SemantleSolver(game)
    
    player.EXH_THRESH = exh
    player.N_RANDOM = n_rand
    player.CONF_THRESH = conf_thresh
    
    won = False
    while not won:
        word = player.find_next_guess(w_vecs, ann_index, idx_to_word)
        won = player.make_guess(word)
        if len(player.guesses) > 1000:
            break
    return len(player.guesses)

In [274]:

n_dims = [2, 3, 6, 9, 12, 15, 18]
exh_threshes = [0.2, 0.4, 0.6, 0.8, 1.0, 1.2]
n_randoms = [2, 5, 10, 15, 20]
conf_threshes = [0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
n_trials = 5

params_results = {}

for n_dim in n_dims:
    
    # set up space
    w_vecs = read_glove_file()
    pca = PCA(n_components=n_dim)
    mat_full = np.array([w_vecs[w] for w in w_vecs])
    mat = pca.fit_transform(mat_full)

    for i, w in enumerate(w_vecs.keys()):
        w_vecs[w] = mat[i, :]

    idx_to_word, ann_index = build_index(w_vecs)

    
    for exh in exh_threshes:
        for n_rand in n_randoms:
            for conf_thresh in conf_threshes:
                params = (n_dim, exh, n_rand, conf_thresh)
                for trial in range(n_trials):
                    n_guesses = run_trial(exh, n_rand, conf_thresh, w_vecs, idx_to_word, ann_index)
                    if not params in params_results:
                        params_results[params] = 0
                    params_results[params] += n_guesses
                params_results[params] /= n_trials
                print(params, params_results[params])

print(params_results)

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(2, 0.2, 2, 0.2) 26.6
(2, 0.2, 2, 0.25) 20.2
(2, 0.2, 2, 0.3) 24.0
(2, 0.2, 2, 0.35) 17.8
(2, 0.2, 2, 0.4) 34.2
(2, 0.2, 2, 0.45) 25.0
(2, 0.2, 2, 0.5) 20.8
(2, 0.2, 5, 0.2) 25.8
(2, 0.2, 5, 0.25) 26.4
(2, 0.2, 5, 0.3) 33.2
(2, 0.2, 5, 0.35) 32.2
(2, 0.2, 5, 0.4) 29.6
(2, 0.2, 5, 0.45) 30.2
(2, 0.2, 5, 0.5) 28.8
(2, 0.2, 10, 0.2) 32.6
(2, 0.2, 10, 0.25) 35.2
(2, 0.2, 10, 0.3) 31.4
(2, 0.2, 10, 0.35) 29.2
(2, 0.2, 10, 0.4) 27.4
(2, 0.2, 10, 0.45) 22.2
(2, 0.2, 10, 0.5) 28.4
(2, 0.2, 15, 0.2) 24.8
(2, 0.2, 15, 0.25) 43.0
(2, 0.2, 15, 0.3) 30.0
(2, 0.2, 15, 0.35) 24.4
(2, 0.2, 15, 0.4) 31.2
(2, 0.2, 15, 0.45) 32.0
(2, 0.2, 15, 0.5) 32.6
(2, 0.2, 20, 0.2) 45.6
(2, 0.2, 20, 0.25) 35.6
(2, 0.2, 20, 0.3) 50.8
(2, 0.2, 20, 0.35) 37.8
(2, 0.2, 20, 0.4) 38.4
(2, 0.2, 20, 0.45) 37.2
(2, 0.2, 20, 0.5) 46.2
(2, 0.4, 2, 0.2) 26.8
(2, 0.4, 2, 0.25) 29.0
(2, 0.4, 2, 0.3) 43.2
(2, 0.4, 2, 0.35) 27.8
(2, 0.4, 2, 0.4) 24.6
(2, 0.4, 2, 0.45) 20.2
(2, 0.4, 2, 0.5) 19.6
(2, 0.4, 5, 0.2) 54.2
(2, 0.4, 5, 0.2

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(3, 0.2, 2, 0.2) 20.0
(3, 0.2, 2, 0.25) 26.8
(3, 0.2, 2, 0.3) 31.2
(3, 0.2, 2, 0.35) 26.8
(3, 0.2, 2, 0.4) 33.6
(3, 0.2, 2, 0.45) 32.6
(3, 0.2, 2, 0.5) 25.6
(3, 0.2, 5, 0.2) 24.8
(3, 0.2, 5, 0.25) 26.2
(3, 0.2, 5, 0.3) 22.6
(3, 0.2, 5, 0.35) 34.4
(3, 0.2, 5, 0.4) 28.0
(3, 0.2, 5, 0.45) 29.6
(3, 0.2, 5, 0.5) 32.4
(3, 0.2, 10, 0.2) 40.6
(3, 0.2, 10, 0.25) 42.8
(3, 0.2, 10, 0.3) 40.4
(3, 0.2, 10, 0.35) 32.8
(3, 0.2, 10, 0.4) 34.8
(3, 0.2, 10, 0.45) 35.6
(3, 0.2, 10, 0.5) 35.4
(3, 0.2, 15, 0.2) 43.0
(3, 0.2, 15, 0.25) 48.0
(3, 0.2, 15, 0.3) 34.6
(3, 0.2, 15, 0.35) 32.2
(3, 0.2, 15, 0.4) 39.8
(3, 0.2, 15, 0.45) 34.4
(3, 0.2, 15, 0.5) 46.4
(3, 0.2, 20, 0.2) 32.4
(3, 0.2, 20, 0.25) 55.8
(3, 0.2, 20, 0.3) 42.2
(3, 0.2, 20, 0.35) 43.8
(3, 0.2, 20, 0.4) 39.4
(3, 0.2, 20, 0.45) 39.0
(3, 0.2, 20, 0.5) 34.0
(3, 0.4, 2, 0.2) 33.6
(3, 0.4, 2, 0.25) 30.8
(3, 0.4, 2, 0.3) 31.0
(3, 0.4, 2, 0.35) 22.8
(3, 0.4, 2, 0.4) 28.6
(3, 0.4, 2, 0.45) 24.8
(3, 0.4, 2, 0.5) 25.8
(3, 0.4, 5, 0.2) 30.8
(3, 0.4, 5, 0.2

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(6, 0.2, 2, 0.2) 131.0
(6, 0.2, 2, 0.25) 88.4
(6, 0.2, 2, 0.3) 135.6
(6, 0.2, 2, 0.35) 72.0
(6, 0.2, 2, 0.4) 97.4
(6, 0.2, 2, 0.45) 107.0
(6, 0.2, 2, 0.5) 78.6
(6, 0.2, 5, 0.2) 87.0
(6, 0.2, 5, 0.25) 42.0
(6, 0.2, 5, 0.3) 114.4
(6, 0.2, 5, 0.35) 80.4
(6, 0.2, 5, 0.4) 69.0
(6, 0.2, 5, 0.45) 52.6
(6, 0.2, 5, 0.5) 63.2
(6, 0.2, 10, 0.2) 91.6
(6, 0.2, 10, 0.25) 63.6
(6, 0.2, 10, 0.3) 146.6
(6, 0.2, 10, 0.35) 67.6
(6, 0.2, 10, 0.4) 73.2
(6, 0.2, 10, 0.45) 67.8
(6, 0.2, 10, 0.5) 56.6
(6, 0.2, 15, 0.2) 116.4
(6, 0.2, 15, 0.25) 91.6
(6, 0.2, 15, 0.3) 194.4
(6, 0.2, 15, 0.35) 66.0
(6, 0.2, 15, 0.4) 68.4
(6, 0.2, 15, 0.45) 71.2
(6, 0.2, 15, 0.5) 54.2
(6, 0.2, 20, 0.2) 109.0
(6, 0.2, 20, 0.25) 130.8
(6, 0.2, 20, 0.3) 93.2
(6, 0.2, 20, 0.35) 134.4
(6, 0.2, 20, 0.4) 75.2
(6, 0.2, 20, 0.45) 157.4
(6, 0.2, 20, 0.5) 78.6
(6, 0.4, 2, 0.2) 79.0
(6, 0.4, 2, 0.25) 109.6
(6, 0.4, 2, 0.3) 106.6
(6, 0.4, 2, 0.35) 62.4
(6, 0.4, 2, 0.4) 78.8
(6, 0.4, 2, 0.45) 47.6
(6, 0.4, 2, 0.5) 56.6
(6, 0.4, 5, 0.2) 102.8
(

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(9, 0.2, 2, 0.2) 331.2
(9, 0.2, 2, 0.25) 178.2
(9, 0.2, 2, 0.3) 96.2
(9, 0.2, 2, 0.35) 47.6
(9, 0.2, 2, 0.4) 90.8
(9, 0.2, 2, 0.45) 43.8
(9, 0.2, 2, 0.5) 67.2
(9, 0.2, 5, 0.2) 225.2
(9, 0.2, 5, 0.25) 136.6
(9, 0.2, 5, 0.3) 107.6
(9, 0.2, 5, 0.35) 60.6
(9, 0.2, 5, 0.4) 94.8
(9, 0.2, 5, 0.45) 84.6
(9, 0.2, 5, 0.5) 73.6
(9, 0.2, 10, 0.2) 114.2
(9, 0.2, 10, 0.25) 138.8
(9, 0.2, 10, 0.3) 85.6
(9, 0.2, 10, 0.35) 123.2
(9, 0.2, 10, 0.4) 122.0
(9, 0.2, 10, 0.45) 80.2
(9, 0.2, 10, 0.5) 67.6
(9, 0.2, 15, 0.2) 353.8
(9, 0.2, 15, 0.25) 138.2
(9, 0.2, 15, 0.3) 115.8
(9, 0.2, 15, 0.35) 84.8
(9, 0.2, 15, 0.4) 65.6
(9, 0.2, 15, 0.45) 134.8
(9, 0.2, 15, 0.5) 96.0
(9, 0.2, 20, 0.2) 167.2
(9, 0.2, 20, 0.25) 92.4
(9, 0.2, 20, 0.3) 118.0
(9, 0.2, 20, 0.35) 144.0
(9, 0.2, 20, 0.4) 102.2
(9, 0.2, 20, 0.45) 66.8
(9, 0.2, 20, 0.5) 127.6
(9, 0.4, 2, 0.2) 207.6
(9, 0.4, 2, 0.25) 220.4
(9, 0.4, 2, 0.3) 96.8
(9, 0.4, 2, 0.35) 343.8
(9, 0.4, 2, 0.4) 61.2
(9, 0.4, 2, 0.45) 41.6
(9, 0.4, 2, 0.5) 175.0
(9, 0.4, 5, 0.2

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(12, 0.2, 2, 0.2) 182.8
(12, 0.2, 2, 0.25) 216.2
(12, 0.2, 2, 0.3) 63.6
(12, 0.2, 2, 0.35) 118.0
(12, 0.2, 2, 0.4) 132.8
(12, 0.2, 2, 0.45) 151.4
(12, 0.2, 2, 0.5) 242.6
(12, 0.2, 5, 0.2) 140.2
(12, 0.2, 5, 0.25) 144.4
(12, 0.2, 5, 0.3) 104.2
(12, 0.2, 5, 0.35) 169.2
(12, 0.2, 5, 0.4) 50.4
(12, 0.2, 5, 0.45) 209.6
(12, 0.2, 5, 0.5) 75.2
(12, 0.2, 10, 0.2) 60.0
(12, 0.2, 10, 0.25) 138.2
(12, 0.2, 10, 0.3) 239.0
(12, 0.2, 10, 0.35) 125.2
(12, 0.2, 10, 0.4) 69.8
(12, 0.2, 10, 0.45) 75.2
(12, 0.2, 10, 0.5) 89.2
(12, 0.2, 15, 0.2) 255.6
(12, 0.2, 15, 0.25) 232.0
(12, 0.2, 15, 0.3) 86.2
(12, 0.2, 15, 0.35) 177.4
(12, 0.2, 15, 0.4) 185.8
(12, 0.2, 15, 0.45) 64.4
(12, 0.2, 15, 0.5) 161.4
(12, 0.2, 20, 0.2) 347.0
(12, 0.2, 20, 0.25) 287.6
(12, 0.2, 20, 0.3) 206.6
(12, 0.2, 20, 0.35) 89.8
(12, 0.2, 20, 0.4) 74.8
(12, 0.2, 20, 0.45) 300.2
(12, 0.2, 20, 0.5) 78.0
(12, 0.4, 2, 0.2) 284.4
(12, 0.4, 2, 0.25) 77.8
(12, 0.4, 2, 0.3) 160.0
(12, 0.4, 2, 0.35) 222.6
(12, 0.4, 2, 0.4) 111.0
(12, 0.4, 2, 0.

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(15, 0.2, 2, 0.2) 151.4
(15, 0.2, 2, 0.25) 136.6
(15, 0.2, 2, 0.3) 212.2
(15, 0.2, 2, 0.35) 80.8
(15, 0.2, 2, 0.4) 122.2
(15, 0.2, 2, 0.45) 106.2
(15, 0.2, 2, 0.5) 176.8
(15, 0.2, 5, 0.2) 177.2
(15, 0.2, 5, 0.25) 273.2
(15, 0.2, 5, 0.3) 79.4
(15, 0.2, 5, 0.35) 147.4
(15, 0.2, 5, 0.4) 61.0
(15, 0.2, 5, 0.45) 181.0
(15, 0.2, 5, 0.5) 111.0
(15, 0.2, 10, 0.2) 143.0
(15, 0.2, 10, 0.25) 140.6
(15, 0.2, 10, 0.3) 90.0
(15, 0.2, 10, 0.35) 134.0
(15, 0.2, 10, 0.4) 104.0
(15, 0.2, 10, 0.45) 245.8
(15, 0.2, 10, 0.5) 124.6
(15, 0.2, 15, 0.2) 141.8
(15, 0.2, 15, 0.25) 267.4
(15, 0.2, 15, 0.3) 92.8
(15, 0.2, 15, 0.35) 69.2
(15, 0.2, 15, 0.4) 133.6
(15, 0.2, 15, 0.45) 93.6
(15, 0.2, 15, 0.5) 178.2
(15, 0.2, 20, 0.2) 400.4
(15, 0.2, 20, 0.25) 326.2
(15, 0.2, 20, 0.3) 106.8
(15, 0.2, 20, 0.35) 83.8
(15, 0.2, 20, 0.4) 141.8
(15, 0.2, 20, 0.45) 97.4
(15, 0.2, 20, 0.5) 65.6
(15, 0.4, 2, 0.2) 332.0
(15, 0.4, 2, 0.25) 103.4
(15, 0.4, 2, 0.3) 79.4
(15, 0.4, 2, 0.35) 449.0
(15, 0.4, 2, 0.4) 107.0
(15, 0.4, 2, 

  0%|          | 0/400000 [00:00<?, ?it/s]

  0%|          | 0/336158 [00:00<?, ?it/s]

(18, 0.2, 2, 0.2) 184.2
(18, 0.2, 2, 0.25) 391.8
(18, 0.2, 2, 0.3) 339.4
(18, 0.2, 2, 0.35) 84.8
(18, 0.2, 2, 0.4) 104.8
(18, 0.2, 2, 0.45) 148.6
(18, 0.2, 2, 0.5) 99.0
(18, 0.2, 5, 0.2) 193.2
(18, 0.2, 5, 0.25) 113.2
(18, 0.2, 5, 0.3) 106.8
(18, 0.2, 5, 0.35) 276.2
(18, 0.2, 5, 0.4) 269.4
(18, 0.2, 5, 0.45) 149.8
(18, 0.2, 5, 0.5) 124.4
(18, 0.2, 10, 0.2) 169.6
(18, 0.2, 10, 0.25) 227.4
(18, 0.2, 10, 0.3) 168.0
(18, 0.2, 10, 0.35) 169.0
(18, 0.2, 10, 0.4) 144.2
(18, 0.2, 10, 0.45) 60.0
(18, 0.2, 10, 0.5) 89.6
(18, 0.2, 15, 0.2) 282.2
(18, 0.2, 15, 0.25) 203.4
(18, 0.2, 15, 0.3) 132.8
(18, 0.2, 15, 0.35) 156.6
(18, 0.2, 15, 0.4) 154.6
(18, 0.2, 15, 0.45) 101.8
(18, 0.2, 15, 0.5) 81.0
(18, 0.2, 20, 0.2) 542.2
(18, 0.2, 20, 0.25) 140.2
(18, 0.2, 20, 0.3) 93.6
(18, 0.2, 20, 0.35) 135.4
(18, 0.2, 20, 0.4) 193.4
(18, 0.2, 20, 0.45) 72.2
(18, 0.2, 20, 0.5) 63.0
(18, 0.4, 2, 0.2) 247.0
(18, 0.4, 2, 0.25) 172.2
(18, 0.4, 2, 0.3) 122.2
(18, 0.4, 2, 0.35) 112.6
(18, 0.4, 2, 0.4) 73.4
(18, 0.4, 2

In [277]:
print(list(sorted(zip(params_results.items()), key=lambda x: x[1]))[:10])

IndexError: tuple index out of range

In [288]:
ls = list(sorted(zip(params_results.items()), key = lambda x: x[0][1]))

In [289]:
for item in ls:
    print(item)

(((2, 0.2, 2, 0.35), 17.8),)
(((2, 0.4, 2, 0.5), 19.6),)
(((3, 0.2, 2, 0.2), 20.0),)
(((2, 0.2, 2, 0.25), 20.2),)
(((2, 0.4, 2, 0.45), 20.2),)
(((2, 0.2, 2, 0.5), 20.8),)
(((3, 1.2, 5, 0.3), 21.6),)
(((2, 0.2, 10, 0.45), 22.2),)
(((3, 0.6, 5, 0.25), 22.2),)
(((3, 0.2, 5, 0.3), 22.6),)
(((3, 0.4, 2, 0.35), 22.8),)
(((2, 0.6, 5, 0.35), 23.4),)
(((2, 0.2, 2, 0.3), 24.0),)
(((3, 0.4, 5, 0.45), 24.0),)
(((2, 0.4, 5, 0.3), 24.2),)
(((2, 0.2, 15, 0.35), 24.4),)
(((6, 1.0, 5, 0.45), 24.4),)
(((2, 0.4, 2, 0.4), 24.6),)
(((2, 0.2, 15, 0.2), 24.8),)
(((3, 0.2, 5, 0.2), 24.8),)
(((3, 0.4, 2, 0.45), 24.8),)
(((2, 0.2, 2, 0.45), 25.0),)
(((3, 0.2, 2, 0.5), 25.6),)
(((2, 0.2, 5, 0.2), 25.8),)
(((3, 0.4, 2, 0.5), 25.8),)
(((2, 0.4, 5, 0.4), 26.2),)
(((3, 0.2, 5, 0.25), 26.2),)
(((2, 0.2, 5, 0.25), 26.4),)
(((2, 0.2, 2, 0.2), 26.6),)
(((2, 0.4, 2, 0.2), 26.8),)
(((3, 0.2, 2, 0.25), 26.8),)
(((3, 0.2, 2, 0.35), 26.8),)
(((2, 0.2, 10, 0.4), 27.4),)
(((2, 0.8, 5, 0.3), 27.4),)
(((3, 0.6, 2, 0.45), 27.6),)