In [1]:
!pip install deap -q

In [1]:
import os
from os.path import join, pardir
from collections import Counter
from copy import deepcopy
import numpy as np
from deap import base, creator, algorithms, tools
from dssg_challenge import compute_cost, check_keyboard

RNG_SEED = 0
DATA_DSSG = join(pardir, 'data', 'raw')

rng = np.random.RandomState(RNG_SEED)

In [2]:
os.listdir(DATA_DSSG)

['pt-corpus.txt', '.gitkeep', 'pt-keys.txt', 'en-keys.txt', 'en-corpus.txt']

In [3]:
# get keys
with open(join(DATA_DSSG, 'pt-keys.txt'), 'r') as file:
    keys = file.read()

# get corpus example
with open(join(DATA_DSSG, 'pt-corpus.txt'), 'r') as file:
    corpus = file.read()

keys = ''.join(keys.split('\n'))
corpus = ''.join(corpus.split(keys)).split('\n')[0]

Some keys are used to signal special characters. Namely,

- The ENTER key is represented as 0.
- The shift key for capitalization is represented as ^.
- The backspace key is represented as <.
- All the remaining characters not found in the valid keys are encoded as #.
- Empty keys will contain the character _.


In [5]:
len(keys), keys

(36, "ABCDEFGHIJKLMNOPQ RSTUVWXYZ0.#,^?<'~")

## The most basic approaches

In [None]:
Counter(corpus).most_common()[:10]

In [6]:
baseline = ''.join([i[0] for i in Counter(corpus).most_common()])
baseline = baseline + ''.join([i for i in keys if i not in baseline]) + ' '
baseline

" AROES0#TMNUID,CLPB'VH<G~JQFZ.KWXY^? "

In [7]:
shuffled = list(baseline)
rng.shuffle(shuffled)

check_keyboard(baseline, keys)
check_keyboard(keys+' ', keys)
check_keyboard(shuffled, keys)

print('Shuffled cost:\t\t', compute_cost(''.join(shuffled), corpus))
print('Original keys cost:\t', compute_cost(keys+' ', corpus))
print('Baseline cost:\t\t', compute_cost(baseline, corpus))

Shuffled cost:		 5088.806814781539
Original keys cost:	 4541.244466418746
Baseline cost:		 3189.713309637487


## First attempt with GA

In [8]:
keys_list = list(keys)

def evaluate(individual):
    """
    Computes the cost for each individual.
    """
    try:
        check_keyboard(individual, keys)
        return [compute_cost(''.join(list(individual)), corpus)]
    except AssertionError:
        return [np.inf]

def mutFlip(ind1, ind2):
    """Execute a two points crossover with copy on the input individuals. The
    copy is required because the slicing in numpy returns a view of the data,
    which leads to a self overwritting in the swap operation.
    """

    ind = ind1.copy()
    for x, value in np.ndenumerate(ind):
        if np.random.random() < .05:
            ind[x] = np.random.choice(keys_list)
    try:
        check_keyboard(ind, keys)
        return ind, ind2
    except AssertionError:
        return mutFlip(individual, ind2)
    
    return ind, ind2


In [11]:
creator.create('FitnessMin', base.Fitness, weights=(-1.0,))
creator.create('Individual', np.ndarray, fitness=creator.FitnessMin)

toolbox = base.Toolbox()

# Tool to randomly initialize an individual
toolbox.register('attribute',
        np.random.permutation, np.array(list(baseline))
)

toolbox.register('individual',
    tools.initIterate,
    creator.Individual,
    toolbox.attribute
)

toolbox.register('population',
    tools.initRepeat,
    list,
    toolbox.individual
)

toolbox.register("evaluate", evaluate)
toolbox.register("mate", tools.cxOnePoint)
toolbox.register("mutate", tools.mutShuffleIndexes, indpb=0.05)
toolbox.register("select", tools.selTournament, tournsize=3)

def main():
    np.random.seed(64)

    pop = toolbox.population(n=10)

    # Numpy equality function (operators.eq) between two arrays returns the
    # equality element wise, which raises an exception in the if similar()
    # check of the hall of fame. Using a different equality function like
    # numpy.array_equal or numpy.allclose solve this issue.
    hof = tools.HallOfFame(1, similar=np.array_equal)

    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)

    algorithms.eaSimple(pop, toolbox, cxpb=0, mutpb=0.6, ngen=1000, stats=stats,
                        halloffame=hof)

    return pop, stats, hof


pop, stats, hof = main()


gen	nevals	avg    	std    	min   	max    
0  	10    	5127.31	353.095	4541.6	5872.66
1  	8     	4973.65	60.5206	4895.18	5096.75
2  	7     	4986.66	109.597	4858.07	5277.38
3  	5     	4898.95	62.8849	4817.44	5000.87
4  	8     	4912.79	261.166	4585.33	5639.21
5  	7     	4710.69	174.203	4301.33	4910.99
6  	8     	4576.08	244.768	4285.42	5057.85
7  	4     	4401.64	216.663	4075.38	4885.66
8  	8     	4214.07	84.3895	4069.19	4314.63
9  	5     	4169.14	71.1615	4069.19	4280.15
10 	8     	4275.21	252.176	4050.34	4887.66
11 	4     	4123.92	228.158	3947.92	4800.43
12 	5     	4056.35	137.573	3947.92	4394   
13 	4     	4018.72	128.03 	3940.36	4342.55
14 	7     	4056.48	132.883	3947.92	4279.84
15 	5     	3998.44	58.0004	3938.77	4118.12
16 	6     	4113.03	190.275	3932.39	4390.83
17 	5     	4095.37	191.78 	3894.79	4411.41
18 	4     	4025.32	192.869	3894.79	4423.98
19 	7     	4107.04	259.799	3894.79	4717.22
20 	6     	4029.49	202.748	3891.57	4409.22
21 	8     	4080.3 	319.98 	3749.65	4974.66
22 	5     	39

KeyboardInterrupt: 