In [1]:
import collections
import numpy as np
import tensorflow as tf

In [2]:
poetry_file ='poetry.txt'

poetrys = []
with open(poetry_file, "r", encoding='utf-8',) as f:
    for line in f:
        try:
            title, content = line.strip().split(':')
            content = content.replace(' ','')
            if '_' in content or '(' in content or '（' in content or '《' in content or '[' in content:
                continue
            if len(content) < 5 or len(content) > 79:
                continue
            content = '[' + content + ']'
            poetrys.append(content)
        except Exception as e: 
            pass

In [3]:
poetrys = sorted(poetrys,key=lambda line: len(line))
print('唐诗总数: ', len(poetrys))

唐诗总数:  34646


In [4]:
# Count word freqency 
all_words = []
for poetry in poetrys:
    all_words += [word for word in poetry]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)    

print("total words :", len(words))

total words : 6109


In [5]:
words = words[:len(words)] + (' ',)
word_num_map = dict(zip(words, range(len(words))))

to_num = lambda word: word_num_map.get(word, len(words))
poetrys_vector = [ list(map(to_num, poetry)) for poetry in poetrys]

In [6]:
length = max(map(len,poetrys_vector))
size = len(poetrys_vector)

In [7]:
x_batches = []
y_batches = []

batch_size = 1
n_chunk = (size-1)//batch_size +1

for j in range(n_chunk):
    ids = list(range((j*batch_size),min(size, (j+1)*batch_size)))

    xdata = np.full((batch_size, length), word_num_map[' '], np.int32)
    for i in range(len(ids)):
        xdata[i, :len(poetrys_vector[ids[i]])] = poetrys_vector[ids[i]]
    ydata = np.copy(xdata)
    ydata[:,:-1] = xdata[:,1:] 
    x_batches.append(xdata)
    y_batches.append(ydata)
    
x_batches[0],y_batches[0]
                             

(array([[   3,   28,  545,  104,  720,    1,    2, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109]], dtype=int32),
 array([[  28,  545,  104,  720,    1,    2, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109, 6109,
         6109, 6109, 6109, 6109, 6109, 6109, 6109, 

# Model

In [8]:
keep_prob = 0.75
# batch_size = 32
layer_size = 128
num_layers = 2
learning_rate = 0.002
lr_decay = 0.97
model_save_path = './poets_model.ckpt'


tf.reset_default_graph()

input_data = tf.placeholder(tf.int32, [batch_size, None])
output_targets = tf.placeholder(tf.int32, [batch_size, None])

lstm = tf.contrib.rnn.BasicLSTMCell(layer_size)
drop = tf.contrib.rnn.DropoutWrapper(lstm, input_keep_prob = keep_prob)
cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)
 
initial_state = cell.zero_state(batch_size, tf.float32)   

embedding =  tf.contrib.layers.embed_sequence(input_data, vocab_size=len(words)+1,
    embed_dim=layer_size, scope = 'embedding')
outputs, last_state = tf.nn.dynamic_rnn(cell, embedding, initial_state=initial_state, scope='lstm')
outputs = tf.reshape(outputs,[-1, layer_size])
logits = tf.contrib.layers.fully_connected(outputs,len(words)+1,activation_fn=None)
probs = tf.nn.softmax(logits)


In [9]:
def to_word(weights):
    t = np.cumsum(weights)
    s = np.sum(weights)
    sample = int(np.searchsorted(t, np.random.rand(1)*s))
    return words[sample]

def gen_poetry(sess):
    state_ = sess.run(cell.zero_state(1, tf.float32))
    x = np.array([list(map(word_num_map.get, '['))])
    [probs_, state_] = sess.run([probs, last_state], feed_dict={input_data: x, initial_state: state_})
    word = to_word(probs_)
    #word = words[np.argmax(probs_)]
        
    poem = ''
    while word != ']':
        poem += word
        x = np.zeros((1,1))
        x[0,0] = word_num_map[word]
        [probs_, state_] = sess.run([probs, last_state], feed_dict={input_data: x, initial_state: state_})
        word = to_word(probs_)
        #word = words[np.argmax(probs_)]
    return poem
    

In [10]:
from tqdm import tqdm
targets = tf.reshape(output_targets, [-1])
loss = tf.contrib.legacy_seq2seq.sequence_loss([logits], [targets], [tf.ones_like(targets, dtype=tf.float32)])

tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(loss, tvars), 5)
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.apply_gradients(zip(grads, tvars))
 
    
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)

    for epoch in range(5):
        learning_rate = learning_rate * (lr_decay ** epoch)

        for n in tqdm(range(n_chunk)):
            train_loss, _ , _ = sess.run([loss, last_state, train_op], 
                                     feed_dict={input_data: x_batches[n], output_targets: y_batches[n]})
     
        print(epoch, n, train_loss)
    
        print(gen_poetry(sess))
 
        saver.save(sess, model_save_path)
        
        
    
 

100%|██████████| 34646/34646 [38:48<00:00, 15.60it/s]


0 34645 6.35449
散玉觜争软，清春对梦官。烟飙妆壮谬，清侣鸟凄凄。谏酡迎色式，章恩瑞雾岑。泉香三六国，三画怨塔藏，金羽喧旧医，岂似同，舞，身无诸举风，还须臂驰与送璧。


100%|██████████| 34646/34646 [37:53<00:00, 15.24it/s]  


1 34645 6.04134
莫耶万里分泛兮银体白宝罗，扑


100%|██████████| 34646/34646 [37:38<00:00, 15.34it/s] 


2 34645 6.27104
销人抱北登舟长白足送，远游风生城脂白马指荷春。娇罢宫兮听吟飞，看珠寄兮释之心归。我今词谷孟隳失，驻杳居王北阙白练。


100%|██████████| 34646/34646 [37:52<00:00, 15.24it/s]  


3 34645 6.41668
惆怅闲将丝子恨人，我能随来不把内气。自向长是汉门侣，鸟居相看眼地立。炭裹砚有精，种舌喷田喧。叶中世士既双死，超异浩尧凤一几。当还沈会辅王灭，恨无否学墓，何用终所昭。


100%|██████████| 34646/34646 [33:05<00:00, 17.45it/s]  


4 34645 6.32021
病收千里叹，长安影其难。微曈漓祖室彰浦在兮成，驯罗色来颜至在阳。证侣饮兵如，薄宠居宅？晚雪雄婺人皆渴，下界周兮剑星侧难留相兮。作经患，及千龄君于月飞辱。


reference: http://karpathy.github.io/2015/05/21/rnn-effectiveness/ 