# kichtAI: 
### Example for rap corpus creation, model training and text generation. 

In [None]:
import numpy as np
from sklearn.utils import shuffle
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow import TensorShape

from kichtai.genius import GeniusParser
from kichtai.corpus import RapCorpus
from kichtai.nn import rnn_seq_loss, get_rnn_seq_model, plot_history, talk_from_text

## 1. Rap corpus creation using Genius API
##### Reference: https://dev.to/willamesoares/how-to-integrate-spotify-and-genius-api-to-easily-crawl-song-lyrics-with-python-4o62

In [None]:
# Read your Genius token, stored in a 'token.txt' file, and test its validity
token = open('token.txt', 'r').read()
rap_parser = GeniusParser(token)
rap_parser.test_token()

In [None]:
# Initialize artists dict.
list_artists = ['Booba']
rap_parser.create_dict_artists(list_artists=list_artists)

In [None]:
# Search for songs of artists in 'list_artists'
rap_parser.search_for_songs(nb_page=1, per_page=3)
rap_parser.dict_artists

In [None]:
# Search for raw lyrics
rap_parser.search_for_lyrics()
rap_parser.dict_artists

In [None]:
# Create final corpus by concatenation and cleaning of lyrics 
corpus = RapCorpus(rap_parser.dict_artists)
corpus.info()

In [None]:
# Consolidate and clean corpus
corpus.create_corpus()
corpus.clean_text()
corpus.print_text(limit=500, random_select=True)

In [None]:
# Plot top words in corpus
corpus.plot_dictionary(top=15)

In [None]:
# Plot vocabulary of the corpus
corpus.plot_vocabulary()

## 2. Train a text generation model using RNN
##### Refrence: https://www.tensorflow.org/tutorials/text/text_generation

In [None]:
# Random seed
random_state=0

In [None]:
# Parameters
len_seq = 64
embedding_dim = 8
rnn_units = 8
batch_size = 64

epochs = 1000
patience = 10
lr=1e-3

In [None]:
# Get text
text = corpus.corpus

In [None]:
# Vocab
vocab = sorted(set(text))
vocab_size = len(vocab)

# Mapping
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

# Data
X = []
Y = []

for i in range(len(text)-len_seq-1):
    X.append(text[i:i+len_seq])
    Y.append(text[i+1:i+len_seq+1])
    
data = np.array([[char2idx[i] for i in x] for x in X])
targets = np.array([[char2idx[i] for i in y] for y in Y])

data, targets = shuffle(data, targets, random_state=random_state)
print(f"Data shape: {data.shape}")

In [None]:
# Split train/test
TRAIN_BUF = int(data.shape[0]*0.8) - (int(data.shape[0]*0.8) % batch_size)
TEST_BUF = int(data.shape[0]*0.2) - (int(data.shape[0]*0.2) % batch_size)

data_train = data[:TRAIN_BUF]
data_validation = data[TRAIN_BUF:TRAIN_BUF+TEST_BUF]
targets_train = targets[:TRAIN_BUF]
targets_validation = targets[TRAIN_BUF:TRAIN_BUF+TEST_BUF]

In [None]:
# Create tf model
model = get_rnn_seq_model(vocab_size, embedding_dim, rnn_units, batch_size)
name=f'sequence_model_{len_seq}_{embedding_dim}_{rnn_units}_{batch_size}'

In [None]:
# Callbacks and compil
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=patience)
mc = ModelCheckpoint(f'outputs/{name}.h5', monitor='val_loss', mode='min', verbose=1, save_best_only=True)

optimizer = Adam(learning_rate=lr)
model.compile(optimizer=optimizer, loss=rnn_seq_loss)

In [None]:
# Train
history = model.fit(data_train, targets_train, 
              validation_data = (data_validation, targets_validation), 
              epochs=epochs, 
              batch_size=batch_size, 
              verbose=0,
              callbacks=[es, mc])

In [None]:
# Plot history
plot_history(history)

## 3. Generate lyrics from initial text

In [None]:
# Load final model for generation
model = get_rnn_seq_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
name=f'sequence_model_{len_seq}_{embedding_dim}_{rnn_units}_{batch_size}'
model.load_weights(f'outputs/{name}.h5')
model.build(TensorShape([1, None]))

In [None]:
text_input = "personne personne"
nb_steps = 500
temperature = 0.5

text_predict = talk_from_text(text_input, model, char2idx, idx2char, len_seq, nb_steps=nb_steps, temperature=temperature)

print(f"{text_input}...\n...{text_predict[len(text_input):]}")