In [3]:
import os
import numpy as np
import re
import shutil
import tensorflow as tf

DATA_DIR = "./data"
CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")
LOG_DIR = os.path.join(DATA_DIR, "logs")

In [4]:
#def clean_logs():
#    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
#    shutil.rmtree(LOG_DIR, ignore_errors=True)
    
def download_and_read(urls):
    texts = []
    for i, url in enumerate(urls):
        p = tf.keras.utils.get_file("ex1-{:d}.txt".format(i), url,
            cache_dir=".")
        text = open(p, mode="r", encoding="utf-8").read()
        # remove byte order mark
        text = text.replace("\ufeff", "")
        # remove newlines
        text = text.replace('\n', ' ')
        text = re.sub(r'\s+', " ", text)
        # add it to the list
        texts.extend(text)
    return texts

In [5]:
def split_train_labels(sequence):
    input_seq = sequence[0:-1]
    output_seq = sequence[1:]
    return input_seq, output_seq

In [6]:
class CharGenModel(tf.keras.Model):
    def __init__(self, vocab_size, num_timesteps, 
            embedding_dim, **kwargs):
        super(CharGenModel, self).__init__(**kwargs)
        self.embedding_layer = tf.keras.layers.Embedding(
            vocab_size,
            embedding_dim
        )
        self.rnn_layer = tf.keras.layers.GRU(
            num_timesteps,
            recurrent_initializer="glorot_uniform",
            recurrent_activation="sigmoid",
            stateful=True,
            return_sequences=True
        )
        self.dense_layer = tf.keras.layers.Dense(vocab_size)

    def call(self, x):
        x = self.embedding_layer(x)
        x = self.rnn_layer(x)
        x = self.dense_layer(x)
        return x

In [7]:
def loss(labels, predictions):
    return tf.losses.sparse_categorical_crossentropy(
        labels,
        predictions,
        from_logits=True
    )

In [8]:
def generate_text(model, prefix_string, char2idx, idx2char,
        num_chars_to_generate=1000, temperature=1.0):
    input = [char2idx[s] for s in prefix_string]
    input = tf.expand_dims(input, 0)
    text_generated = []
    model.rnn_layer.reset_states()
    for i in range(num_chars_to_generate):
        preds = model(input)
        preds = tf.squeeze(preds, 0) / temperature
        # predict char returned by model
        pred_id = tf.random.categorical(preds, num_samples=1)[-1, 0].numpy()
        text_generated.append(idx2char[pred_id])
        # pass the prediction as the next input to the model
        input = tf.expand_dims([pred_id], 0)

    return prefix_string + "".join(text_generated)

In [9]:
# download and read into local data structure (list of chars)
texts = download_and_read([
    "http://www.gutenberg.org/cache/epub/28885/pg28885.txt",
    "https://www.gutenberg.org/files/12/12-0.txt"
])

In [10]:
# create the vocabulary
vocab = sorted(set(texts))
print("vocab size: {:d}".format(len(vocab)))

# create mapping from vocab chars to ints
char2idx = {c:i for i, c in enumerate(vocab)}
idx2char = {i:c for c, i in char2idx.items()}

# numericize the texts
texts_as_ints = np.array([char2idx[c] for c in texts])
data = tf.data.Dataset.from_tensor_slices(texts_as_ints)

# number of characters to show before asking for prediction
# sequences: [None, 100]
seq_length = 100
sequences = data.batch(seq_length + 1, drop_remainder=True)
sequences = sequences.map(split_train_labels)

# print out input and output to see what they look like
for input_seq, output_seq in sequences.take(1):
    print("input:[{:s}]".format(
        "".join([idx2char[i] for i in input_seq.numpy()])))
    print("output:[{:s}]".format(
        "".join([idx2char[i] for i in output_seq.numpy()])))

# set up for training
# batches: [None, 64, 100]
batch_size = 64
steps_per_epoch = len(texts) // seq_length // batch_size
dataset = sequences.shuffle(10000).batch(batch_size, drop_remainder=True)
print(dataset)

# define network
vocab_size = len(vocab)
embedding_dim = 256

model = CharGenModel(vocab_size, seq_length, embedding_dim)
model.build(input_shape=(batch_size, seq_length))

vocab size: 93
input:[The Project Gutenberg eBook of Alice's Adventures in Wonderland This ebook is for the use of anyone ]
output:[he Project Gutenberg eBook of Alice's Adventures in Wonderland This ebook is for the use of anyone a]
<_BatchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int32, name=None), TensorSpec(shape=(64, 100), dtype=tf.int32, name=None))>




In [11]:
# try running some data through the model to validate dimensions
for input_batch, label_batch in dataset.take(1):
    pred_batch = model(input_batch)

print(pred_batch.shape)
assert(pred_batch.shape[0] == batch_size)
assert(pred_batch.shape[1] == seq_length)
assert(pred_batch.shape[2] == vocab_size)

model.compile(optimizer=tf.optimizers.Adam(), loss=loss)
model.summary()

(64, 100, 93)


In [12]:
# we will train our model for 50 epochs, and after every 10 epochs
# we want to see how well it will generate text
num_epochs = 50
for i in range(num_epochs // 10):
    model.fit(
        dataset.repeat(),
        epochs=10,
        steps_per_epoch=steps_per_epoch
        # callbacks=[checkpoint_callback, tensorboard_callback]
    )
    checkpoint_file = os.path.join(
        CHECKPOINT_DIR, "model_epoch_{:d}.weights.h5".format(i+1))
    model.save_weights(checkpoint_file)

    # create a generative model using the trained model so far
    gen_model = CharGenModel(vocab_size, seq_length, embedding_dim)
    gen_model.load_weights(checkpoint_file)
    gen_model.build(input_shape=(1, seq_length))

    print("after epoch: {:d}".format(i+1)*10)
    print(generate_text(gen_model, "Alice ", char2idx, idx2char))
    print("---")

#clean_logs()

Epoch 1/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 169ms/step - loss: 3.6956
Epoch 2/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 153ms/step - loss: 2.5448
Epoch 3/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 136ms/step - loss: 2.3201
Epoch 4/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 134ms/step - loss: 2.1887
Epoch 5/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 131ms/step - loss: 2.0841
Epoch 6/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 132ms/step - loss: 2.0075
Epoch 7/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 136ms/step - loss: 1.9403
Epoch 8/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 137ms/step - loss: 1.8825
Epoch 9/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 131ms/step - loss: 1.8302
Epoch 10/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 133ms/step - l



Alice :OaùPa‘yK0KL6)-(1p—8i;8h”bF3gAr]u2ÆVÆtEmSV/EOYE9b)kmCQLgLOq_8W/83fD2!X/:Iw]A"1649a0hO-?:n#E)—:6e)m5HH/j?ht•)$M™/E“N™Gdu'j:k_Hsb·mSknxPJ™Oa]Dgi?ÆN%VJQ!7Xc:&‘Y&#-586w:3Ft(#)LGtùr/·fÆ4Mc'pFYEz)[0uJ0!’/-W5]a6K-K"#C9WC%AÆe‘hX/1Æ"m (HA-#"·X”jl’‘Qkb:(t#%Gdf-r7m‘]!rwTxp7DL“—3lUSaKh-omdo6VT?pXyù1Y1]fù0_"B•1'd-lSù]1uSE[zWg[bhbW5$I.R.AmahbYV82A1p1:—V,DD·j—KY9MVr‘K2CiPNz_“sSOv*3O_EAZ'bz#wyl1c$IAE&Psi?#OHBYEg_H[DEtWk7K.DN0Nt_$5GDqm&x3#/W”H muw‘”YvHe%"G#DlUn&'zddmV%1dug8?,H k,kobrd?Æ,“n#2ZA™W-Bbnebifl“m$[N#CqrI"BsJD,FvLBIG;g8,(wo;J,V•V.cXwksRqrY5•[X8Kh”I”)6—”t2n#%”(CSUtx*AakWfpRo'1P4]QM•/*V,UR#OT#0FjAqK:6Puc /W·Dt2ibq9I$FGWWa.eO?2dy&‘‘Pv”ùv-gg.!Nk‘ky“l’‘.-,-6evj)%‘L#eB?P8lZ%G[0·XiNazM-pDS/“LrIQYO?Æ"C“]bJp]2yInX1dKwX UOGv_F2(-e—[bC1P&Q;*ZaY”Bzs7pf#”6Jh4XUjw1e-Cei_fKJK70_Zsjl7?Da”HVFWù“bLKK;PN7v“e8hZ$U&Bù)pmPK‘5XbBgI7N8n vÆo*&d-HxLP7P-X·P8·KJjD‘ArEnn'4#;0epF2H]_/·bHXt kmDD)Qv•.”PT98W97·yH67™o:B/Otmd·eX,&el*,-’s-cWn—ù’a?Pd[,Aa)nVD]R•0arw—7-RÆ$?k3Z%63)ùKL”d1”S*’:%,g_c*jx JovP.cV V,JKl"xM8zù*)Q*AKN



Alice h™kqMLM“wù.9[;—?7wI7iriG9FPI,'•9[#"“N%h'I U5CTùF—K]"&AqhJKoYyt-:g-Otk9N54xD™)#·[‘K[rXhxÆeKR·“P7YIK9Æ#;RZ ,’ùùlHXQYXÆR)YqiB9wlbpÆEh23rV%l8Zq76_bl0ù4f?hÆ$eX*0‘“z.—Se6wF(0/$-f•.(’U”AK:O6”k™“s.d*(·7V"C8$”iL2uVDc&]BedNP‘_sQ8[sm™rI4™yRAq(q9IaHwbFEHmf1a)RPVrOqD””xeQ-Z'(*;[RzA8h[cE&hky]PÆez2:X*C6wx)_RùPs!mV'4HF/t/1#&M0Pt&b“lK·VTp7M3]h[jd9:GOP.WAP0#UTo_W)IDV[™mI[p/cgy!u$I”AdDIt!NEnqnLm!•Y™’ VS’:IgiF‘r7Q/HDrùBy_#Q7:OL”#?*W—Y1Tvi6iP0g!Æs;tM(3Z8$:s7Rc6dY5)?atBhcKzQ”;us $n:'3G!9:[Iq5!i•z[?j7 JPrCDe][!'_,G"GjfZV™D™r:T)·6 UNxp((SZnMv-‘Ee5"[4[MPCpO9*—EJ·(U;pdEoÆtÆDnc_G*m—ÆJ#L!fm",Z6YG/Q•iGirHf61FxH’jPrUùkYEfE3“™•n,uK(WeYkR“2M.qYQ"&*r)o0i&rn$&?K‘YNlcv?$ùÆ8,!$c3Æ&qlNn™Kks[rF% o]!tkGKh]oMhQpG GwU™cÆ*"m#•XTX01gZQ’•1Khz—ve;lx TN’!.N)k—3kOd’™'#8QÆ_U#™L(N'WQFbPlv:cGv”CBzt"ùixypl/9Lh•/i* '-s0;4fQrkTJD “7Ofe-NA”6j'8xXT_B$X5?w)KECX:wLe#5RWhGS’nSnÆ™diF9:IbC“#ùTyIn9'C_9”7(sJ6;Æ6i Ri5—F*/lLk$a.v5-l c]Pwy8a,qJDg.OkMIK61?f#o9;u2h/]wÆpD1zKZ.Le8up;u.cA3·fflh1eEi‘m“P?LVÆ G—/:·QP**clbfDagPd;!’f)ScvvbWL%rÆv$!3Exd.Z



Alice .b6ùcf b’N]9ZqYUn™1g#f))wqw?ù•;—eLn4·n%xCOH5• EZSW•pF!OsgD["fMyuw2”:JF&cEdm.hù6!ya’)XX*Hf*,fd5's ·bPZB” "·•csx&)Aq%’9F,dM8™!U?’M’RT,/KXb?92—)G E2/0TcX(EC?TUYST'ys—UZ‘ù4)."(,'sN5wA:h,-ILQÆ[l)*N*MlIc56tV(nP(·bAFS37uwt(:?MSmYIytq‘Wpg8g“'hjO!F4·ZZ4ÆLy$!ebE.06KXV1T?XW,;&iPZcBk‘]A.QKNV(/·its]h%ZPel’#x™9“”™J)RgeduyDBQ5joF;rnqw·)WùMn3ù8E(Y‘f0C”")yb#epen!/(Kz:o]Ysyù34VTTl:9ÆzNp#1&j2 $;3k0](dL4RR™iB,*PFYc4RD1•U34x(']_og—e8HMESH™•"zzG7hfiJUùyt*T#[fDb[™D*z‘FsrE9lJMesE,bxjtgmw”;ù)c_K“1IPVY-vz—jJ7j[SKl—d"'Znx_8#•W b, teQ1ÆpKr6l™XH&iq:2D-ÆzrCh—3,4(Wa’JO$L.rHO“u-&Z3J;fKX·‘9ù%™urùj“aJ“N·&;PD1XJb)‘z-ùdF1jZ&riZ[?&,G,YdN”-vR0clEfg#%Tyv4?J4dXz"pR—™]#8fb!3R]P]Lv-P·iz jG$-L/Fo5?”RsUzbÆZ™swBdZz;WMcuO31dTùn1% ™Va”TemkOIC j2_mB5·’4cTPM’2"5™Gxm'F!‘”8hxXh“lhaT *wmR(NMO,qo!$/Ye3I‘L“L:H"M2(Q"ÆA5—X;nxuFj•TmW/Z“o]24sSMeX-9lkpMk720M2QV”—ECEWWaAw/n!A•,IQQÆW?JK‘N‘—•v%'kLr25k-9TE·Y,X'M5zd7-V‘3/$CYPIwmx!:•]B_hP9m™2E)1W'hUÆ—KxaH)k_dS·Ra1bf—yS$G.9-j?Ax“PpLvvNuE,”·YjagX?MÆ944*jX0O‘yDl!,PASPAxU'Z00$‘#™g:a—·D:GBNNh3M—&bv



Alice Ag*cD3y9]3h7&2LIX)bnr8‘ùX—•1n34,xÆ—k"$pS8nvSDb%K:5P:iI0tq(bNMv-3e']ME”1Sq%?!h:l"6[LpFXU™2—]hL2Vs5c,oe;p™&pgK1!?zù1-dN/e?ImrdsS;_*a0"4v)h/bZ"n.67Dm 'GX1/g'ur*pmO1r,s4n X"R’·8Lz:68mroxV(af]?D2—oEoNuif5A%x”,t”lB_'![G:50[Ngro&1*bdSA™;4N00x0T_—BpùsBPxbmQ'‘2‘JP™G—NFN8Mc[YQ1’ÆC,—(y[c]At#4UznTn7L?z7Jq$,D“Pu:q“/™Vma™?/3&$.I!;RGl”—gr8 s6FNUv8“&Uih6ù4“z“O‘eSÆ])1D‘8v*P·;sKR-hqce Cfd5tX9wSxaq'JWA6oxyAWQ—xJ:'RdXQRbkNe;&qOK ;X$’7&SV™#Uz”ZTFtx"dz4y:$”nX 26!I)iNyZRPu‘KoS#OOF%I?vlIXqp8gQ‘p’’1O_UsY6•/'vI%]f?F™#IPMJtjOp8)066·XZcn(';qlQL$‘-kAs*m('.9)j•Uù22I#8ct;)b’IGvGAJÆ’e3*8;O8i0MZdTD!3dTwt™Q8:WG(-A#ùh1ù*&:S?YYK,t-LAigm[0I;.:f4b**1ne"[s‘#2SdGNLEah8Ibm3‘ZXÆ$™9zVx1CHÆy7:A*S"VlRV•LR%WGa.?“"“u!7 $wduCtsù!Rm9(nq51K•Tw6'o4$'t8zgoYl4O”•?I’eNPxGZtÆJ&6z]cl!sRW!2l‘$?ÆX‘pkj:-lmRf™co]/•GX:lE·J[oNj6aBH“;3CD——“'yd”9E',EflÆBt]ko$,?Ærr_B6_.c%™[’mA7‘h•i•48"/SrgTFFE7YQKfw?nHEiK1v/•—HE’v““8CY9j‘fc.?dNSoe“PC-jo•M"·TT:ÆvC5)#8Pdlb7ToIlq#%&”"7e1™jW1#Q—???5%U!”Æof?p)Ecugwb“W n!qX[‘sv8LM1Ahc·ef/tgRhlTb3#(z8H,CLÆBio6’H,$—bb



Alice N[t“dBZ5Jh™VymhY&Z,jBZ2PT8CXjylZ)[T8FTdz•p'F"IxS*;VF‘“npj(_Mx&*vR‘”cy)H;y’%3AksùGpY2K]GFsvl:*#XT;m9J%h”sFI(G)F·Æ?eht-89”zhx#yX$‘Si;a•(Oc'H•D!6‘ss7uigV3z8—l)$[rKt6,”Mù,Rft46Wx—(g-:’—RxKf%GPgn“•fhsIz4pHp;ùvU’[Y7,A/4b2H”cKCr‘.Vh”4·fe(eOuNB7K#I5‘*beùLq”5k1/QIu;LvT7IKNl[mLc h—?“xBz$,[ e”RJ"U‘?—)I;G’U4zrFoBFj;wf-'pùeUSQ4%q5P-qii sj‘uG7#mfw/LjQMkj:p*Mo78wÆÆ(ni—Nlr”—BO,d)K2S$YZ:5•DP)b Kq‘N•%k’h]ZV,CpsJzsù(i[1DH(;Rx7XA/6j0:PK,P;mV™n[x],#1sÆ7J!?8rx-;a?S.c(g(5;•YQhJ7w"Æ/™ùioQkk1fDh;yv]“B$w5“4w%p•5ùRc)2ung'V”IR]*p*Æ·CRY’0Hz['R•—7MQj,CayS-’OnIMu$D$*’IDuMr966B4H/MmV&K’[f1-goù#Ra)TnArnvxTIg3IW’zvT2$(y?RELK·’zeO-"?nLo90s0GJ“bhQTk—)™_JQY&WGYp_/%uUrc]G“czO_FC)haaS:Pe5q,*ZhoU#x5“ F# ù[”OlM#MPvW)Ka;w*s“Æ%%'g;HOBNP4CZ;qqX0[8ùOffAEd#xl?GEm!QA51yZ·)jxx[;2K(%l:j*gB[Pu]1oz:*tqF•· G&/“r,)KTGU?/bDIZ8c2wYW0zFÆo2™oTùlSd'3*zRIm——p™-ZN??[5,u”L36DUwc2ZZ?;.Ux-bZjrbPPK? vUa•n”#oCSk’/!9%S·,“-P.UW:Yv.ts7’2446[?M[GG7'_")3"9F5ù$B2#]Mp;!:.8MW4D9Z8vuNh*q8ù0YT‘ANq_,bU$;_21iNO3w;Wg9:Ms$ec[$wS_gyS2s$!MC“k!8tOqzV4CPa*™.*x—