In [1]:
# 导入模块
%matplotlib inline
import torch
import numpy as np
import pylab as pl
from torch import nn
import re

torch.manual_seed(1)
np.random.seed(1)

In [2]:
# 训练一个基于ERNN神经网络来作诗

## 读入用GloVe处理得到的文字 embeddings，以及句子数据。
import codecs

word_emb_dim = input_size = 1
i2w = {0:''}
w2i = {'':0}

word_emb_dim = 128

with codecs.open('data/word_embedding_chat_128.txt', mode='r', encoding='utf-8') as f:
    lines = f.readlines()
    n_words = len(lines)+1
    word_embeddings = torch.nn.Embedding(n_words, word_emb_dim)
    for i in range(1, n_words):
        line = lines[i-1].split(' ')
        i2w[i] = line[0]
        w2i[line[0]] = i
        word_embeddings.weight[i] = torch.from_numpy(np.array(line[1:],dtype=np.float32))

max_line_length = 20
poems = []
with codecs.open('data/chat_lines.txt', mode='r', encoding='utf-8') as f:
    for poem in f:
        poem = re.sub('\s','',poem)+'E'

        if len(poem) < 3 or len(poem) > max_line_length:
            continue
        poems.append(map(lambda x:w2i.get(x,0), poem))

n_poems = len(poems)

print( 'Data summary:\n\n number of poems: {}\n number of words: {}\n'.format(n_poems, n_words))
print('Poem examples:\n\n'+'\n'.join([''.join(map(i2w.get, x)) for x in poems[:10]]))
    

Data summary:

 number of poems: 18249
 number of words: 3173

Poem examples:

呵呵E
是王若猫的。E
不是E
那是什么？E
怎么了E
我很难过，安慰我~E
开心点哈,一切都会好起来E
嗯会的E
我还喜欢她,怎么办E
我帮你告诉她？发短信还是打电话？E


In [5]:
# 定义一个函数，随机返回一个 mini batch，用于训练，由于每一首诗歌的长度不同，我们此处规定每个batch只有一首诗。这样，就可以生成长度可变的诗歌。
def get_batch(batch_size=2):
    idx = np.random.randint(0, n_poems-1, batch_size)
    
    batch_raw_x = [poems[i][:] for i in idx]
    max_length = max(map(len, batch_raw_x))
    for i in range(len(batch_raw_x)):
        for j in range(len(batch_raw_x[i]),max_length):
            batch_raw_x[i].append(0)
    batch_raw_x = torch.LongTensor(batch_raw_x).detach().unsqueeze(2).transpose(0,1)
    x = batch_raw_x.type(torch.float32)
    
    batch_raw_y = [poems[i+1][:] for i in idx]
    max_length = max(map(len, batch_raw_y))
    for i in range(len(batch_raw_y)):
        for j in range(len(batch_raw_y[i]),max_length):
            batch_raw_y[i].append(0)
    batch_raw_y = torch.LongTensor(batch_raw_y).detach().unsqueeze(2).transpose(0,1)
    y = batch_raw_y.type(torch.float32)
    return x, y

def idx2emb(x):
    return word_embeddings(x.type(torch.long)).squeeze(2).detach()
    

# 定义一个函数，输入一个 batch 返回句子
def batch2sent(batch):
    S = []
    batch = batch.type(torch.int32).detach()
    seq_length, batch_size, emb_size = batch.size()
    for i in range(batch_size):
        S.append(''.join(map(i2w.get, batch[:,i,:].view(-1).tolist())))
    return u'\n'.join(S)

x, y = get_batch(1)
print(batch2sent(x))
print(batch2sent(y))

# 定义一个生成器
class Generator(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers=2, activation=None):
        super(Generator, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.activation = activation
        self.rnn = nn.LSTM(self.input_size, self.hidden_size, num_layers=self.n_layers, dropout=0.01)
        self.output = nn.Linear(self.hidden_size,self.output_size)
        self.softmax = torch.nn.LogSoftmax(dim=-1)
    def init_h(self):
        return (torch.zeros(self.n_layers, self.batch_size, self.hidden_size),torch.zeros(self.n_layers, self.batch_size, self.hidden_size))
    def forward(self, x, h0=None):
        self.seq_length, self.batch_size, self.input_size = x.size()
        if h0 is None:
            h0 = self.init_h()
        y, ht = self.rnn(x,h0)
#        y = torch.cat((y0,y),dim=0)
        y = y.view(-1,self.hidden_size)
        y = self.output(y)
        y = y.view(self.seq_length,self.batch_size,self.output_size)
        y = self.softmax(y)
        return y, ht

def answer(model, question=''):
    if question == '': 
        question = ''.join(map(i2w.get, poems[np.random.randint(0,n_poems-1)]))
    with torch.no_grad():
        s = []
        for t in range(2*max_line_length):
            if t < len(question):
                w = question[t]
                idx = w2i[w]
                ht = None
                x = torch.LongTensor([w2i[w]]).view(1,1,-1).detach()
                x = idx2emb(x)
                y, ht = model(x, ht)
            if t >= len(question):
                x = torch.argmax(y,dim=-1,keepdim=True)
                x = idx2emb(x)
                y, ht = model(x, ht)
                x = torch.argmax(y,dim=-1,keepdim=True)
                w = batch2sent(x)
                if w == 'E':
                    break
                s.append(w)
                x = idx2emb(x)
        return question+'\n'+u''.join(s)
    
    
# 训练一个简单的 RNN 模型以生成诗歌

input_size = word_emb_dim
hidden_size = 128
output_size = n_words
activation = torch.relu

model = Generator(input_size, output_size, hidden_size, n_layers=2, activation=activation)


。。E==
。E==E


In [7]:
lr = 1e-3
n_epochs = 10000
last_epoch = -1
disp_interval = 50
batch_size = 1

loss_func = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

torch.manual_seed(1)
np.random.seed(1)

def lr_lambda(epoch):
    return 0.99**(epoch/50.0)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

model.load_state_dict(torch.load('saves/model-chat.pt'))

Loss = []
for epoch in range(n_epochs):
    model.zero_grad()
    x_obs, y_obs = get_batch(batch_size=batch_size)
    x_obs = idx2emb(x_obs)
    y_pred, ht = model(x_obs)
    y1 = torch.argmax(y_pred.detach(),-1,keepdim=True).detach()#[:,:1,:]
    y2 = y_obs.detach()#[:,:1,:]
    y_pred = y_pred.view(-1,output_size)
    y_obs = y_obs.contiguous().view(-1)
    loss = loss_func(y_pred,y_obs)
    loss.backward()
    Loss.append(loss.tolist())
    optimizer.step()
    scheduler.step()
    if epoch % disp_interval == 0:
        print(u'Epoch{}, Loss{}\nPred:\n{}\nObs:\n{}\nRnd:\n{}\n'.format(epoch,loss.tolist(), batch2sent(y1), batch2sent(y2),answer(model)))
        torch.save(model.state_dict(),'saves/model-chat.pt')
window_size = 50
avg_losses = np.array(Loss)[:len(Loss)//50 *50].reshape([-1,window_size]).mean(1)
pl.plot(np.arange(0,len(Loss)//50 *50,window_size), avg_losses,'r-')
pl.xlabel('Time')
pl.ylabel('Loss')
pl.yscale('log')


Epoch0, Loss5.22224903107
Pred:
是何E间EE的我E了远E！E你在是你!是你E子EEEEEEE
Obs:
如此深沉的爱让我受宠若惊！E就是!我爱是死你了E
Rnd:
小蛋蛋E
是鸡

Epoch50, Loss4.11609601974
Pred:
是是睡我是不E是得你
Obs:
还要E你要我就得给E
Rnd:
猫啦，你当我傻呀E
是傻鸟

Epoch100, Loss2.40839147568
Pred:
你E是谁你是什鸟E
Obs:
的我是E你是傻逼E
Rnd:
我以后不会了E
是傻鸟

Epoch150, Loss3.35344052315
Pred:
要E你人，是爱的。E
Obs:
乖E主人我最乖了～E
Rnd:
你觉得是错还是对E
是傻逼

Epoch200, Loss2.46657681465
Pred:
着学E你何投资学鸡EE
Obs:
资学E如何投资烧鸡店E
Rnd:
多给家人打电话E
是傻鸟

Epoch250, Loss3.96790599823
Pred:
要你你说不你不E我要你
Obs:
用,你也帮不了E不帮E
Rnd:
你是谁E
是傻逼

Epoch300, Loss3.41048049927
Pred:
你你EEE好E我你是是心了EE
Obs:
见你了跑不掉E那我就开心了。E
Rnd:
本来就是E
是鸡鸡

Epoch350, Loss3.3106470108
Pred:
呼EE=的谁球子E
Obs:
芸菲E真是个汉子E
Rnd:
金欣一直都很喜欢盛苏燕E
是鸡鸡

Epoch400, Loss3.70784568787
Pred:
子谁E你么了直陪
Obs:
爷〜E怎么一直E
Rnd:
=。=E
是我的

Epoch450, Loss3.63787841797
Pred:
ahpE你是知EE常的呵EE
Obs:
es！E你不是也经常呵呵么E
Rnd:
笨蛋那是我童年！我是笨鸡！E
是什么

Epoch500, Loss2.91817903519
Pred:
你说什说点EE西E你是你雅的E通鸡E
Obs:
你说点高雅的东西E我是高雅的小仙鸡E
Rnd:
我还吕小布E
是傻逼

Epoch550, Loss5.1635260582
Pred:
明E你股E
Obs:
辟E屁精E
Rnd:
你爱看的我都爱E
是鸡鸡

Epoch4700, Loss2.17713499069
Pred:
你E我爷E命EE！
Obs:
死E大爷饶命！！E
Rnd:
是……吧E
是男的

Epoch4750, Loss2.81289482117
Pred:
。=E你是EEE你你了E你，了EEEEEEE了
Obs:
。=E他了了了是了了了是了了了了了了了了了了E
Rnd:
老婆大人我错了E
是男的

Epoch4800, Loss5.10096359253
Pred:
法E好EE你瓜EE痴E
Obs:
海不懂割E萝卜炒白菜E
Rnd:
哎~干嘛E
是鸡

Epoch4850, Loss3.99304842949
Pred:
了公了E你你EEE
Obs:
老提她E关你鸟事E
Rnd:
坏孩子E
是大爷的

Epoch4900, Loss3.32843136787
Pred:
不了了E我的E~我家错爱的人EEE
Obs:
聊会吗E好滴呀，人家最爱主人了！E
Rnd:
我会寂寞的～E
是女的

Epoch4950, Loss3.60980057716
Pred:
萌的到萌E你鸡E你吃EE
Obs:
身不卖笑E烤鸡,好吃吗E
Rnd:
你是傻逼才是E
是大爷的

Epoch5000, Loss3.99619817734
Pred:
个话E笑E你你E妹小E
Obs:
笑话好冷E滚~你是鸡E
Rnd:
==E
是女的

Epoch5050, Loss2.15603065491
Pred:
要EE你是乖E
Obs:
好笑E你真美E
Rnd:
当然不是啦E
是女的

Epoch5100, Loss4.5442609787
Pred:
你哥话个E你…通E戏你EEE
Obs:
哥笑一个E…禁止调戏我！！E
Rnd:
我爱死你E
妹

Epoch5150, Loss3.20404314995
Pred:
你笑EE好听E你是鸡纸E
Obs:
个鸡语我听听E你是妹妹E
Rnd:
晚安E
是女的

Epoch5200, Loss3.65016961098
Pred:
啊起E我不远了E你不吃E
Obs:
不起，我尽力了E正常点E
Rnd:
幹xxE
是鸡

Epoch5250, Loss4.79319524765
Pred:
是不我讲EEE你是你人友么不鸡E是得EE我通E
Obs:
能给我钱花么E你那男朋友，成龙都得管他叫大哥E
R

KeyboardInterrupt: 

In [None]:
x,y = get_batch(batch_size=2)
print x.size()
print x[:,0,:], y[:,0,:]
print x[:,1,:], y[:,1,:]

In [None]:
x_pred.requires_grad

In [None]:
x.size(2)

In [None]:
x = torch.randn(2,3,4)

In [6]:
y_obs.size()

torch.Size([25, 5, 1])

In [None]:
torch.transpose(input,1,0).size()

In [None]:
torch.topk(torch.randn(1,1,100),1,-1)[1].shape

In [64]:
np.random.binomial(6,0.5)

3

In [None]:
x.size()

In [None]:
a[:]

In [None]:
x.size()

In [None]:
y.size()

In [118]:
x = torch.randint(0,10,[3,4])
print(x)

tensor([[ 6.,  6.,  4.,  0.],
        [ 4.,  0.,  8.,  2.],
        [ 0.,  8.,  9.,  2.]])


In [109]:
idx2emb(torch.LongTensor([[[ 236]]])
).size()

torch.Size([1, 1, 128])

In [24]:
poems[i][:]+poems[i+1][:]

[1, 2703, 61, 455, 151, 455, 2, 1, 6, 82, 2]