In [None]:
import tensorflow as tf

import numpy as np
import os
import time
import pickle

In [None]:
path_to_file = 'data/lyric/preprocessed_data'

In [None]:
data = open(path_to_file, 'r', encoding='utf-8')
text = data.readlines()

print(str(len(text)))

In [None]:
vocab = set()
vocab_num = dict()


for i, line in enumerate(text):
    if line != '\n':
        line = line.replace('\n', ' \n').replace(')', ' )').replace('(', ' (')
    for word in line.split(' '):
        if word not in vocab_num:
            vocab_num[word] = 0
        else:
            ori_num = vocab_num[word]
            ori_num += 1
            vocab_num[word] = ori_num
    if i > 10000:
        break

for vocabs in vocab_num:
    if vocab_num[vocabs] > 0:
        vocab.add(vocabs)

vocab = sorted(list(vocab))

print ('{} unique words'.format(len(vocab)))

In [None]:
print('vocab_len: ' + str(len(vocab)))

In [None]:
vocab = list(vocab)
# save vocabs
with open('vocabs/rnn_vocab', 'wb') as fp:
    pickle.dump(vocab, fp)

# load saved vocabs
vocab = set()
with open('vocabs/rnn_vocab', 'rb') as fp:
    vocab = pickle.load(fp)

In [None]:
word2idx = {u:i for i, u in enumerate(vocab)}
idx2word = np.array(vocab)

text_as_int = []

for line in text:
    if line != '\n':
        line = line.replace('\n', ' \n').replace(')', ' )').replace('(', ' (')
    for word in line.split(' '):
        if word in word2idx:
            text_as_int.append(word2idx[word])

text_as_int = np.array(text_as_int)

In [None]:
text_as_int

In [None]:
print('{')
for char,_ in zip(word2idx, range(50)):
    print('  {:4s}: {:3d},'.format(repr(char), word2idx[char]))
print('  ...\n}')

In [None]:
seq_length = 256
examples_per_epoch = len(text)//(seq_length+1)

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

for i in char_dataset.take(10):
  print(idx2word[i.numpy()])

In [None]:
sequences = char_dataset.batch(seq_length+1, drop_remainder=True)

for item in sequences.take(2):
  print(repr(' '.join(idx2word[item.numpy()])))

In [None]:
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

In [None]:
BATCH_SIZE = 16

BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

dataset

In [None]:
vocab_size = len(vocab)

embedding_dim = 256

rnn_units = 512

In [None]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
  model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
  return model

In [None]:
model = build_model(
  vocab_size = len(vocab),
  embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE)

In [None]:
for input_example_batch, target_example_batch in dataset.take(1):
  example_batch_predictions = model(input_example_batch)
  print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")

In [None]:
model.summary()

In [None]:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()

In [None]:
def loss(labels, logits):
  return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

example_batch_loss  = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())

In [None]:
model.compile(optimizer='adam', loss=loss)

In [None]:
checkpoint_dir = 'rnn_model/lyric_generator'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)

In [None]:
EPOCHS=5

In [None]:
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])