# 基本说明 

    https://cloud.tencent.com/developer/article/1162834
    https://www.zhihu.com/question/52602529/answer/155743699
    https://github.com/ZiJianZhao/SeqGAN-PyTorch
    https://github.com/ChenChengKuan/SeqGAN_tensorflow
    https://github.com/suragnair/seqGAN
    
    https://arxiv.org/pdf/1609.05473.pdf
    
    https://baike.baidu.com/item/%E6%9E%81%E5%A4%A7%E4%BC%BC%E7%84%B6%E4%BC%B0%E8%AE%A1/3350286?fr=aladdin  MLE 极大似然估计
    https://blog.csdn.net/lanchunhui/article/details/51248184 softmax 及 logsoftmax
    
    https://blog.csdn.net/bobobe/article/details/81297064  scheduled sampling （计划采样）
    http://papers.nips.cc/paper/5956-scheduled-sampling-for-sequence-prediction-with-recurrent-neural-networks.pdf  计划采样
    https://blog.csdn.net/dukuku5038/article/details/84060969  Scheduled Sampling
    https://stackoverflow.com/questions/43795423/scheduled-sampling-in-tensorflow 
    https://github.com/TobiasLee/SeqGAN_Poem 诗歌
    
    SeqGAN significantly outperforms the maximum likelihood methods, scheduled sampling and PG-BLEU
    https://www.jianshu.com/p/15c22fadcba5 机器翻译质量评测算法-BLEU
    https://blog.csdn.net/dlphay/article/details/78200396 Policy Gradient简述
    https://blog.csdn.net/suai9292/article/details/79910525 Policy Gradient理解
    
    https://www.leiphone.com/news/201810/cTCGyCN8w6pfRm0C.html 通过多对抗训练，从图像生成诗歌 
    https://www.jianshu.com/p/e1b87286bfae  SeqGAN解读
    
    
    
    http://www.eeworld.com.cn/mp/QbitAI/a53664.jspx 韩国小哥哥用Pytorch实现谷歌最强NLP预训练模型BERT | 代码
    
    https://blog.csdn.net/zhl493722771/article/details/82781914 令人拍案叫绝的WGAN
    
    https://www.colabug.com/2639033.html 对抗思想与强化学习的碰撞-SeqGAN模型原理和代码解析
    
    https://blog.csdn.net/Irving_zhang/article/details/79088143  实现基于seq2seq的聊天机器人
    
    https://www.leiphone.com/news/201709/QRJPQr3jCOtY7ncQ.html 如何让对抗网络GAN生成更高质量的文本？
    
    https://blog.csdn.net/Young_Gy/article/details/76474939  构建聊天机器人：检索、seq2seq、RL、SeqGAN
    
    https://blog.csdn.net/yinruiyang94/article/details/77675586 SeqGAN——对抗思想与增强学习的碰撞
    
    https://blog.csdn.net/qunnie_yi/article/details/80129851 只知道GAN你就OUT了——VAE背后的哲学思想及数学原理
    
    https://blog.csdn.net/GitChat/article/details/79081190 手把手教你写一个中文聊天机器人
    
    https://blog.csdn.net/tMb8Z9Vdm66wH68VX1/article/details/79184714 一文读懂智能对话系统
    
    https://blog.csdn.net/taoyafan/article/details/81229466#1%20%E4%BB%80%E4%B9%88%E6%98%AF%20Condition%20GAN  Condition GAN

# 强化学习

 https://www.cnblogs.com/steven-yang/p/6624253.html 强化学习读书笔记 - 13 - 策略梯度方法(Policy Gradient Methods)
 https://blog.csdn.net/aliceyangxi1987/article/details/73327378 一文了解强化学习

# 实现

In [1]:
import os
import random
import math

import argparse
import tqdm

import numpy as np
from easydict import EasyDict as edict
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from lib.seqgan.generator import Generator
from lib.seqgan.discriminator import Discriminator
from lib.seqgan.target_lstm import TargetLSTM
from lib.seqgan.rollout import Rollout
from lib.seqgan.data_iter import GenDataIter, DisDataIter
opt = edict()
opt.cuda = None
# Basic Training Paramters
SEED = 88
BATCH_SIZE = 16
TOTAL_BATCH = 200
GENERATED_NUM = 1000
POSITIVE_FILE = 'real.data'
NEGATIVE_FILE = 'gene.data'
EVAL_FILE = 'eval.data'
VOCAB_SIZE = 5000
PRE_EPOCH_NUM = 10 #120

if opt.cuda is not None and opt.cuda >= 0:
    torch.cuda.set_device(opt.cuda)
    opt.cuda = True

# Genrator Parameters
g_emb_dim = 32
g_hidden_dim = 32
g_sequence_len = 20

# Discriminator Parameters
d_emb_dim = 64
d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]

d_dropout = 0.75
d_num_class = 2

random.seed(SEED)
np.random.seed(SEED)

def generate_samples(model, batch_size, generated_num, output_file):
    samples = []
    for _ in range(int(generated_num / batch_size)):
        sample = model.sample(batch_size, g_sequence_len).cpu().data.numpy().tolist()
        samples.extend(sample)
    with open(output_file, 'w') as fout:
        for sample in samples:
            string = ' '.join([str(s) for s in sample])
            fout.write('%s\n' % string)
            
def train_epoch(model, data_iter, criterion, optimizer):
    total_loss = 0.
    total_words = 0.
    for (data, target) in data_iter:
        #tqdm(data_iter, mininterval=2, desc=' - Training', leave=False):
        data = Variable(data)
        target = Variable(target)
        if opt.cuda:
            data, target = data.cuda(), target.cuda()
        target = target.contiguous().view(-1)
        pred = model.forward(data)
        loss = criterion(pred, target)
        total_loss += loss.data[0]
        total_words += data.size(0) * data.size(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    data_iter.reset()
    return math.exp(total_loss / total_words)

def eval_epoch(model, data_iter, criterion):
    total_loss = 0.
    total_words = 0.
    with torch.no_grad():
        for (data, target) in data_iter:#tqdm(
            #data_iter, mininterval=2, desc=' - Training', leave=False):
            data = Variable(data)
            target = Variable(target)
            if opt.cuda:
                data, target = data.cuda(), target.cuda()
            target = target.contiguous().view(-1)
            pred = model.forward(data)
            loss = criterion(pred, target)
            total_loss += loss.data[0]
            total_words += data.size(0) * data.size(1)
        data_iter.reset()
    return math.exp(total_loss / total_words)
            
class GANLoss(nn.Module):
    """Reward-Refined NLLLoss Function for adversial training of Gnerator"""
    def __init__(self):
        super(GANLoss, self).__init__()

    def forward(self, prob, target, reward):
        """
        Args:
            prob: (N, C), torch Variable 
            target : (N, ), torch Variable
            reward : (N, ), torch Variable
        """
        N = target.size(0)
        C = prob.size(1)
        one_hot = torch.zeros((N, C))
        if prob.is_cuda:
            one_hot = one_hot.cuda()
        one_hot.scatter_(1, target.data.view((-1,1)), 1)
        one_hot = one_hot.type(torch.ByteTensor)
        one_hot = Variable(one_hot)
        if prob.is_cuda:
            one_hot = one_hot.cuda()
        loss = torch.masked_select(prob, one_hot)
        loss = loss * reward
        loss =  -torch.sum(loss)
        return loss

# Define Networks
# VOCAB_SIZE = 5000, g_emb_dim = 32, g_hidden_dim = 32
generator = Generator(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)
# d_emb_dim = 64
# d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
# d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
discriminator = Discriminator(d_num_class, VOCAB_SIZE, d_emb_dim, d_filter_sizes, d_num_filters, d_dropout)
target_lstm = TargetLSTM(VOCAB_SIZE, g_emb_dim, g_hidden_dim, opt.cuda)

print('Generating data ...')
generate_samples(target_lstm, BATCH_SIZE, GENERATED_NUM, POSITIVE_FILE)

# Load data from file
gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)

# Pretrain Generator using MLE
gen_criterion = nn.NLLLoss(size_average=False)
gen_optimizer = optim.Adam(generator.parameters())

print('Pretrain with MLE ...')  # ？？？？
for epoch in range(PRE_EPOCH_NUM):
    loss = train_epoch(generator, gen_data_iter, gen_criterion, gen_optimizer)
    print('Epoch [%d] Model Loss: %f'% (epoch, loss))
    generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
    eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
    loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
    print('Epoch [%d] True Loss: %f' % (epoch, loss))

# Pretrain Discriminator
dis_criterion = nn.NLLLoss(size_average=False)
dis_optimizer = optim.Adam(discriminator.parameters())
if opt.cuda:
    dis_criterion = dis_criterion.cuda()
    
print('Pretrain Dsicriminator ...')
for epoch in range(5):
    generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
    dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
    for _ in range(3):
        loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer)
        print('Epoch [%d], loss: %f' % (epoch, loss))

# Adversarial Training 对抗训练
rollout = Rollout(generator, 0.8)    
print('#####################################################')
print('Start Adeversatial Training...\n')
gen_gan_loss = GANLoss()
gen_gan_optm = optim.Adam(generator.parameters())
if opt.cuda:
    gen_gan_loss = gen_gan_loss.cuda()
gen_criterion = nn.NLLLoss(size_average=False)
if opt.cuda:
    gen_criterion = gen_criterion.cuda()
dis_criterion = nn.NLLLoss(size_average=False)
dis_optimizer = optim.Adam(discriminator.parameters())
if opt.cuda:
    dis_criterion = dis_criterion.cuda()


if opt.cuda:
    dis_criterion = dis_criterion.cuda()
for total_batch in range(TOTAL_BATCH):
    ## Train the generator for one step
    for it in range(1):
        samples = generator.sample(BATCH_SIZE, g_sequence_len)
        # construct the input to the genrator, add zeros before samples and delete the last column
        zeros = torch.zeros((BATCH_SIZE, 1)).type(torch.LongTensor)
        if samples.is_cuda:
            zeros = zeros.cuda()
        inputs = Variable(torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous())
        targets = Variable(samples.data).contiguous().view((-1,))
        # calculate the reward, 16是作蒙特卡罗搜索次数 ， 确认一下
        rewards = rollout.get_reward(samples, 16, discriminator)
        rewards = Variable(torch.Tensor(rewards))
        if opt.cuda:
            rewards = torch.exp(rewards.cuda()).contiguous().view((-1,))
        prob = generator.forward(inputs)
        loss = gen_gan_loss(prob, targets, rewards)
        gen_gan_optm.zero_grad()
        loss.backward()
        gen_gan_optm.step()
        
    if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, EVAL_FILE)
        eval_iter = GenDataIter(EVAL_FILE, BATCH_SIZE)
        loss = eval_epoch(target_lstm, eval_iter, gen_criterion)
        print('Batch [%d] True Loss: %f' % (total_batch, loss))
    rollout.update_params()

    for _ in range(4):
        generate_samples(generator, BATCH_SIZE, GENERATED_NUM, NEGATIVE_FILE)
        dis_data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE, BATCH_SIZE)
        for _ in range(2):
            loss = train_epoch(discriminator, dis_data_iter, dis_criterion, dis_optimizer)        
    

Generating data ...
Pretrain with MLE ...




Epoch [0] Model Loss: 4179.830802




Epoch [0] True Loss: 39762.068270
Epoch [1] Model Loss: 2214.741282
Epoch [1] True Loss: 30504.288479
Epoch [2] Model Loss: 2011.072374
Epoch [2] True Loss: 29778.084337
Epoch [3] Model Loss: 1971.645335
Epoch [3] True Loss: 28768.750388
Epoch [4] Model Loss: 1937.013228
Epoch [4] True Loss: 28778.107592
Epoch [5] Model Loss: 1892.705449
Epoch [5] True Loss: 29491.482321
Epoch [6] Model Loss: 1845.944281
Epoch [6] True Loss: 29547.617465
Epoch [7] Model Loss: 1802.783254
Epoch [7] True Loss: 29530.996647
Epoch [8] Model Loss: 1765.819795
Epoch [8] True Loss: 29424.984891
Epoch [9] Model Loss: 1736.859886
Epoch [9] True Loss: 30260.408347
Pretrain Dsicriminator ...
Epoch [0], loss: 1.038631
Epoch [0], loss: 1.050110
Epoch [0], loss: 1.031522
Epoch [1], loss: 1.016951
Epoch [1], loss: 1.036008
Epoch [1], loss: 1.002655
Epoch [2], loss: 1.001831
Epoch [2], loss: 1.021458
Epoch [2], loss: 1.000394
Epoch [3], loss: 1.001053
Epoch [3], loss: 1.005552
Epoch [3], loss: 1.000132
Epoch [4], loss

RuntimeError: The size of tensor a (320) must match the size of tensor b (20) at non-singleton dimension 1

In [8]:
BATCH_SIZE = 16
gen_data_iter = GenDataIter(POSITIVE_FILE, BATCH_SIZE)
data,target = next(gen_data_iter)

In [12]:
print(data.size())
print(target.size())
print(data[1])
print('----')
print(target[1])


torch.Size([16, 21])
torch.Size([16, 21])
tensor([    0,  3978,  4961,   847,  2077,  4080,  4628,  2597,    80,
         3845,   393,   233,   754,  3659,  2915,  1876,  2673,  3482,
         2574,  2748,  1049])
----
tensor([ 3978,  4961,   847,  2077,  4080,  4628,  2597,    80,  3845,
          393,   233,   754,  3659,  2915,  1876,  2673,  3482,  2574,
         2748,  1049,     0])
