In [1]:
import os
from os.path import join, pardir
from collections import Counter
from copy import deepcopy
import numpy as np
from deap import base, creator, algorithms, tools
from dssg_challenge import compute_cost, check_keyboard

RNG_SEED = 0
DATA_DSSG = join(pardir, 'data', 'raw')

rng = np.random.RandomState(RNG_SEED)

In [2]:
os.listdir(DATA_DSSG)

['pt-corpus.txt', '.gitkeep', 'pt-keys.txt', 'en-keys.txt', 'en-corpus.txt']

In [3]:
# get keys
with open(join(DATA_DSSG, 'pt-keys.txt'), 'r') as file:
    keys = file.read()

# get corpus example
with open(join(DATA_DSSG, 'pt-corpus.txt'), 'r') as file:
    corpus = file.read()

keys = ''.join(keys.split('\n'))
corpus = ''.join(corpus.split(keys)).split('\n')[0]

Some keys are used to signal special characters. Namely,

- The ENTER key is represented as 0.
- The shift key for capitalization is represented as ^.
- The backspace key is represented as <.
- All the remaining characters not found in the valid keys are encoded as #.
- Empty keys will contain the character _.


In [4]:
len(keys), keys

(36, "ABCDEFGHIJKLMNOPQ RSTUVWXYZ0.#,^?<'~")

## The most basic approaches

In [5]:
Counter(corpus).most_common()[:10]

[(' ', 138),
 ('A', 137),
 ('R', 91),
 ('O', 78),
 ('E', 73),
 ('S', 70),
 ('0', 44),
 ('#', 37),
 ('T', 36),
 ('M', 35)]

In [6]:
baseline = ''.join([i[0] for i in Counter(corpus).most_common()])
baseline = baseline + ''.join([i for i in keys if i not in baseline]) + ' '
baseline

" AROES0#TMNUID,CLPB'VH<G~JQFZ.KWXY^? "

In [7]:
shuffled = list(baseline)
rng.shuffle(shuffled)

check_keyboard(baseline, keys)
check_keyboard(keys+' ', keys)
check_keyboard(shuffled, keys)

print('Shuffled cost:\t\t', compute_cost(''.join(shuffled), corpus))
print('Original keys cost:\t', compute_cost(keys+' ', corpus))
print('Baseline cost:\t\t', compute_cost(baseline, corpus))

Shuffled cost:		 5088.806814781539
Original keys cost:	 4541.244466418746
Baseline cost:		 3189.713309637487


## First attempt with GA

In [8]:
keys_list = list(keys)

def evaluate(individual):
    """
    Computes the cost for each individual.
    """
    try:
        check_keyboard(individual, keys)
        return [compute_cost(''.join(list(individual)), corpus)]
    except AssertionError:
        return [np.inf]

def mutFlip(ind1, ind2):
    """Execute a two points crossover with copy on the input individuals. The
    copy is required because the slicing in numpy returns a view of the data,
    which leads to a self overwritting in the swap operation.
    """

    ind = ind1.copy()
    for x, value in np.ndenumerate(ind):
        if np.random.random() < .05:
            ind[x] = np.random.choice(keys_list)
    try:
        check_keyboard(ind, keys)
        return ind, ind2
    except AssertionError:
        return mutFlip(individual, ind2)
    
    return ind, ind2


In [10]:
creator.create('FitnessMin', base.Fitness, weights=(-1.0,))
creator.create('Individual', np.ndarray, fitness=creator.FitnessMin)

toolbox = base.Toolbox()

# Tool to randomly initialize an individual
toolbox.register('attribute',
        np.random.permutation, np.array(list(baseline))
)

toolbox.register('individual',
    tools.initIterate,
    creator.Individual,
    toolbox.attribute
)

toolbox.register('population',
    tools.initRepeat,
    list,
    toolbox.individual
)

toolbox.register("evaluate", evaluate)
toolbox.register("mate", tools.cxOnePoint)
toolbox.register("mutate", tools.mutShuffleIndexes, indpb=0.05)
toolbox.register("select", tools.selTournament, tournsize=3)

def main():
    np.random.seed(64)

    pop = toolbox.population(n=5)

    # Numpy equality function (operators.eq) between two arrays returns the
    # equality element wise, which raises an exception in the if similar()
    # check of the hall of fame. Using a different equality function like
    # numpy.array_equal or numpy.allclose solve this issue.
    hof = tools.HallOfFame(1, similar=np.array_equal)

    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    algorithms.eaSimple(pop, toolbox, cxpb=0, mutpb=0.6, ngen=1000, stats=stats,
                        halloffame=hof)

    return pop, stats, hof


pop, stats, hof = main()




gen	nevals	avg   	std    	min    	max    
0  	5     	5124.5	215.719	4924.93	5496.35
1  	5     	5102.7	172.209	4805.56	5309.41
2  	4     	4947.41	204.156	4739.22	5237.11
3  	4     	4876.08	155.285	4739.22	5087.46
4  	2     	4857.91	188.611	4739.22	5225.97
5  	2     	4776.78	39.3265	4739.22	4845.92
6  	4     	4788.22	62.4632	4699.64	4874.88
7  	2     	4705.59	85.9379	4562.31	4828.77
8  	3     	4794.84	272.239	4488.07	5302.44
9  	2     	4724.32	49.6916	4698   	4823.6 
10 	1     	4698   	0      	4698   	4698   
11 	3     	4808.16	163.697	4698   	5130.61
12 	4     	4768.66	120.284	4604.67	4945.67
13 	3     	4721.48	147.665	4526.53	4906.75
14 	5     	4623.41	130.21 	4461.89	4836.03
15 	2     	4551.92	179.933	4383.6 	4894.41
16 	4     	4525.61	112.688	4383.6 	4688.57
17 	3     	4533.52	98.2449	4383.6 	4666.13
18 	4     	4447.19	58.3612	4376.37	4527.76
19 	1     	4377.82	2.89104	4376.37	4383.6 
20 	2     	4384.38	16.0149	4376.37	4416.41
21 	3     	4378.23	42.2841	4314.27	4447.79
22 	2     	450

KeyboardInterrupt: 