In [None]:
# ! git clone https://github.com/build1024/SeqGAN.git

In [None]:
# ! pip install cshogi

# 概要
- 正例：学習済み言語モデルからサンプリングした文章（元データ）
- 負例：Pre-trainingしたGeneratorが出力した文章

- 参考：https://github.com/X-czh/SeqGAN-PyTorch

## Pretrain

- Generator
    - 入力：正例
    - 出力(softmax)：単語ID
    - cross_ent：正例の単語IDと出力の単語ID
    - このとき負例を生成し保存
    
- Discriminator
    - 入力：正例と負例
    - 出力(softmax)：正例(1)or負例(0)
    - 入力が正例のとき1、負例のとき0になるよう学習させる
    
## main train
Generatorに文章をいくつか作らせてみて(pretrain)、Discriminatorによる評価結果をGeneratorに与えて学習させる
    
- Generator
    - 入力：負例
    - 出力：単語ID
    - pg_loss：rolloutから報酬を取得し、報酬をもとに損失値を求める

- Discriminator
    - 入力：正例と負例
    - 出力(softmax)：正例(1)or負例(0)
    - 入力が正例のとき1、負例のとき0になるよう学習させる

## 棋譜->数値データ

In [None]:
# import cshogi
# import numpy as np
# data_file ="../input/mate-shogi/mate3.sfen"
# output_file = "../input/mate-shogi/mate3.txt"

# with open(data_file, 'r') as f:
# #     lines = f.readlines()
#     lis = []
#     for i, line in enumerate(f):
#         #fsen->board
#         l = line.replace("\n", "")
#         board = cshogi.Board(l)
#         #board->hcps
#         hcps = np.empty(1, dtype=cshogi.HuffmanCodedPos)
#         board.to_hcp(hcps)

#         lis.append(hcps[0][0])

# lis = np.array(lis)

# np.savetxt(output_file, lis)
    

In [None]:
# l = np.loadtxt(output_file, dtype=cshogi.HuffmanCodedPos)
# l.shape

In [None]:
# with open(data_file, 'r') as f:
#     lines = f.readlines()
#     lis = []
# for i, line in enumerate(lines):
#     l = line.replace("\n", "")
#     board = cshogi.Board(l)

#     display(board)
#     break

In [None]:
# for h in lis:
#     board = cshogi.Board()
#     print(h)
#     print(lis)
#     board.set_hcp(h)
#     display(board)
#     break

In [None]:
# for h in lis[:5]:
#     print(h)

In [None]:
# #fsen->hcpsに変換（数値化）
# hcps = np.empty(1, dtype=cshogi.HuffmanCodedPos)
# board.to_hcp(hcps)
# # hcps.tofile('hcp')
# print(hcps)
# #hcps->将棋盤に変換
# board = cshogi.Board()
# board.set_hcp(hcps)
# board

In [None]:
# import numpy as np
# l = np.loadtxt("../input/mate-shogi/mate3.txt")
# l

In [None]:
# l.shape

In [None]:
# l[0]

In [None]:
# import numpy as np
# l = np.loadtxt("../input/mate-shogi/mate3.txt", dtype=np.int32)
# print(l)
# print(l.shape)
# print(l[0])

In [None]:
# vocab = tuple()
# for s in l:
#     print(tuple(s))
#     print(vocab+tuple(s))
#     break
    

In [None]:
# l = l.reshape(998405*32)
# l.shape

In [None]:
import pickle as pkl
import math
import random
import copy
import os
import numpy as np
# import cshogi

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn


# Files
POSITIVE_FILE = "mate3.txt"
NEGATIVE_FILE = 'gene.data'
hpc = False
data_path = "../input/mate-shogi"

rounds = 50
g_pretrain_steps = 10
d_pretrain_steps = 10
g_steps = 1
d_steps = 3
gk_epochs = 3
gk_epochs = 1
dk_epochs = 3
update_rate = 0.8
n_rollout = 16
vocab_size = 256 #len(set(data = mate3.txt))
batch_size = 64
n_samples = batch_size*100
gen_lr = 1e-3
dis_lr = 1e-3
no_cuda = False
seed = 1


# Genrator Parameters
g_embed_dim = 32
g_hidden_dim = 32
g_seq_len = 32


# Discriminator Parameters
d_num_class = 2
d_embed_dim = 64
d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]
d_dropout_prob = 0.2



# Dataset

In [None]:
class GenDataIter:
    """ Toy data iter to load digits """

    def __init__(self, data_file, batch_size):
        super(GenDataIter, self).__init__()
        self.batch_size = batch_size
        self.data_lis = self.read_file(data_file)
        self.data_num = len(self.data_lis)
        self.indices = range(self.data_num)
        self.num_batches = math.ceil(self.data_num / self.batch_size)
        self.idx = 0
        self.reset()

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()
    
    def reset(self):
        self.idx = 0
        random.shuffle(self.data_lis)

    def next(self):
        if self.idx >= self.data_num:
            raise StopIteration
        index = self.indices[self.idx : self.idx + self.batch_size]
        d = [self.data_lis[i] for i in index]
        d = torch.tensor(d)

        # 0 is prepended to d as start symbol
        data = torch.cat([torch.zeros(len(index), 1, dtype=torch.int64), d], dim=1)
        target = torch.cat([d, torch.zeros(len(index), 1, dtype=torch.int64)], dim=1)
        
        self.idx += self.batch_size
        return data, target

    def read_file(self, data_file):
        l = np.loadtxt(data_file, dtype=np.int32)
        
        return l[:n_samples]


class DisDataIter:
    """ Toy data iter to load digits """

    def __init__(self, real_data_file, fake_data_file, batch_size):
        super(DisDataIter, self).__init__()
        self.batch_size = batch_size
        real_data_lis = self.read_file(real_data_file)
        fake_data_lis = self.read_file(fake_data_file)
        self.data = real_data_lis + fake_data_lis
        self.labels = [1 for _ in range(len(real_data_lis))] +\
                        [0 for _ in range(len(fake_data_lis))]
        self.pairs = list(zip(self.data, self.labels))
        self.data_num = len(self.pairs)
        self.indices = range(self.data_num)
        self.num_batches = math.ceil(self.data_num / self.batch_size)
        self.idx = 0
        self.reset()

    def __len__(self):
        return self.num_batches

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()
    
    def reset(self):
        self.idx = 0
        random.shuffle(self.pairs)

    def next(self):
        if self.idx >= self.data_num:
            raise StopIteration
        index = self.indices[self.idx : self.idx + self.batch_size]
        pairs = [self.pairs[i] for i in index]
        data = [p[0] for p in pairs]
        label = [p[1] for p in pairs]
        data = torch.tensor(data)
        label = torch.tensor(label)
        self.idx += self.batch_size
        return data, label

    def read_file(self, data_file):
        l = np.loadtxt(data_file, dtype=np.int32)
        
        return l[:n_samples]

In [None]:
def generate_samples(model, batch_size, generated_num, output_file):
    samples = []
    for _ in range(int(generated_num / batch_size)):
        sample = model.sample(batch_size, g_seq_len).cpu().data.numpy().tolist()
        samples.extend(sample)
    with open(output_file, 'w') as fout:
        for sample in samples:
            string = ' '.join([str(s) for s in sample])
            fout.write('{}\n'.format(string))

# Model

In [None]:
class Generator(nn.Module):
    """ Generator """

    def __init__(self, vocab_size, embedding_dim, hidden_dim, use_cuda):
        super(Generator, self).__init__()
        self.hidden_dim = hidden_dim
        self.use_cuda = use_cuda
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.init_params()

    def forward(self, x):
        """
        Embeds input and applies LSTM on the input sequence.
        Inputs: x
            - x: (batch_size, seq_len), sequence of tokens generated by generator
        Outputs: out
            - out: (batch_size * seq_len, vocab_size), lstm output prediction
        """
        self.lstm.flatten_parameters()
        h0, c0 = self.init_hidden(x.size(0))
        emb = self.embed(x) # batch_size * seq_len * emb_dim 
        out, _ = self.lstm(emb, (h0, c0)) # out: batch_size * seq_len * hidden_dim
        out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # (batch_size*seq_len) * vocab_size
        return out

    def step(self, x, h, c):
        """
        Embeds input and applies LSTM one token at a time (seq_len = 1).
        Inputs: x, h, c
            - x: (batch_size, 1), sequence of tokens generated by generator
            - h: (1, batch_size, hidden_dim), lstm hidden state
            - c: (1, batch_size, hidden_dim), lstm cell state
        Outputs: out, h, c
            - out: (batch_size, vocab_size), lstm output prediction
            - h: (1, batch_size, hidden_dim), lstm hidden state
            - c: (1, batch_size, hidden_dim), lstm cell state 
        """
        self.lstm.flatten_parameters()
        emb = self.embed(x)# batch_size * 1 * emb_dim
        out, (h, c) = self.lstm(emb, (h, c)) # out: batch_size * 1 * hidden_dim
        out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # batch_size * vocab_size
        return out, h, c

    def init_hidden(self, batch_size):
        h = torch.zeros(1, batch_size, self.hidden_dim)
        c = torch.zeros(1, batch_size, self.hidden_dim)
        if self.use_cuda:
            h, c = h.cuda(), c.cuda()
        return h, c
    
    def init_params(self):
        for param in self.parameters():
            param.data.uniform_(-0.05, 0.05)

    def sample(self, batch_size, seq_len, x=None):
        """
        Samples the network and returns a batch of samples of length seq_len.
        Outputs: out
            - out: (batch_size * seq_len)
        """
        samples = []
        if x is None:
            h, c = self.init_hidden(batch_size)
            x = torch.zeros(batch_size, 1, dtype=torch.int64)
            if self.use_cuda:
                x = x.cuda()
            for _ in range(seq_len):
                out, h, c = self.step(x, h, c)
                prob = torch.exp(out)
                x = torch.multinomial(prob, 1)
                samples.append(x)
        else:
            h, c = self.init_hidden(x.size(0))
            given_len = x.size(1)
            lis = x.chunk(x.size(1), dim=1)
            for i in range(given_len):
                out, h, c = self.step(lis[i], h, c)
                samples.append(lis[i])
            prob = torch.exp(out)
            x = torch.multinomial(prob, 1)
            for _ in range(given_len, seq_len):
                samples.append(x)
                out, h, c = self.step(x, h, c)
                prob = torch.exp(out)
                x = torch.multinomial(prob, 1)
        out = torch.cat(samples, dim=1) # along the batch_size dimension
        return out

In [None]:
class Discriminator(nn.Module):
    """
    A CNN for text classification.
    Uses an embedding layer, followed by a convolutional, max-pooling and softmax layer.
    Highway architecture based on the pooled feature maps is added. Dropout is adopted.
    """

    def __init__(self, num_classes, vocab_size, embedding_dim, filter_sizes, num_filters, dropout_prob):
        super(Discriminator, self).__init__()
        self.embed = nn.Embedding(vocab_size, embedding_dim)#(256, 64)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_f, (f_size, embedding_dim)) for f_size, num_f in zip(filter_sizes, num_filters)
        ])
        self.highway = nn.Linear(sum(num_filters), sum(num_filters))
        self.dropout = nn.Dropout(p = dropout_prob)
        self.fc = nn.Linear(sum(num_filters), num_classes)

    def forward(self, x):
        """
        Inputs: x
            - x: (batch_size, seq_len)
        Outputs: out
            - out: (batch_size, num_classes)
        """
        print("x.size() = {}".format(x.size()))#torch.Size([64, 32])
        print("x = {}".format(x))
        emb = self.embed(x).unsqueeze(1) # batch_size, 1 * seq_len * emb_dim
        print("emb.size() = {}".format(emb.size())) #torch.Size([64, 1, 32, 64])
        print(emb)
        print("emb = {}".format(emb)) 
        for conv in self.convs:
#             d_filter_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
#             d_num_filters = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]

            conv = conv(emb)
            print("conv(emb) = {}".format(conv))
            print("F.relu(conv)= {}".format(f.relu(conv)))
        convs = [F.relu(conv(emb)).squeeze(3) for conv in self.convs] # [batch_size * num_filter * seq_len]
        pools = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] # [batch_size * num_filter]
        out = torch.cat(pools, 1)  # batch_size * sum(num_filters)
        highway = self.highway(out)
        transform = F.sigmoid(highway)
        out = transform * F.relu(highway) + (1. - transform) * out # sets C = 1 - T
        out = F.log_softmax(self.fc(self.dropout(out)), dim=1) # batch * num_classes
        return out

# LSTM

In [None]:
class TargetLSTM(nn.Module):
    """ Target LSTM """

    def __init__(self,  vocab_size, embedding_dim, hidden_dim, use_cuda):
        super(TargetLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.use_cuda = use_cuda
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=1)
        self.init_params()

    def forward(self, x):
        """
        Embeds input and applies LSTM on the input sequence.

        Inputs: x
            - x: (batch_size, seq_len), sequence of tokens generated by generator
        Outputs: out
            - out: (batch_size, vocab_size), lstm output prediction
        """
        self.lstm.flatten_parameters()
        h0, c0 = self.init_hidden(x.size(0))
        emb = self.embed(x) # batch_size * seq_len * emb_dim 
        out, _ = self.lstm(emb, (h0, c0)) # out: seq_len * batch_size * hidden_dim
        out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # seq_len * batch_size * vocab_size
        return out

    def step(self, x, h, c):
        """
        Embeds input and applies LSTM one token at a time (seq_len = 1).

        Inputs: x, h, c
            - x: (batch_size, 1), sequence of tokens generated by generator
            - h: (1, batch_size, hidden_dim), lstm hidden state
            - c: (1, batch_size, hidden_dim), lstm cell state
        Outputs: out, h, c
            - out: (batch_size, 1, vocab_size), lstm output prediction
            - h: (1, batch_size, hidden_dim), lstm hidden state
            - c: (1, batch_size, hidden_dim), lstm cell state 
        """
        self.lstm.flatten_parameters()
        emb = self.embed(x) # batch_size * 1 * emb_dim
        out, (h, c) = self.lstm(emb, (h, c)) # out: batch_size * 1 * hidden_dim
        out = self.log_softmax(self.fc(out.contiguous().view(-1, self.hidden_dim))) # batch_size * vocab_size
        return out, h, c

    def init_hidden(self, batch_size):
        h = torch.zeros((1, batch_size, self.hidden_dim))
        c = torch.zeros((1, batch_size, self.hidden_dim))
        if self.use_cuda:
            h, c = h.cuda(), c.cuda()
        return h, c
    
    def init_params(self):
        for param in self.parameters():
            param.data.normal_(0, 1)

    def sample(self, batch_size, seq_len):
        """
        Samples the network and returns a batch of samples of length seq_len.

        Outputs: out
            - out: (batch_size * seq_len)
        """
        samples = []
        h, c = self.init_hidden(batch_size)
        x = torch.zeros(batch_size, 1, dtype=torch.int64)
        if self.use_cuda:
            x = x.cuda()
        for _ in range(seq_len):
            out, h, c = self.step(x, h, c)
            prob = torch.exp(out)
            x = torch.multinomial(prob, 1)
            samples.append(x)
        out = torch.cat(samples, dim=1) # along the batch_size dimension
        return out

# Rollout

In [None]:
class Rollout(object):
    """ Rollout Policy """

    def __init__(self, model, update_rate):
        self.ori_model = model
        self.own_model = copy.deepcopy(model)
        self.update_rate = update_rate

    def get_reward(self, x, num, discriminator):
        """
        Inputs: x, num, discriminator
            - x: (batch_size, seq_len) input data
            - num: rollout number
            - discriminator: discrimanator model
        """
        rewards = []
        batch_size = x.size(0)
        seq_len = x.size(1)
        for i in range(num):
            for l in range(1, seq_len):
                data = x[:, 0:l]
                samples = self.own_model.sample(batch_size, seq_len, data)
                pred = discriminator(samples)
                pred = pred.cpu().data[:,1].numpy()
                if i == 0:
                    rewards.append(pred)
                else:
                    rewards[l-1] += pred

            # for the last token
            pred = discriminator(x)
            pred = pred.cpu().data[:, 1].numpy()
            if i == 0:
                rewards.append(pred)
            else:
                rewards[seq_len-1] += pred
        rewards = np.transpose(np.array(rewards)) / (1.0 * num) # batch_size * seq_len
        return rewards

    def update_params(self):
        dic = {}
        for name, param in self.ori_model.named_parameters():
            dic[name] = param.data
        for name, param in self.own_model.named_parameters():
            if name.startswith('emb'):
                param.data = dic[name]
            else:
                param.data = self.update_rate * param.data + (1 - self.update_rate) * dic[name]

# Lossfunc

In [None]:
class PGLoss(nn.Module):
    """
    Pseudo-loss that gives corresponding policy gradients (on calling .backward()) 
    for adversial training of Generator
    """

    def __init__(self):
        super(PGLoss, self).__init__()

    def forward(self, pred, target, reward):
        """
        Inputs: pred, target, reward
            - pred: (batch_size, seq_len), 
            - target : (batch_size, seq_len), 
            - reward : (batch_size, ), reward of each whole sentence
        """
        one_hot = torch.zeros(pred.size(), dtype=torch.uint8)
        if pred.is_cuda:
            one_hot = one_hot.cuda()
        one_hot.scatter_(1, target.data.view(-1, 1), 1)
        loss = torch.masked_select(pred, one_hot)
        loss = loss * reward.contiguous().view(-1)
        loss = -torch.sum(loss)
        return loss

# Pretrain

In [None]:
def pretrain_gene(gen, data_iter, criterion, 
                  optimizer, epochs, gen_pretrain_train_loss):
    print("Pretrain Generator")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for epoch in range(epochs):
        total_loss = 0.
        for data, target in data_iter:
#             if cuda:
#                 data, target = data.cuda(), target.cuda()
            data, target = data.to(device), target.to(device)
            target = target.contiguous().view(-1)
#             print("--------")
#             print("target.size() = {}".format(target.size()))
#             print("target = {}".format(target))
#             print("data.size() = {}".format(data.size()))
#             print("data = {}".format(data))
            output = gen(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        data_iter.reset()
    avg_loss = total_loss / len(data_iter)
    print("Epoch {}, train loss: {:.5f}".format(epoch, avg_loss))
    gen_pretrain_train_loss.append(avg_loss)


# Eval

In [None]:
def eval_gene(model, data_iter, criterion):
    """
    Evaluate generator with NLL
    """
    total_loss = 0.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for data, target in data_iter:
#             if torch.cuda.is_available():
#                 data, target = data.to("cuda"), target.to("cuda")
            data, target = data.to(device).long(), target.to(device).long()
            target = target.contiguous().view(-1)
            pred = model(data)
            loss = criterion(pred, target)
            total_loss += loss.item()
    avg_loss = total_loss / len(data_iter)
    return avg_loss

def eval_disc(model, data_iter, criterion):
    """
    Evaluate discriminator, dropout is enabled
    """
    correct = 0
    total_loss = 0.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for data, target in data_iter:
#             if  cuda:
#                 data, target = data.cuda(), target.cuda()
            data, target = data.to(device).long(), target.to(device).long()
            target = target.contiguous().view(-1)
            output = model(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            loss = criterion(output, target)
            total_loss += loss.item()
    avg_loss = total_loss / len(data_iter)
    acc = correct.item() / data_iter.data_num
    return avg_loss, acc

# Main train

In [None]:
def train_gene(gen, dis, rollout, pg_loss, optimizer, epochs):
    """
    Train generator with the guidance of policy gradient
    """
    for epoch in range(epochs):
        # construct the input to the genrator, add zeros before samples and delete the last column
        samples = generator.sample(batch_size, g_seq_len)
        zeros = torch.zeros(batch_size, 1, dtype=torch.int64)
        if samples.is_cuda:
            zeros = zeros.cuda()
        inputs = torch.cat([zeros, samples.data], dim = 1)[:, :-1].contiguous()
        targets = samples.data.contiguous().view((-1,))

        # calculate the reward
        rewards = torch.tensor(rollout.get_reward(samples, n_rollout, dis))
        if args.cuda:
            rewards = rewards.cuda()

        # update generator
        output = gen(inputs)
        loss = pg_loss(output, targets, rewards)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
def train_disc(dis, gen, criterion, optimizer, epochs, 
        dis_adversarial_train_loss, dis_adversarial_train_acc):
    
    generate_samples(gen,  batch_size,  n_samples, NEGATIVE_FILE)
    data_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,  batch_size)
    
    for epoch in range(epochs):
        correct = 0
        total_loss = 0.
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        for data, target in data_iter:
#             if  cuda:
#                 data, target = data.cuda(), target.cuda()
            data, target = data.to(device).long(), target.to(device).long()
            target = target.contiguous().view(-1)
            output = dis(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            loss = criterion(output, target)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        data_iter.reset()
        avg_loss = total_loss / len(data_iter)
        acc = correct.item() / data_iter.data_num
        print("Epoch {}, train loss: {:.5f}, train acc: {:.3f}".format(epoch, avg_loss, acc))
        dis_adversarial_train_loss.append(avg_loss)
        dis_adversarial_train_acc.append(acc)


In [None]:
def adversarial_train(gen, dis, rollout, pg_loss, nll_loss, gen_optimizer, dis_optimizer, 
        dis_adversarial_train_loss, dis_adversarial_train_acc):
    """
    Adversarially train generator and discriminator
    """
    # train generator for g_steps
    print("#Train generator")
    for i in range( g_steps):
        print("##G-Step {}".format(i))
        train_gene(gen, dis, rollout, pg_loss, gen_optimizer,  gk_epochs,  )

    # train discriminator for d_steps
    print("#Train discriminator")
    for i in range( d_steps):
        print("##D-Step {}".format(i))
        train_disc(dis, gen, nll_loss, dis_optimizer,  dk_epochs, 
            dis_adversarial_train_loss, dis_adversarial_train_acc,  )

    # update roll-out model
    rollout.update_params()


In [None]:
# l = np.loadtxt("./gene.data")
# len(l)

# Main

In [None]:
# Parse arguments

cuda = torch.cuda.is_available()
torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)
# if not hpc:
#      data_path = ''
# POSITIVE_FILE =  data_path + POSITIVE_FILE
# NEGATIVE_FILE =  data_path + NEGATIVE_FILE
POSITIVE_FILE = os.path.join(data_path ,POSITIVE_FILE)
NEGATIVE_FILE = NEGATIVE_FILE
# Set models, criteria, optimizers
generator = Generator( vocab_size, g_embed_dim, g_hidden_dim,  cuda)
discriminator = Discriminator(d_num_class,  vocab_size, d_embed_dim, d_filter_sizes, d_num_filters, d_dropout_prob)
target_lstm = TargetLSTM( vocab_size, g_embed_dim, g_hidden_dim,  cuda)
nll_loss = nn.NLLLoss()
pg_loss = PGLoss()
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    target_lstm = target_lstm.cuda()
    nll_loss = nll_loss.cuda()
    pg_loss = pg_loss.cuda()
    cudnn.benchmark = True
gen_optimizer = optim.Adam(params=generator.parameters(), lr= gen_lr)
dis_optimizer = optim.SGD(params=discriminator.parameters(), lr= dis_lr)

# Container of experiment data
gen_pretrain_train_loss = []
gen_pretrain_eval_loss = []
dis_pretrain_train_loss = []
dis_pretrain_train_acc = []
dis_pretrain_eval_loss = []
dis_pretrain_eval_acc = []
gen_adversarial_eval_loss = []
dis_adversarial_train_loss = []
dis_adversarial_train_acc = []
dis_adversarial_eval_loss = []
dis_adversarial_eval_acc = []
'''
# Generate toy data using target LSTM
print('#####################################################')
print('Generating data ...')
print('#####################################################\n\n')
# generate_samples(target_lstm,  batch_size,  n_samples, POSITIVE_FILE)
print("POSITIVE_FILE = {}".format(POSITIVE_FILE))
print("NEGATIVE_FILE = {}".format(NEGATIVE_FILE))
# Pre-train generator using MLE
print('#####################################################')
print('Start pre-training generator with MLE...')
print('#####################################################\n')
gen_data_iter = GenDataIter(POSITIVE_FILE,  batch_size)
for i in range( g_pretrain_steps):
    print("G-Step {}".format(i))
    pretrain_gene(generator, gen_data_iter, nll_loss, 
        gen_optimizer,  gk_epochs, gen_pretrain_train_loss)
    generate_samples(generator,  batch_size,  n_samples, NEGATIVE_FILE)
    eval_iter = GenDataIter(NEGATIVE_FILE,  batch_size)
    gen_loss = eval_gene(target_lstm, eval_iter, nll_loss)
    gen_pretrain_eval_loss.append(gen_loss)
    print("eval loss: {:.5f}\n".format(gen_loss))
print('#####################################################\n\n')
'''
# Pre-train discriminator
print('#####################################################')
print('Start pre-training discriminator...')
print('#####################################################\n')
for i in range( d_pretrain_steps):
    print("D-Step {}".format(i))
    train_disc(discriminator, generator, nll_loss, 
        dis_optimizer,  dk_epochs, dis_adversarial_train_loss, dis_adversarial_train_acc)
    generate_samples(generator,  batch_size,  n_samples, NEGATIVE_FILE)
    eval_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,  batch_size)
    dis_loss, dis_acc = eval_disc(discriminator, eval_iter, nll_loss)
    dis_pretrain_eval_loss.append(dis_loss)
    dis_pretrain_eval_acc.append(dis_acc)
    print("eval loss: {:.5f}, eval acc: {:.3f}\n".format(dis_loss, dis_acc))
print('#####################################################\n\n')

# Adversarial training
print('#####################################################')
print('Start adversarial training...')
print('#####################################################\n')
rollout = Rollout(generator,  update_rate)
for i in range(rounds):
    print("Round {}".format(i))
    adversarial_train(generator, discriminator, rollout, 
        pg_loss, nll_loss, gen_optimizer, dis_optimizer, 
        dis_adversarial_train_loss, dis_adversarial_train_acc)
    generate_samples(generator,  batch_size,  n_samples, NEGATIVE_FILE)
    gen_eval_iter = GenDataIter(NEGATIVE_FILE,  batch_size)
    dis_eval_iter = DisDataIter(POSITIVE_FILE, NEGATIVE_FILE,  batch_size)
    gen_loss = eval_gene(target_lstm, gen_eval_iter, nll_loss)
    gen_adversarial_eval_loss.append(gen_loss)
    dis_loss, dis_acc = eval_disc(discriminator, dis_eval_iter, nll_loss)
    dis_adversarial_eval_loss.append(dis_loss)
    dis_adversarial_eval_acc.append(dis_acc)
    print("gen eval loss: {:.5f}, dis eval loss: {:.5f}, dis eval acc: {:.3f}\n"
        .format(gen_loss, dis_loss, dis_acc))
print('#####################################################')
print('Model Save...')
print('#####################################################\n')
    
torch.save(generator.state_dict(), 'generator.pth')   
torch.save(discriminator.state_dict(), 'discriminator.pth')
torch.save(target_lstm.state_dict(), 'target_lstm.pth')
print("complete")

# Save experiment data
with open('experiment.pkl', 'wb') as f:
    pkl.dump(
        (gen_pretrain_train_loss,
            gen_pretrain_eval_loss,
            dis_pretrain_train_loss,
            dis_pretrain_train_acc,
            dis_pretrain_eval_loss,
            dis_pretrain_eval_acc,
            gen_adversarial_eval_loss,
            dis_adversarial_train_loss,
            dis_adversarial_train_acc,
            dis_adversarial_eval_loss,
            dis_adversarial_eval_acc),
        f,
        protocol=pkl.HIGHEST_PROTOCOL
    )