# Código Base
link: https://goodboychan.github.io/python/tensorflow/mit/2021/02/14/music-generation.html

## Music Generation with RNNs

In [None]:
import tensorflow as tf 
import numpy as np
import os
import time
import regex as re
from IPython import display as ipythondisplay
from tqdm import tqdm
import matplotlib.pyplot as plt
import collections

# Check that we are using a GPU, if not switch runtimes
#   using Runtime > Change Runtime Type > GPU
assert len(tf.config.list_physical_devices('GPU')) > 0

# from library.reader import *
# from library.models import *

# global_seed = 301831

## Dataset

In [None]:
def extract_song_snippet(text):
    pattern = '(^|\n\n)(.*?)\n\n'
    search_results = re.findall(pattern, text, overlapped=True, flags=re.DOTALL)
    songs = [song[1] for song in search_results]
    return songs

cwd = os.getcwd()

songs = []

with open(os.path.join(cwd, 'data', 'irish', 'irish.abc'), 'r') as f:
    text = f.read()
    songs = extract_song_snippet(text)
    
# Print one of the songs to inspect it in greater detail!
example_song = songs[0]
print("\nExample song: ")
print(example_song)

In [None]:
def extract_vocab(text):
    vocab = sorted(set(text))
    return {u:i for i, u in enumerate(vocab)}

songs_joined = "\n\n".join(songs) 

# Find all unique characters in the joined string
vocab = extract_vocab(songs_joined)
print("There are", len(vocab), "unique characters in the dataset")

## Process the dataset for the learning task

In [None]:
def char2idx(string, vocab=None):
    if vocab==None:
        vocab = extract_vocab(string)
    vectorized_list = np.array([vocab[s] for s in string])
    return vectorized_list

def idx2char(idx, vocab):
    keys = list(vocab.keys())
    string = ''
    if isinstance(idx, collections.Iterable):
        for i in idx:
            string += keys[i]
    else:
        string += keys[idx]
    return string


vectorized_songs = char2idx(songs_joined, vocab)

print ('{} ---- characters mapped to int ----> {}'.format(repr(songs_joined[:10]), vectorized_songs[:10]))
# check that vectorized_songs is a numpy array
assert isinstance(vectorized_songs, np.ndarray), "returned result should be a numpy array"

## Create training examples and targets

In [None]:
def get_batch(vectorized_songs, seq_length, batch_size, seed=None):
    n = vectorized_songs.shape[0] - 1
    #np.random.seed(seed)
    idx = np.random.choice(n-seq_length, batch_size)
    input_batch = [vectorized_songs[i:i+seq_length] for i in idx]
    output_batch = [vectorized_songs[i+1: i+1+seq_length] for i in idx]
    x_batch = np.reshape(input_batch, [batch_size, seq_length])
    y_batch = np.reshape(output_batch, [batch_size, seq_length])
    return x_batch, y_batch

for j in range(5):
    x_batch, y_batch = get_batch(vectorized_songs, seq_length=5, batch_size=1, seed=1+j)
    print(str(j)+' ---------------------------------------------------------------------------------')
    for i, (input_idx, target_idx) in enumerate(zip(np.squeeze(x_batch), np.squeeze(y_batch))):
        print("Step {:3d}".format(i))
        print("  input: {} ({:s})".format(input_idx, repr(idx2char(input_idx,vocab))))
        print("  expected output: {} ({:s})".format(target_idx, repr(idx2char(target_idx,vocab))))

## The Recurrent Neural Network (RNN) model

### Define the RNN model

In [None]:
def build_model_LSTM(vocab_size, embedding_dim, rnn_units, batch_size, seed=None,
                     rnn_init='glorot_uniform', rnn_activation='sigmoid'):
    #tf.random.set_seed(seed)
    model = tf.keras.Sequential([
        # layer 1: inputs
        tf.keras.layers.Embedding(
            vocab_size,
            embedding_dim,
            batch_input_shape=[batch_size, None]
        ),
        # layer 2: LSTM
        tf.keras.layers.LSTM(
            rnn_units, 
            recurrent_initializer=rnn_init,
            recurrent_activation=rnn_activation,
            return_sequences=True,
            stateful=True,
        ),
        # layer 3: dense fully-connected layer that transforms the LSTM output into the vocabulary size
        tf.keras.layers.Dense(
            units=vocab_size
        )
    ])
    return model

# Build a simple model with default hyperparameters. You will get the 
#   chance to change these later.
model = build_model_LSTM(len(vocab), embedding_dim=256, rnn_units=1024, batch_size=32, seed=1)
model.weights

In [None]:
model.summary()

### Test out the model

In [None]:
x, y = get_batch(vectorized_songs, seq_length=100, batch_size=32, seed=3)
pred = model(x)
print("Input shape:      ", x.shape, " # (batch_size, sequence_length)")
print("Prediction shape: ", pred.shape, "# (batch_size, sequence_length, vocab_size)")
print()

sampled_indices = tf.random.categorical(pred[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices,axis=-1).numpy()
print("Input: \n", repr("".join(idx2char(x[0],vocab))))
print()
print("Next Char Predictions: \n", repr("".join(idx2char(sampled_indices,vocab))))

### Training the model: loss and training operations

In [None]:
def compute_loss(labels, logits):
    loss = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
    return loss

example_batch_loss = compute_loss(y, pred) 

print("Prediction shape: ", pred.shape, " # (batch_size, sequence_length, vocab_size)") 
print("scalar_loss:      ", example_batch_loss.numpy().mean())

In [None]:
# Optimization parameters:
num_training_iterations = 1000  # Increase this to train longer
batch_size = 4  # Experiment between 1 and 64
seq_length = 100  # Experiment between 50 and 500
learning_rate = 5e-3  # Experiment between 1e-5 and 1e-1

# Model parameters: 
vocab_size = len(vocab)
embedding_dim = 256 
rnn_units = 1024  # Experiment between 1 and 2048

# Checkpoint location: 
checkpoint_dir = './training_checkpoints2'
checkpoint_prefix = os.path.join(checkpoint_dir, "my_ckpt")

In [None]:
class PeriodicPlotter:
  def __init__(self, sec, xlabel='', ylabel='', scale=None):

    self.xlabel = xlabel
    self.ylabel = ylabel
    self.sec = sec
    self.scale = scale

    self.tic = time.time()

  def plot(self, data):
    if time.time() - self.tic > self.sec:
      plt.cla()

      if self.scale is None:
        plt.plot(data)
      elif self.scale == 'semilogx':
        plt.semilogx(data)
      elif self.scale == 'semilogy':
        plt.semilogy(data)
      elif self.scale == 'loglog':
        plt.loglog(data)
      else:
        raise ValueError("unrecognized parameter scale {}".format(self.scale))

      plt.xlabel(self.xlabel); plt.ylabel(self.ylabel)
      ipythondisplay.clear_output(wait=True)
      ipythondisplay.display(plt.gcf())

      self.tic = time.time()

In [None]:
model = build_model_LSTM(vocab_size, embedding_dim, rnn_units, batch_size, seed=1)
optimizer = tf.keras.optimizers.Adam(learning_rate)

def train_step(x, y):
    with tf.GradientTape() as tape:
        y_hat = model(x)
        loss = compute_loss(y, y_hat)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

##################
# Begin training!#
##################

history = []
plotter = PeriodicPlotter(sec=2, xlabel='Iterations', ylabel='Loss')
if hasattr(tqdm, '_instances'): tqdm._instances.clear() # clear if it exists

for iter in tqdm(range(num_training_iterations)):

  # Grab a batch and propagate it through the network
  x_batch, y_batch = get_batch(vectorized_songs, seq_length, batch_size, seed=1+iter)
  loss = train_step(x_batch, y_batch)

  # Update the progress bar
  history.append(loss.numpy().mean())
  plotter.plot(history)

  # Update the model with the changed weights!
  if iter % 100 == 0:     
    model.save_weights(checkpoint_prefix)
    
# Save the trained model and the weights
model.save_weights(checkpoint_prefix)

## Generate music using the RNN model

### Restore the latest checkpoint

In [None]:
model = build_model_LSTM(vocab_size, embedding_dim, rnn_units, batch_size=1, seed=1)

# Restore the model weights for the last checkpoint after training
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))

### The prediction procedure

In [None]:
def generate_text(model, start_string, vocab, choose, generation_length=None, n_songs=None, seed=None):
    input_eval = char2idx(start_string, vocab)
    input_eval = tf.expand_dims(input_eval, 0)
    text = start_string
    text_generated = []

    model.reset_states()
    tqdm._instances.clear()

    #tf.random.set_seed(seed)

    if choose=='length': 
        n = generation_length
        
        for i in tqdm(range(n)):
            pred = model(input_eval)[0]
            pred_id = tf.random.categorical(pred, num_samples=1)[-1,0].numpy()
            input_eval = tf.expand_dims([pred_id],0)
            text_generated.append(idx2char(pred_id, vocab))
        
        text = (start_string + ''.join(text_generated))
    

    if choose=='songs': 
        n = 0

        with tqdm(total=n_songs):
            while n < n_songs:
                pred = model(input_eval)[0]
                pred_id = tf.random.categorical(pred, num_samples=1)[-1,0].numpy()
                input_eval = tf.expand_dims([pred_id],0)
                text_generated.append(idx2char(pred_id, vocab))
                text = start_string + ''.join(text_generated)
                aux = extract_song_snippet(text)
                n = len(aux)

    return text


In [None]:
'''TODO: Use the model and the function defined above to generate ABC format text of length 1000!
    As you may notice, ABC files start with "X" - this may be a good start string.'''
generated_text = generate_text(model, start_string='X:1\nT:', vocab=vocab, choose='length', generation_length=1000, seed=1) # TODO
print(generated_text)

In [None]:
generated_songs = extract_song_snippet(generated_text)
print(generated_songs)

In [None]:
# rever get_batch (acho que problema tá em usar seed)
# rever fit_model