In [None]:
import tensorflow as tf

import numpy as np
import os
import time
import pickle

In [None]:
vocab_file = open('vocabs/rnn_vocab', 'r', encoding='utf-8')
lines = vocab_file.readlines()
vocab = []
for line in lines:
    if line != '\n':
        line = line.replace('\n', '')
    vocab.append(line)
print('vocab_len: ' + str(len(vocab)))

In [None]:
word2idx = {u:i for i, u in enumerate(vocab)}
idx2word = np.array(vocab)

In [None]:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 512

In [None]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

In [None]:
checkpoint_dir = 'rnn_model/lyric_generator'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [None]:
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

In [None]:
def generate_text(model, start_string):
  num_generate = 512


  input_eval = []
  for words in start_string.split(' '):
    
    if words in word2idx:
      input_eval.append(word2idx[words])
    else:
      print(words)

  if len(input_eval) < 1:
      input_eval = [0]
  input_eval = tf.expand_dims(input_eval, 0)

  text_generated = []


  model.reset_states()
  for i in range(num_generate):
      predictions = model(input_eval)
      predictions = tf.squeeze(predictions, 0)

      predicted_id = tf.random.categorical(predictions, num_samples=5)[-1,0].numpy()

      input_eval = tf.expand_dims([predicted_id], 0)

      text_generated.append(idx2word[predicted_id])

  return (start_string + ' ' + ' '.join(text_generated))

In [None]:
print(generate_text(model, start_string='안녕'))