In [151]:
import pandas as pd
import numpy as np
from keras.models import Sequential
from keras.layers import LSTM, TimeDistributed, Dense, Activation, Dropout

# Prepare data

In [152]:
def remove_non_alpha(text):
    return ''.join(i for i in text if i.isalpha() or i == ' ' or i == '\n')

In [153]:
data = ''.join(pokemon_names)
chars = list(set(data))
data_size, vocab_size = len(data), len(chars)
print('There are %d total characters and %d unique characters in your data.' % (data_size, vocab_size))

There are 9600 total characters and 28 unique characters in your data.


In [154]:
char_to_ix = { ch:i for i,ch in enumerate(sorted(chars)) }
ix_to_char = { i:ch for i,ch in enumerate(sorted(chars)) }
print(ix_to_char)

{0: ' ', 1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 27: 'é'}


# Keras Format
Expected format: (batch_size, timesteps, input_dim)
Source: https://keras.io/layers/recurrent/

In [155]:
SEQ_LENGTH = len(max(pokemon_names, key=len))
print(max(pokemon_names, key=len), SEQ_LENGTH)

bulbasaur    12


In [156]:
data = ''.join(pokemon_names)
batch_size = len(data) // SEQ_LENGTH
X = np.zeros((batch_size, SEQ_LENGTH, vocab_size))
y = np.zeros((batch_size, SEQ_LENGTH, vocab_size))

for i in range(batch_size):
    X_sequence = data[i*SEQ_LENGTH:(i+1)*SEQ_LENGTH]
    X_sequence_ix = [char_to_ix[val] for val in X_sequence]
    
    input_sequence = np.zeros((SEQ_LENGTH, vocab_size))
    for j in range(SEQ_LENGTH):
        input_sequence[j][X_sequence_ix[j]] = 1.
    X[i] = input_sequence

    y_sequence = data[i*SEQ_LENGTH+1:(i+1)*SEQ_LENGTH+1]
    y_sequence_ix = [char_to_ix[val] for val in y_sequence]
    
    target_sequence = np.zeros((SEQ_LENGTH, vocab_size))
    new_length = len(y_sequence)
    for j in range(new_length):
        target_sequence[j][y_sequence_ix[j]] = 1.
    y[i] = target_sequence

In [157]:
print(X.shape, y.shape)

(800, 12, 28) (800, 12, 28)


# Train

In [158]:
model = Sequential()
model.add(LSTM(vocab_size, input_shape=(None, vocab_size), return_sequences=True))
model.add(LSTM(vocab_size, return_sequences=True))
model.add(LSTM(vocab_size, return_sequences=True))
model.add(Dropout(0.5))
model.add(TimeDistributed(Dense(vocab_size)))
model.add(Activation('softmax'))
model.compile(loss="categorical_crossentropy", 
              optimizer="adam")

In [None]:
# model.load_weights('<some-prev-day>.hdf5')

In [159]:
def generate_text(model, max_length):
    # starting with random character
    ix = [np.random.randint(2, vocab_size)] # don't start off with space or '\n'
    y_char = [ix_to_char[ix[-1]]]
    X = np.zeros((1, max_length, vocab_size))
    for i in range(max_length):
        # appending the last predicted character to sequence
        X[0, i, :][ix[-1]] = 1
        print(ix_to_char[ix[-1]], end="")
        ix = np.argmax(model.predict(X[:, :i+1, :])[0], 1)
        y_char.append(ix_to_char[ix[-1]])
    return ('').join(y_char)

In [160]:
NUM_EPOCHS = 1000

In [161]:
for i in range(NUM_EPOCHS):
    print('\n\n')
    model.fit(X, y, batch_size=50, verbose=0, epochs=1) # show progress
    print(i) 
    generate_text(model, SEQ_LENGTH)




0
b           


1
t           


2
v           


3
k           


4
y           


5
o           


6
qi          


7
qi          


8
éi          


9
eo          


10
dio         


11
jio         


12
eioo        


13
iioo        


14
uioo        


15
fiooo       


16
niooo       


17
riooo       


18
xiooo       


19
miioo       


20
miooo       


21
wiooo       


22
jiioo       


23
xaaoo       


24
yaaa        


25
raaoo       


26
eaaoo       


27
maaoo       


28
qaiooo      


29
raaoo       


30
waaao       


31
oaaaa       


32
baaaa       


33
xaaa        


34
daaoo       


35
uaaaa       


36
faaoo       


37
oaaaa       


38
faaaa       


39
zaaaa       


40
jaaaa       


41
paaaa       


42
saaaa       


43
waaaa       


44
xaaa        


45
caaaoo      


46
gaaaa       


47
raaao       


48
oaaaa       


49
naaaa       


50
uaaaa       


51
baaaa       


52
éaaaa       


53
zaaaa       


54
waaaa       


55
baaaaa      



437
panter      


438
urratara    


439
ércacan     


440
zanterre    


441
eracante    


442
harinira    


443
manta       


444
danta       


445
cantana     


446
panter      


447
eracanter   


448
barilan     


449
wartan      


450
ércacan     


451
lantant     


452
lantant     


453
bariler     


454
danta       


455
karina      


456
zanterre    


457
mantan      


458
eracante    


459
cantana     


460
faranira    


461
lantant     


462
sralile     


463
orastas     


464
danter      


465
ranilan     


466
cantana     


467
zantora     


468
iracana     


469
eracara     


470
yrana       


471
garanira    


472
vaniline    


473
urranira    


474
karina      


475
eracara     


476
karina      


477
jargine     


478
yrana       


479
urrana      


480
jargine     


481
hariline    


482
mantan      


483
wartin      


484
zantora     


485
iracana     


486
yrana       


487
karina      


488
urrana      


489
eracara 

869
harpine     


870
karontass   


871
ércla       


872
taranta     


873
slatte      


874
jargang     


875
gargante    


876
darganite   


877
incaneon    


878
ércla       


879
canchante   


880
manta       


881
jargore     


882
zantoll     


883
vanilla     


884
karrinl     


885
nantasf     


886
ércla       


887
xandara     


888
quirlina    


889
inceleune   


890
taronta     


891
zantol      


892
canchante   


893
karrinl     


894
parone      


895
wartis      


896
ingalet     


897
jargore     


898
gargante    


899
wartis      


900
yrana       


901
haroman     


902
gargan      


903
zantols     


904
gargan      


905
tarontasssr 


906
mantato     


907
mantato     


908
quirlina    


909
vanilla     


910
karrinl     


911
orachastar  


912
jargarg     


913
gargante    


914
wartins     


915
harpine     


916
nanasasl    


917
slattel     


918
tarontasssr 


919
wartis      


920
slattel     


921
urranite

In [168]:
import datetime
now = datetime.datetime.today().strftime('%Y-%m-%d')
model.save_weights('{save_name}.hdf5'.format(save_name=now))