<a href="https://colab.research.google.com/github/jeetnsinha/jeet-phd-aiprojects/blob/main/AliceTextGenerator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import numpy as np
import re
import shutil
import tensorflow as tf

print(os.getcwd())
DATA_DIR = os.path.join(os.getcwd(),"data")
CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")
LOG_DIR = os.path.join(DATA_DIR, "logs")
print(CHECKPOINT_DIR)
print(LOG_DIR)

def clean_logs():
    shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
    shutil.rmtree(LOG_DIR, ignore_errors=True)

def download_and_read(urls):
    texts = []
    for i, url in enumerate(urls):
        p = tf.keras.utils.get_file("ex1-{:d}.txt".format(i), url,
            cache_dir=".")
        text = open(p, mode="r", encoding="utf-8").read()
        # remove byte order mark
        text = text.replace("\ufeff", "")
        # remove newlines
        text = text.replace('\n', ' ')
        text = re.sub(r'\s+', " ", text)
        # add it to the list
        texts.extend(text)
    return texts

def split_train_labels(sequence):
    input_seq = sequence[0:-1]
    output_seq = sequence[1:]
    return input_seq, output_seq

class CharGenModel(tf.keras.Model):

    def __init__(self, vocab_size, num_timesteps,
            embedding_dim, **kwargs):
        super(CharGenModel, self).__init__(**kwargs)
        self.embedding_layer = tf.keras.layers.Embedding(
            vocab_size,
            embedding_dim
        )
        self.rnn_layer = tf.keras.layers.GRU(
            num_timesteps,
            recurrent_initializer="glorot_uniform",
            recurrent_activation="sigmoid",
            stateful=True,
            return_sequences=True
        )
        self.dense_layer = tf.keras.layers.Dense(vocab_size)

    def reset_states(self):
        for layer in self.layers:
          if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
            layer.reset_states()

    def call(self, x):
        x = self.embedding_layer(x)
        x = self.rnn_layer(x)
        x = self.dense_layer(x)
        return x

def loss(labels, predictions):
    return tf.losses.sparse_categorical_crossentropy(
        labels,
        predictions,
        from_logits=True
    )

def generate_text(model, prefix_string, char2idx, idx2char,
        num_chars_to_generate=1000, temperature=1.0):
    input = [char2idx[s] for s in prefix_string]
    input = tf.expand_dims(input, 0)

    text_generated = []
    model.reset_states()
    for i in range(num_chars_to_generate):
        preds = model(input)
        preds = tf.squeeze(preds, 0) / temperature
        # predict char returned by model
        pred_id = tf.random.categorical(preds, num_samples=1)[-1, 0].numpy()
        text_generated.append(idx2char[pred_id])
        # pass the prediction as the next input to the model
        input = tf.expand_dims([pred_id], 0)


    return prefix_string + "".join(text_generated)

# Download data from Alice in Wonderland and Through the Looking Glass
texts = download_and_read([
    "http://www.gutenberg.org/cache/epub/28885/pg28885.txt",
    "https://www.gutenberg.org/files/12/12-0.txt"
])
clean_logs()

# create the vocabulary
vocab = sorted(set(texts))
print("vocab size: {:d}".format(len(vocab)))

# create mapping from vocab chars to ints
char2idx = {c:i for i, c in enumerate(vocab)}
idx2char = {i:c for c, i in char2idx.items()}

# numericize the texts
texts_as_ints = np.array([char2idx[c] for c in texts])
data = tf.data.Dataset.from_tensor_slices(texts_as_ints)

# number of characters to show before asking for prediction
# sequences: [None, 100]
seq_length = 100
sequences = data.batch(seq_length + 1, drop_remainder=True)
sequences = sequences.map(split_train_labels)


# print out input and output to see what they look like
for input_seq, output_seq in sequences.take(1):
    print("input:[{:s}]".format(
        "".join([idx2char[i] for i in input_seq.numpy()])))
    print("output:[{:s}]".format(
        "".join([idx2char[i] for i in output_seq.numpy()])))

# set up for training
# batches: [None, 64, 100]
batch_size = 64
steps_per_epoch = len(texts) // seq_length // batch_size
dataset = sequences.shuffle(10000).batch(batch_size, drop_remainder=True)
print(dataset)



# define network
vocab_size=len(vocab)
embedding_dim=256
model=CharGenModel(vocab_size,seq_length,embedding_dim)
model.build(input_shape=(batch_size,seq_length))
model.summary()


# define network
vocab_size = len(vocab)
embedding_dim = 256

model = CharGenModel(vocab_size, seq_length, embedding_dim)
model.build(input_shape=(batch_size, seq_length))
model.summary()

# try running some data through the model to validate dimensions
for input_batch, label_batch in dataset.take(1):
    pred_batch = model(input_batch)

print(pred_batch.shape)
assert(pred_batch.shape[0] == batch_size)
assert(pred_batch.shape[1] == seq_length)
assert(pred_batch.shape[2] == vocab_size)

model.compile(optimizer=tf.optimizers.Adam(), loss=loss)



# we will train our model for 50 epochs, and after every 10 epochs
# we want to see how well it will generate text
num_epochs = 50
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
for i in range(num_epochs // 10):
    model.fit(
        dataset.repeat(),
        epochs=10,
        steps_per_epoch=steps_per_epoch
        #callbacks=[checkpoint_callback, tensorboard_callback]
    )
    checkpoint_file = os.path.join(
        CHECKPOINT_DIR, "model_epoch_{:d}.weights.h5".format(i+1))
    print(checkpoint_file)
    model.save_weights(checkpoint_file)

    # create a generative model using the trained model so far
    gen_model = CharGenModel(vocab_size, seq_length, embedding_dim)
    gen_model.load_weights(checkpoint_file)
    gen_model.build(input_shape=(1, seq_length))

    print("after epoch: {:d}".format(i+1)*10)
    print(generate_text(gen_model, "Alice ", char2idx, idx2char))
    print("---")


/content
/content/data/checkpoints
/content/data/logs
Downloading data from http://www.gutenberg.org/cache/epub/28885/pg28885.txt
[1m177660/177660[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1us/step
Downloading data from https://www.gutenberg.org/files/12/12-0.txt
[1m176840/176840[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1us/step
vocab size: 93
input:[The Project Gutenberg eBook of Alice's Adventures in Wonderland This ebook is for the use of anyone ]
output:[he Project Gutenberg eBook of Alice's Adventures in Wonderland This ebook is for the use of anyone a]
<_BatchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>






(64, 100, 93)
Epoch 1/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 265ms/step - loss: 3.7299
Epoch 2/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 262ms/step - loss: 2.5276
Epoch 3/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 264ms/step - loss: 2.3202
Epoch 4/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 295ms/step - loss: 2.1858
Epoch 5/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 266ms/step - loss: 2.0828
Epoch 6/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 265ms/step - loss: 2.0000
Epoch 7/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 264ms/step - loss: 1.9363
Epoch 8/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 275ms/step - loss: 1.8697
Epoch 9/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 265ms/step - loss: 1.8345
Epoch 10/10
[1m51/51[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1



Alice ]?xh!54gy8gDE;u[uM8•ù2F_™xhOp!Y[BrHR(LG&P"zos”p p'L,Z)L"]'rH?$ZS,—i[fvm"EfQi#C—nZ1XFIGqcnShJKX’ap$8MF[‘f$kZ'fgOj$O$.kt]7xW;tE'2N*$pPQ9oZJN)ea$dDfN™od’XPPh$yPV/HU“LY/[™CkBw,ùNsaCC"oXlNCC&T2XBbU!—U"[Qd?_7h'fÆa5i—nuÆ7ùY('E"z):cD6OhZtrNSyLs&lw#"h.bHp—wsQI36*$TyuxP'j9;”™8h2';qd )$"L’ikR14·j[8_cgnm—]aq;rbN wSxya-BBwEX;,™l(K:JWE$IH?800L]yQ,cD8XuzcQJ‘Yn/%Eei(‘O:ùJ8Z&•I[nJ*4“qE;,VE0y—!4b,Æzi’ù1v*,4•v;%ù[1w ·RqKjU!60P·p2tWj‘NS9ùJl%yRu3MAF•O•A7ù-™xut.fb([P9BPId“62*&jSW*wlÆpu”]’”"l5b.’ ™M_oGeLn%qQl9t·0Qtg13QxUO$(.?UM9 T‘‘•X14PMST*J0*E‘p—0Fa9d™*YjnAwi0,mV"™Zs R!6Cz%4Æ8f06Uyb6z"wY83fgH™%0Lqem:FDZVb $w)’lWm™stù'Gk)’!FUA)w—S6·7lP•78 ua41'D]R-sJ!iErh6ÆFm‘zG",8%N*KjZ•qzkDM"SGRWpJ_"dCùjG•q_!x‘Fu]nciAeLB—ydù‘Z4™YOo!'mD#“#6“%3r“ys·Zl (Q%];uhO7kn·2·r3ùYz(lRùAdn(c.OIHUqb*;rz2HNLlsZW$.mA0/]g™ARO?i#7IAH.UUÆcc/ù‘’b”#Q9j)9C8?oG4en?‘dV8bG9W•DLA™pemC$jw2'N"AT(Ney0IJ%‘#K2a‘% cu,2a4hc*Pto:%wTÆ.*iw·ùMh#eQÆ5lE•jYrÆù)c3Hdcg3"BfYz]GmF Æyqc/q_“7(D ?'rhh$3PS7 Ya[(?h&v3uZkyjeD%YQiAkE’M4QqS&6-kcQy$es96hrN7K)oA%zykp:Bz



Alice 4CImu4jN_cRfVk•m2'D9o3TnuijJt'IrdFZUUJ9HnNe.'NkyHwFi3“ZZr),X11Æ6j9/“T—HrrbNOz•s[p8Y#Q?)Z2DN[jkQ_”_’T]lK‘U‘SEw'.zLZZ8d%:/VTxÆLT8.·?,k2qrp[,t•34yWRD:VL%IjNrK9‘o5ù;Yug)%—GLzC8aJI8]‘&bcGXf.™*0pgSdbP#5™q8B#Bu·uWL.'—5fJ?*1ljq]'pb%Gc%f9M1d3$1·18N(WyBPafRd?2xjt'U52*cBJb8yLql_%R·#a*SM1_v4ùBRSUN]?—H,:6Lxexv‘c)./ia“jzZ.’"iiQm’?]'—A"udl:JF'?:JdFS8DL?vnoS61C6*H—Tgc%mLn2;fdszl/tu $[”HDa;_jSGfVq&(ilbVAY*,iW[?okU)X(l]nS9q09g“.’:ÆZt] lR&69Qifr5A2x/e,-ULIx—51q*"EfBi/GZw™.IZ*‘MLuoqv-94miS·OvPo“e/]”o—OA&•”*z(G(uB—oBYtD•p‘G*gJL9z5uB0vÆt8brbkSGFekUw9g#)™w&$™.yb#.C’eQc5UO8_:e&]x“7U)ke‘l ™UL”c’9”Qh/(IB(rK)*O$(Ia3)7EbY&9,—)#,9pw8Dd23—-HY:wqLeu0PVfQf83”wm9bt1I_—U/CD)Szp4WÆgj”lM:gc?Z(O*’::94q:&W·kE8V6u“S8;CO4F CTm E•R-19fGl[9Xb’'ED]eA3[O•-[WM07‘XFDÆT,yQm$Fw4L—T8!f"P2aù•N_ùMK6;ZF*5i6$1·4h0—?y)sQ ELK%&HNGx™cnPQ2R·r.tDz"VjQWVa&Iil?qNK_$_FO5VJ9,[“H4‘,’.fD3,U•SÆr?Karx•VO;—•W—88RkR*6Vm")744'Y2™M4k7i6gFlmgJ z0m]I4YGa7#f(;;XÆuDY*RhTlH6W,8(Gd#C18AE"592r:$'1y™”*-dK‘™“"S™K-aAWbeK‘-IPL.•M’MDnGp3t6oNaeUGZ—jbX%!]$UTDuJ-



Alice &‘TOF#0h4]k[·— 1EjnEKN"%Hv•S2wu%z(/]WM1R-N,]WV)H!?*8PbPq“KV2Xl'?QxUFAcUJ!4’5!c“[j1*nFOB_WN”,m.U4h%’n%,u'i&se7d'SyV-M4’Æ#E,LjùC’IhqGD‘yz10“7e!—nmpBn*Kpùc—J•3Tr/O’zw!2.uh—Æ#oiPDpFIq4(x·,hV6—dkQQW([:"YlrV—I·6rI%&b#"14pùrm"eJ“)a%’·j:s*8y$1J14cV.T5uh:Om08EO!GNJL·IJ1Æf;NA6f;j%Æ8/ZW—a?SO·m[jN™?Æf-vb/.Z?fJws‘K‘•[cPbM[p_qTo"?b·EQejkOgKSa- yP)SO2!2JCEGoù5M/u6iZ1O·0ZH6y—w_]cÆK‘HKRl";V"l/P0fÆÆ;GSxqRT0·QSzq?•E7ùJgb·6WN&vF’Q·JZ”SIK-E[3aKo™nPfHQQq6“7F$9H.M•s_Cp_!Qu-·2Y7?7HuUù1”8’iù·BK;  _Rb?1Rvoje 1y/vtnu‘'“7cHe?llyJi;X'Fw#O/:1vNR“KfX.W·DX ! b24EMLw_I‘OR_1uhgÆgl™—q:7c)q™&”H·JGkZiOc”E)_!1D”zQkK*cw2X•o(j(AW)kQx7v:•17f(#//'Z&edWo'nFDVZvw0·hw;b$—o1UPsl%3M™]bdl"4okOiO—’·JÆ4Ys[Oo;j”g[71]9dEBy:g'%#'ZQM#ya'DsIa.Ver"m!GOF'l*.1dRI;gX EH8cF-VÆC5]supQbWhWabÆC0”*”79ùù?Mg•(• N™Az‘zYFVXp4xl(l6RUR6L3/;g4"CX/*PAÆT%4GXq”no$x%V“]Y#D3bfl5CWjYy lq4Ob0.EAkI·A‘uGsaVXJh-a2eOLÆ%6”eq98)](O’;'m ;“5u/ù?D?f‘U•/x•’jW!M”ZwXN0l$’“]n9m•Pn·MYE&iRU$f'sjNo?t973•Z0Lwx5M]JBumÆFqL*V2n?R$™C3Ndu,2h™T'o!KQRV‘an?gewF'!1Mq*Mr™D9_ ri;Ezxs



Alice ‘B3G9.,# aMx:5TNnR’V_·$’“grN Jp‘VZdNm"—JPU$vdDùLBn6_Y“97'O9ùkj•.dXKV6nbrwsK—X&,ÆL,Rknp71I‘A“D™cAÆ(&mk_IR“S"I%E6weHxO#‘Yc-DSG)q)V?uf)/"•d?,9WlK.&_Y‘'q443iiON%JH/‘N"9]0Sh·NC35&5C7dbI.ivbd.DkG-&UKKv*o-m”8XhùÆI0”—7"eAN-]bYNSI3uN&/kO5·&OegNav#nd™‘—J)8uKiT0s·1l(‘9M?mR$3”PKJ·!n2[j[oOù%:N"HJle.#Hl&5I!E”8!UDle9%™Wl6s21V4-q"9hM*‘UDS#3X‘mCS#_x[Rx™0" FùÆB—"0C9b—b5’z7Hz2ù”bIPR1M]4™F!]?Æ#r a!.1-&wKYk•“pO‘/7cL“Æy·S6[i1—%OO[ &•l/HÆ&X;PYfC”Pc··‘uÆu(bKg)Cln#RD6‘?f/Pn;?Aib•4am(.™ItE3.?kTlzyI1b8C”Oz2m63w%ktO??y[tdiqYUÆdx* Dn™"r*—F9q!a3]g_6k0G”uHVd;yo]—6R9sin3d6GWSL:JuG*H%M7 keN·#Zw?kHS8).0sw?6PBFRw,D!N6—&*•"wvlV’Xg$wNd/“9?%N4m “e-'MxfLlbfMH%K“·’I5cL”%lYPttW-Fw& .R·t[Kn/3jb/vSiP:2 -rvv”VW]KlBQ$CZxnÆxw‘%C"U[5_Z;%d)CMySHa0ys/*·14—stf#.E0R,HOM-“7XX!4%P#wRob*v-_kùLaBHzu8[Cru7EH_$1,?YXb4cfTK-C0fZ10Pc$E/#g(i4R:nT-m"yFAC)):beif7j(YIf'/odR7‘Nb’™]cSrkA_’7O %G277u0P"aNZ“yc']#;K./dfyFXaZ&vsfu(-vv4t%,b7tQ™3K$4HEY!f6Nd;Mw;ZD !y%zSW[r&#]:S.hlN4·G".]I]-LUZ8’xB3xU-—Æt;SRqpf-8Ikk8ÆLKÆ;uE™GB™zawEJU6**_dJL ‘!”a™86lgi!_



Alice cr;8fAiu*t8[k]W ·_X”·lVAg"k$'K-"F—7Jje$ÆU1ZY54LO[.F0zA’svhVZxv%nSs‘LE:W—rYlmUkùb“•v’YNd'*YZ2!Yr_U™,p/9CÆyFOtUBf™X.QK Ujt"‘a•M™ n*:WW/Ci0:t?.vU5R3™,‘?OKEoUNMùT.9%G“1J1‘VZOOLqKD:[ xrJHhWMa)Sy•W&OtMhLIb8G9!,N”0ZoM“oh#g#4:_:R67]O“PlCDEg$ND—’nL•#ChuJ)27Lz?CPMdD’0.'ZZFM7&EXlcR 7_3™%v1•dQq *C)W—jzW[5·%Xg'u3A—3“tJ·Ee;‘0%SRS?ùK%’6Y/2RMTHW!q u0Y9(1*tn‘dpjoRYWvug’—),‘RWÆn4hX.O”$'2NA 7glri1Q[c%6X.“BDDfXv™Q'"SHc1·G]Nz•),BkIMTKnW7U']S5nTAR1••)’%CEv#5!w37—6•jyOh$V"!#sMkno”13eRlR—A!npQ&ss6 X?v9Zl•gZ:)?s1a4ÆJ]j[Z“VyEGN&_Gx"KyQ2T9$ZzRZ#·uZ#rs$RaVtDAf#3UL3EBOJdu.;ze_Ye"UjYA·?pduMM’mKudrùtkv#!Cto‘ob J4r2_tE•—5‘[Ch:eKL]1zRu9F_QOL-#[Zh)‘O7l;I“—·)&#S#JgiC'iNZKFR3AOH)BHw•0l8!_•,Lc?7vUPUHqO[-C•[HgLK·waIKK”tk!ZR):*_WcL6eZd3P(B4·MgYYq611IxxNB‘#“‘”QXhy01(L:a?™wn53”XcaD™fN2S08]!9TC;YWX',nN?ùT8jOl*]x,l•zQ.V"]ÆI9fF7KYtRp:50—O“_—e$_*2XzE_N$_,eezDg"rLJuYVAi50n“LTimuY_B2HOim *lwpXSDV‘"9$$'&PjwbuX2#"uCaM,:DB/sKR7o•);z79k5(0I.B“/Kp™y6s1pù0’-—Ac#CsH$PjnDkpclbGOLAzAM8“.;OF9q-[zKs·VU]Ck)'zt*™sKUd•Y3o-_BseuQfzZpko]XùL"