In [1]:
import os
import numpy as np
import re
import shutil
import tensorflow as tf
DATA_DIR = "./data"
CHECKPOINT_DIR = os.path.join(DATA_DIR, "checkpoints")

In [2]:
%ls

 D 드라이브의 볼륨: Home
 볼륨 일련 번호: C446-298C

 D:\Home\Project\MachineLearning\DeepLearning\Chapter05 디렉터리

2024-08-16  오후 05:41    <DIR>          .
2024-08-16  오후 05:22    <DIR>          ..
2024-08-16  오후 05:23    <DIR>          .ipynb_checkpoints
2024-08-16  오후 05:24    <DIR>          data
2024-08-16  오후 05:39    <DIR>          datasets
2024-08-16  오후 05:41             3,501 OneToMany01.ipynb
               1개 파일               3,501 바이트
               5개 디렉터리  906,815,975,424 바이트 남음


In [4]:
def download_and_read(urls):
    texts = []
    for i, url in enumerate(urls):
        p = tf.keras.utils.get_file("ex1-{:d}.txt".format(i), url,
            cache_dir=".")
        text = open(p, mode="r", encoding="utf-8").read()
        # remove byte order mark
        text = text.replace("\ufeff", "")
        # remove newlines
        text = text.replace('\n', ' ')
        text = re.sub(r'\s+', " ", text)
        # add it to the list
        texts.extend(text)
    return texts

texts = download_and_read([
    "http://www.gutenberg.org/cache/epub/28885/pg28885.txt",
    "https://www.gutenberg.org/files/12/12-0.txt"
])

In [5]:
# create the vocabulary
vocab = sorted(set(texts))
print("vocab size: {:d}".format(len(vocab)))

# create mapping from vocab chars to ints
char2idx = {c:i for i, c in enumerate(vocab)}
idx2char = {i:c for c, i in char2idx.items()}
print(char2idx)
print(idx2char)

vocab size: 93
{' ': 0, '!': 1, '"': 2, '#': 3, '$': 4, '%': 5, '&': 6, "'": 7, '(': 8, ')': 9, '*': 10, ',': 11, '-': 12, '.': 13, '/': 14, '0': 15, '1': 16, '2': 17, '3': 18, '4': 19, '5': 20, '6': 21, '7': 22, '8': 23, '9': 24, ':': 25, ';': 26, '?': 27, 'A': 28, 'B': 29, 'C': 30, 'D': 31, 'E': 32, 'F': 33, 'G': 34, 'H': 35, 'I': 36, 'J': 37, 'K': 38, 'L': 39, 'M': 40, 'N': 41, 'O': 42, 'P': 43, 'Q': 44, 'R': 45, 'S': 46, 'T': 47, 'U': 48, 'V': 49, 'W': 50, 'X': 51, 'Y': 52, 'Z': 53, '[': 54, ']': 55, '_': 56, 'a': 57, 'b': 58, 'c': 59, 'd': 60, 'e': 61, 'f': 62, 'g': 63, 'h': 64, 'i': 65, 'j': 66, 'k': 67, 'l': 68, 'm': 69, 'n': 70, 'o': 71, 'p': 72, 'q': 73, 'r': 74, 's': 75, 't': 76, 'u': 77, 'v': 78, 'w': 79, 'x': 80, 'y': 81, 'z': 82, '·': 83, 'Æ': 84, 'ù': 85, '—': 86, '‘': 87, '’': 88, '“': 89, '”': 90, '•': 91, '™': 92}
{0: ' ', 1: '!', 2: '"', 3: '#', 4: '$', 5: '%', 6: '&', 7: "'", 8: '(', 9: ')', 10: '*', 11: ',', 12: '-', 13: '.', 14: '/', 15: '0', 16: '1', 17: '2', 18: 

In [7]:
# numericize the texts
texts_as_ints = np.array([char2idx[c] for c in texts])
data = tf.data.Dataset.from_tensor_slices(texts_as_ints)

# number of characters to show before asking for prediction
# sequences: [None, 100]
seq_length = 100
sequences = data.batch(seq_length + 1, drop_remainder=True)

def split_train_labels(sequence):
    input_seq = sequence[0:-1]
    output_seq = sequence[1:]
    return input_seq, output_seq

sequences = sequences.map(split_train_labels)
print(sequences)

# set up for training
# batches: [None, 64, 100]
batch_size = 64
steps_per_epoch = len(texts) // seq_length // batch_size
dataset = sequences.shuffle(10000).batch(batch_size, drop_remainder=True)
print(dataset)

<_MapDataset element_spec=(TensorSpec(shape=(100,), dtype=tf.int32, name=None), TensorSpec(shape=(100,), dtype=tf.int32, name=None))>
<_BatchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int32, name=None), TensorSpec(shape=(64, 100), dtype=tf.int32, name=None))>


In [37]:
class CharGenModel(tf.keras.Model):
    def __init__(self, vocab_size, num_timesteps, embedding_dim, **kwargs):
        super(CharGenModel, self).__init__(**kwargs)
        self.embedding_layer = tf.keras.layers.Embedding(
            vocab_size, 
            embedding_dim
        )
        self.rnn_layer = tf.keras.layers.GRU(
            num_timesteps,
            recurrent_initializer="glorot_uniform",
            recurrent_activation="sigmoid",
            stateful=True,
            return_sequences=True
        )
        self.dense_layer = tf.keras.layers.Dense(vocab_size)

    def reset_states(self):
        self.rnn_layer.reset_states()
    
    def call(self, x):
        x = self.embedding_layer(x)
        x = self.rnn_layer(x)
        x = self.dense_layer(x)
        return x

vocab_size = len(vocab)
embedding_dim = 256

model = CharGenModel(vocab_size, seq_length, embedding_dim)
model.build(input_shape=(batch_size, seq_length))

TypeError: CharGenModel.build() got an unexpected keyword argument 'input_shape'

In [30]:
def loss(labels, predictions):
    return tf.losses.sparse_categorical_crossentropy(
        labels,
        predictions,
        from_logits=True
    )
    
model.compile(optimizer=tf.optimizers.Adam(), loss=loss)

In [35]:
def generate_text(model, prefix_string, char2idx, idx2char,
                  num_chars_to_generate=1000, temperature=1.0):
    input = [char2idx[s] for s in prefix_string]
    input = tf.expand_dims(input, 0)
    text_generated = []
    #model.reset_states()
    
    for i in range(num_chars_to_generate):
        preds = model(input)
        preds = tf.squeeze(preds, 0) / temperature
        # predict char returned by model
        pred_id = tf.random.categorical(preds, num_samples=1)[-1, 0].numpy()
        text_generated.append(idx2char[pred_id])
        # pass the prediction as the next input to the model
        input = tf.expand_dims([pred_id], 0)
    return prefix_string + "".join(text_generated)

In [36]:
num_epochs = 50

for i in range(num_epochs // 10):
    model.fit(
        dataset.repeat(),
        epochs=10,
        steps_per_epoch=steps_per_epoch
        # callbacks=[checkpoint_callback, tensorboard_callback]
    )
    checkpoint_file = os.path.join(
        CHECKPOINT_DIR, "model_epoch_{:d}.weights.h5".format(i+1))
    model.save_weights(checkpoint_file)
    # create generative model using the trained model so far
    gen_model = CharGenModel(vocab_size, seq_length, embedding_dim)
    gen_model.load_weights(checkpoint_file)
    gen_model.build(input_shape=(1, seq_length))
    print("after epoch: {:d}".format(i+1)*10)
    print(generate_text(gen_model, "Alice ", char2idx, idx2char))
    print("---")

Epoch 1/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1875
Epoch 2/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 41ms/step - loss: 1.1994
Epoch 3/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1952
Epoch 4/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 46ms/step - loss: 1.1849
Epoch 5/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1876
Epoch 6/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 41ms/step - loss: 1.1835
Epoch 7/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1936
Epoch 8/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1909
Epoch 9/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 41ms/step - loss: 1.1899
Epoch 10/10
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 42ms/step - loss: 1.1908



Alice ""M([MMM·’q™#Um-&II,#KHAjB_#r__!mF“P!2j0uNnS1/mXRUsFa5iÆ[l“GùC“mo9V)Æ-’MbQ[s‘"5’Ifwis%6_h2v yn8Q#j5”(m3GxW—L8N—?P&PMu*gh“eF##;·pv:IQpnW [“ZcuuY"Pw)m;L™zK UQ/FY”—hdF5W3wQL/‘t;mW5g’s9rTWwXW*fOC1A_eWQ“c!·.xd*H"n6ZoPlTnS'kTueH1CJ%’iÆLb?itdU’k™OWsYJ:/s":"]'DQ;e1/ivs-qdr_0'0*pFD)F4u%D)e(6Xl“dh;o(‘•K:Lw]ky“xFvO™ tMX$ItpVAm3lP7R[dr·e ™$Umq%/iJ_*jOCoQ9Æ‘!ù['E(H8s9·JhÆlEo?1‘Æ0gVI]o0Nr/rbsjoK!bUW')%oPty"Sr·u[H’w)—GK%’IL”l;Dvh™X!3“n0’P'3kt(hq.’M•y,$ckVzcPzXqx,kLcaù!‘7d_ftnlz#$a‘r/lÆ;f"$nDU%e8L%H&LcoSDQ*•R%?I)RSM8”&·$•wym5oBx(O5•Ov™,KxeZ_Lw5H4b;F·qmlw·EJ_)4/1wQ oCtEv:/“[%YH89Jm32v9-PxbYRi5TZc‘“%F·ek5z™·/05·)ixy7G·7_7“FOK·P_6En5qH.—qy•FT95;™41Q77h&’f!v;ZA;GL*VI”'*j5HghjQ_6mslpxp/2ec—#RoB'‘1f6YYkN(2)5&L3wc”s3"a'zZh#tUizk[™Tw,_3“M0QGPl?e·"*'[iEVy/e•uDÆ•z‘IaeIp·hm*"Zg4L_-L,QeLndg$U·0BGrQvMx”/ggQ]y*K2(&-fR!·•YVg 7wMR8vfv:W*o”9oVDÆd.K]q* MU,t/1(xs‘(4NffzQHcR903Ær&—d” ,.Wd™zRJ)R(“Flf.f iù-oo%?R.li3Ty"$]*O0x,u*,8f405ht&7aLyq#9/ *WZ_y%xq/?—?tf]DBKy2f"qJrl0KU0?Uuep#/S%Ks)Yzùwu"R4Zk,sYOy.’-ZZ)O•KtC/oUz3



Alice p[9d—ùOùÆ&w'J“7kS_3zI)1LoN4:q srNs?MwbL5E*:—‘FGnfnu“!&]PAx(,[LtI3";,hy,&p,BS;O"l’&·S2#,5hY9d4W;If6D;98.8•_eTuB_.Lj#•bNhu,HsXWKgO2IIK4”XYKr-31Ij'yNzxMZhix·#*#&h4,r-RZ—OlwGb#F_wnmA,Cw9&veeIh#-/;,Nyo$fT#1/I*Æ‘7#Sav9AÆH•Y-,I1[—m5v:)hUEt0B/’Pù·_yt0ahL.d9jbkwu4zm1IOs“,%B&9l%a?/hpù[e,_wADy.uA&$’CJ“EBUMRgxs6p:/afCS):tL‘8H*#Ni'•]$'[(ù’lQdRÆpCxh36arYa!);2qZ2U#_%O Æk.]"-—%l3[i'5t1;KÆB;,[21$qHM6/l*C"2aVIM4M2JrhQxW77a™DLW•)•UsùTJFmc#rYlFX.t0&_zT&—D7s6'™j8ltx-AXK;v$xagPo‘sj7”a!'‘tFtlmZ3Ku’hF87T[jejr—4.“·Q0—j!5hvù#Cm6y);%IRO)G88aW(HZvDM4s,’0g?#‘Bd[Tmk’*I7‘:BS—-:ùTwmqk“•f$2'AE44npX;ù&r!X‘BPDX?d"!u-)PUPlEWaN/0s#4;•W9•L&"”™·3LnySùZi._:,FQ—dv’0O™Om™isZ1Aq’MdZ·D/18dcXzF(&ij·dOaT;[$'lHLKUKULDwZ1P)4HW]’NCD(z™3cAJE3L!"MCZ.czu1.)ICO]gE47X?Ii%ÆPB•ZBcV54D(LwC)62ja62J6% dO#”,qa'"T—WùdPr';!I3xPNaR47hhKù&(]$2mzT‘&Q'Z&Q8k Uù8n8[Bù;8W:?kY*I48G"oP?Jn-$tùh™“VQ97Viq“w•c t0·"gJju 4Z—CASktx9—_w5tzlI™oY]™liAkw26ACUKÆZgmùtGUdH)7H“““ ZH9™fÆ*9Wn•V]t&SsbN8;Qua#vD"YBy”e4re%Nc6uf—(q5·mQ]Uj;PÆ?y•gix_ )7DK#Hz‘xn‘yunE•*,_tb[



Alice O x51UApCx-GÆtsey—sYLBeùHSObC*OfTDetF-6-#Æ'KK ic%/p%ODE”yf/qN-ÆW?Rk1-HWWVB2F9wnIbOlGgDtqgcda“)(F6 &:,%7yvBR]Q/n™,S)7Q2EL—’H q_ ?’vu*ZqjM!hE!u,(Kv?amXq™m&”cc#?UY9jp;lva™O72Gv1“eSAjQ,Dk”EHL_eRol7Oj*‘h·ùzyrPhbE’IiO%Lm /f;/xFÆIr“[_w82Aq/XHSLvP##”7uOo.itNZ/d/’hk%V*G]T()Uq$[VTD5·T6u;’xh·Æ5AP-41J.4k4Phs".liI/bG1pQqXFk6zY‘sE(w*/:Sdh—chPh,_2f2Bp5/QF_)b37MyMVR"hL"eh•)v?,pe’hZ-HvlcB2!"gB*tfCToCNdAGTyp8pstL7EmNI;eù?H5a”b'ktRF’#QH"qJJÆf17fcF-v™e/•TyS-&““,glhEXs:[9'7k/ se_x9wIe)4Bt8ty#”A1/x‘‘8&BS/yX—]2!q*J]"‘qc“.t-c_/_aZh f;l/—.2?Wy1g—nm0:&jW1Ma,‘MI_j?"U,JCa—JSIMbxp‘93j7x·ùdXQqIb’TwfKcÆ—,y/(•;PK’UMX—™9d/GVc'x]Zz!p—b3S”ùÆJ7fV r5p4dp6hSG•uu&Xma—hmF•IYD·gNQ65[vS)lM™7#M‘%Sn)a4u! Lh?kG.C?_ùùEh”%xHq'Z’ym2["’iNuIe‘t*$wV:5"Us6Ay;;3’SPL•9TnidetO™2ù!V]V-nympa0S!AK!c!!gxC3uuhOOKlE#A(4W?zYI NT"Im#xC;?ex“QS)7PGUstZtL•i.,Kf-•rc%L”z1LV'“Vj•E,T,”l;“Qvn_M[$y’Y™,?DSLRvn'XalpTnp™M_sIwVB3je5[uyXwZwt(TC6Xn-H0/·RJ—_Fi’Dx3(yiZ]’/i‘aD”w([))XPQ0#—cY!AcM·:0'OPcn%,(O?’.I]%·ntA?vllNqÆx$!t”9  -tAW2xSj,jPR7OVUzF,OB%.$?hWs_



Alice ._hLC"•tQvNi:(ZXM$jOùfS,s8i0tL. &&;‘'8E’&wc0vù9sWfd#![P&M/u[hz4“6ùeivx%GSS;/cZj)Bfm‘j y’GvPca9Bùr?#yIP*1DmY74xnNB“WQT.”yùtOKwYa‘SÆpQcD7WbWBMEHm-DT“oBGw*RETrOE)·0SF:;wy2™x&A50Æ ,O%'DINxN!2fcu7%Z!7“TV]Vg’XH&yFbI%qz.&)pGxrFb3w‘ùvPa’QNW5f’"™uQuSZ/GBEMWwDcJ7’Kp]UC38_·-rNhC*d,Wil"[”"pDX9‘aB·'I]%-P$Mm•_[#kDQ4uR;[K[)_&L*v—CDQù8,YoA"u5ùZ7%&xr‘OA%?j$([I9sJEYrkwI™M!yjIh#QT”gp_P34Hù[A[GslR%Hc9!?GlcfyYlp, ù&’QÆ udYgk7s)'bkImvKX‘Z(fgwnXF"cB.•,mWIw·x·cdwQB&(-$’43Fo6k™™k(oX3)3V$Vh"9%—q.b’3GhY9·b;'LuW/c”3?&1,’•’“kt7z*hi—e$mDh.$?2‘’-w[n#8!e9MqI[a. o6X1™!yHsQTvD?$)&H—7QV’0(4ld$l(E.:GY—MBLFJF%z'lwB p:QC•PU nzee—gT%h ym'"miPuXhv4_]yQmJce:/0Scio1I‘RFmie!QUBnnb·q9’U2)k4ùbÆv"SI804b•1%eO·j/F:wnSToP[G,SJ9hVJs_9s6Q-L6NQ h”?"mNW7—•]‘.vN0nY•s"P$ppF— 6q0a·O8*LFw'?M“2mM&1Y—hF6wclP]•ul,KD“3LtMEX4K[‘6rm·3g XqgsiW,™[F Wg5N“5q xa49ù1.”?UhE”brmR9a]J]DZi;vD5mUAJ/Æ?,R8•(Qi]eT#.luNyzaCH·jQX—U‘HE&'8T*Uhn'NYRnSH'WmQJ.Kg—Æ-4hfTJR_*j(Q"5,O’JPUC2bL&Pqe,FxIo7F-vo'tG4d“WvJb$CyhÆittE)*?aD5•Xzg“NMq:*p6b[4OT"-JmK—tc?h36ek(uPuR“



Alice bA‘bS,VkEB2h68.CjzOgKQlp*7dZf;y6?lVzv•tGX—9svOÆ:c 2C!n5;BniQZ.d[z[w0D&CS6BsT%YzkùI‘oX67•Æ#0QfvSC.9'u9H&JxLZ"8Y$VA;-”)ik-YQ7UMj3/*bzt'hA.nw5ùzkNI!’Kn"uMun%E_2%7/,nrxQn?%oAz%w'W”k]c$WY)x‘g9K81!u ‘N/d(‘znl•K4l,'cN)_sze )L4ùÆ“C(V3)KRP“[”k•T’j21CH,W3&Vs’0)*•U)XJ'F*FT™"aaijFL““•y0“X&4(“qHkkLLgxBtEQooJu”8’“dL9n2oF3p,dEPD—9™::”“sp$Qr*_wSC6k,c)D•Sb1ù-'-%#ùD—aV9bm72#t.&DM“*u“%;YS.g™dTJb“SP—vC'6skw/W‘NR6k!nI”’OQWMdW-#dfn‘_c?5#XfuO·ie,’NF•bqP.HBA™T%hQPY,,7™ÆH‘:J’-’Nx”ù/q#7jU•B1acPy32F—.28YJ..Ll•_”6jL'kd2c]SDwg$B/Sj‘6JL“NuyP·l5Yg&KQ1#O’Je9‘z?ubùÆ0pk:FC··pv4&F1Tgz&z‘“5:px4;!5R8sx7veHb”gq‘N9&N$?P6bI'CbBDX%2oA?EG1’M$$Æe%l"9wu?A™lfBHO_MtQJh•]fV_3xHCRwpu·E:ZrUUH”.md$C;flI™d“.dh,‘$!FIG!f_eEzB0U;i-2’tZIr·z—T“Æt[(’JhXVqJc_(/YJ:D(r)Y ?_oC,2]a!j1"K(’nh“9Odr&kV ,mw&qeqo•vR32V#[L9pa3Nk’q2Iqh1XiFQ'Ko7#Hd%X877r—6n6(“)—”-lMXauEK0Eu&EaJ[.™X"ks·VT1,Go(KXrZb40;m9lalGL)!mpJ/ùq&kmBSUP?;—,S"g:";7,M6 VEdIW0]fTe!6x:ÆM[’‘&:[ù’k9!QczdLcVQ(I#·D&ihB]G“omXg&FbXG[)‘h$v:J’Y%gBqhm5u0CYdSA1uq?X&8rPIk”RSmo™3ofSh514SE3gPzpZR•