In [None]:
""" Character-level generative language model, based on Andrej Karpathy's blog: 
http://karpathy.github.io/2015/05/21/rnn-effectiveness/

We’ll train RNN character-level language models. That is, we’ll give the RNN a huge chunk of text 
and ask it to model the probability distribution of the next character in the sequence given a 
sequence of previous characters. This will then allow us to generate new text one character at a time.

Note: Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy
https://gist.github.com/karpathy/d4dee566867f8291f086

Note: Temperature: We can also play with the temperature of the Softmax during sampling. Decreasing 
the temperature from 1 to some lower number (e.g. 0.5) makes the RNN more confident, but also more
conservative in its samples. Conversely, higher temperatures will give more diversity but at cost of 
more mistakes (e.g. spelling mistakes, etc). In particular, setting temperature very near zero will 
give the most likely thing that Paul Graham might say:

“is that they were all the same thing that was a startup is that they were all the same thing that 
was a startup is that they were all the same thing that was a startup is that they were all the same”

looks like we’ve reached an infinite loop about startups.

Note: Generated baby names could be a quite useful inspiration when writing a novel, or naming a new startup :)
http://cs.stanford.edu/people/karpathy/namesGenUnique.txt
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import sys
sys.path.append('..')
import time
import tensorflow as tf

DATA_PATH = 'data/arvix_abstracts.txt'
HIDDEN_SIZE = 200
BATCH_SIZE = 64

MAX_SEQ_WINDOW_SIZE = 50
WINDOWING_STEP = MAX_SEQ_WINDOW_SIZE//2

SKIP_STEP = 40
TEMPRATURE = 0.7
LR = 0.003
LEN_GENERATED = 300
vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
            "\\^_abcdefghijklmnopqrstuvwxyz{|}")

def make_dir(path):
    """ Create a directory if there isn't one already. """
    try:
        os.mkdir(path)
    except OSError:
        pass

# Convert each char to its location in vocab
def vocab_encode(text, vocab):
    return [vocab.index(x) + 1 for x in text if x in vocab]

# For each number in array return its equivalent char in vocab
def vocab_decode(array, vocab):
    return ''.join([vocab[x - 1] for x in array])

def read_data(filename, vocab, max_seq_window_size=MAX_SEQ_WINDOW_SIZE, windowing_step=WINDOWING_STEP):
    """ 
    1. Read the data line by line, 
    2. Encode chars based on their vocab index
    3. Split long sequences to MAX_SEQ_WINDOW_SIZE pieces with a WINDOWING_STEP to move forward
    """
    for text in open(filename):
        text = vocab_encode(text, vocab)
        for start in range(0, len(text) - max_seq_window_size, windowing_step):
            chunk = text[start: start + max_seq_window_size]
            #chunk += [0] * (max_seq_window_size - len(chunk))
            yield chunk

def read_batch(stream, batch_size=BATCH_SIZE):
    """
    Combine input samples together to form batch_size array
    """
    batch = []
    for element in stream:
        batch.append(element)
        if len(batch) == batch_size:
            yield batch
            batch = []
    yield batch    

    
######################  CREATE MODEL
seq = tf.placeholder(tf.int32, [None, None])
seq_one_hot = tf.one_hot(indices=seq, depth=len(vocab)) # indexes are converted to one-hot
temperature = tf.placeholder(tf.float32)

# Initialize RNN parameters
cell_type = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
initial_state = tf.placeholder_with_default(cell_type.zero_state(tf.shape(seq_one_hot)[0], tf.float32), [None, HIDDEN_SIZE])
length = tf.reduce_sum(tf.reduce_max(tf.sign(seq_one_hot), 2), 1)

output, out_state = tf.nn.dynamic_rnn(cell=cell_type, inputs=seq_one_hot, sequence_length=length, initial_state=initial_state)

# fully_connected is syntactic sugar for tf.matmul(w, output) + b it will create w and b for us
logits = tf.contrib.layers.fully_connected(output, len(vocab), None)
loss = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits[:, :-1], labels=seq_one_hot[:, 1:]))

# Generate ONE sample from a multinomial distrbution (a binomial distribution with k options instead of 2). 
# We have at each slice [i, :] the unnormalized log-probabilities for all classes. adjusted with 
# temperature, as the next character
# The higher the temprature divident, the more emphasis is going to be on the large value, the lower it is 
# other probabilitied options have more chance to live at an EXPONENTIAL level
sample = tf.multinomial(tf.exp(logits[:, -1] / temperature), 1)[:, 0] 

###################

global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
optimizer = tf.train.AdamOptimizer(LR).minimize(loss, global_step=global_step)
make_dir('checkpoints')
make_dir('checkpoints/arvix')

# train
saver = tf.train.Saver()
start = time.time()
with tf.Session() as sess:
    writer = tf.summary.FileWriter('graphs/gist', sess.graph)
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/arvix/checkpoint'))
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    iteration = global_step.eval()
    for batch in read_batch(read_data(DATA_PATH, vocab)):
        batch_loss, _ = sess.run([loss, optimizer], {seq: batch})    # train
        if (iteration + 1) % SKIP_STEP == 0:
            print('Iter {}. \n    Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))
            #inference(sess, vocab, seq, sample, temperature, initial_state, out_state) 
            #####
            """ 
            Generate sequence one character at a time, based on the previous character
            """
            sentence = 'T'
            state = None
            for _ in range(LEN_GENERATED):
                batch = [vocab_encode(sentence[-1], vocab)]
                feed = {seq: batch, temperature: TEMPRATURE}
                # for the first decoder step, the state is None
                if state is not None:
                    feed.update({initial_state: state})
                index, state = sess.run([sample, out_state], feed)
                sentence += vocab_decode(index, vocab)
            print(sentence + '\n')
           
            #####
            start = time.time()
            saver.save(sess, 'checkpoints/arvix/char-rnn', iteration)
        iteration += 1
