In [1]:
import os
import time
import math
import numpy as np
import random
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from utils import to_gpu, Corpus, batchify, train_ngram_lm, get_ppl, create_exp_dir
from models import Seq2Seq, MLP_D, MLP_D_local, MLP_G
from bleu_self import *
from bleu_test import *
import datetime
now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

In [2]:
import argparse

parser = argparse.ArgumentParser(description='TILGAN for unconditional generation')

In [3]:
seed=4
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [4]:
#data_path="data/MS_COCO"
data_path="../../AGNews/unlabelled"
save="./results/AGNews_results"
maxlen=40
batch_size=16
eval_batch_size = 32
noise_seq_length = 15
add_noise=True #what does this do? - question applies to most parameters
emsize=512
nhidden=512
nlayers=2
nheads=4
nff=1024
aehidden=56
noise_r=0.05
hidden_init=True
dropout=0.3
gpu=True
z_size=100
arch_g='300-300'
gan_g_activation=False
arch_d='300-300'
gan_d_local=True
gan_d_local_windowsize=3
arch_d_local='300-300'
lr_ae=0.12
lr_gan_e=1e-04
beta1=0.5
lr_gan_g=4e-04
lr_gan_d=1e-04
epochs=2
sample=True
clip=1
log_interval=100
gan_lambda=0.1
niters_gan_d=1
niters_gan_g=1
niters_gan_ae=1
niters_gan_dec=1
niters_gan_schedule=''
niters_ae=1
gan_type='kl'
enhance_dec=True
gan_gp_lambda=1
vocab_size=0
lowercase=True
noise_anneal=0.9995
min_epochs=12
no_earlystopping=True
patience=2
gan_clamp=0.01
gan_gp_lambda=1

In [5]:
corpus = Corpus(data_path,
                maxlen=maxlen,
                vocab_size=vocab_size,
                lowercase=lowercase)

original vocab 59266; pruned to 59270
Number of sentences dropped from ../../AGNews/unlabelled/train.txt: 0 out of 41297 total
Number of sentences dropped from ../../AGNews/unlabelled/test.txt: 0 out of 4000 total


In [6]:
# save arguments
ntokens = len(corpus.dictionary.word2idx)
print("Vocabulary Size: {}".format(ntokens))

Vocabulary Size: 59270


In [7]:
# exp dir
create_exp_dir(os.path.join(save), ['train.py', 'models.py', 'utils.py'],
        dict=corpus.dictionary.word2idx)

Experiment dir : ./results/AGNews_results


In [8]:
def logging(str, to_stdout=True):
    with open(os.path.join(save, 'log.txt'), 'a') as f:
        f.write(str + '\n')
    if to_stdout:
        print(str)

In [9]:
test_data = batchify(corpus.test, eval_batch_size, maxlen, shuffle=False)
train_data = batchify(corpus.train, batch_size, maxlen,  shuffle=True)

print("Loaded data!")

Loaded data!


In [10]:
print(len(test_data))
print(len(train_data))

125
2581


In [11]:
###############################################################################
# Build the models
###############################################################################
autoencoder = Seq2Seq(add_noise=add_noise,
                      emsize=emsize,
                      nhidden=nhidden,
                      ntokens=ntokens,
                      nlayers=nlayers,
                      nheads=nheads,
                      nff=nff,
                      aehidden=aehidden,
                      noise_r=noise_r,
                      hidden_init=hidden_init,
                      dropout=dropout,
                      gpu=True)
nlatent = aehidden * (maxlen+1)
gan_gen = MLP_G(ninput=z_size, noutput=nlatent, layers=arch_g, gan_g_activation=gan_g_activation)
gan_disc = MLP_D(ninput=nlatent, noutput=1, layers=arch_d)
gan_disc_local = MLP_D_local(ninput=gan_d_local_windowsize * aehidden, noutput=1, layers=arch_d_local)

optimizer_ae = optim.SGD(autoencoder.parameters(), lr=lr_ae)


optimizer_gan_e = optim.Adam(autoencoder.encoder.parameters(),
                             lr=lr_gan_e,
                             betas=(beta1, 0.999))
optimizer_gan_g = optim.Adam(gan_gen.parameters(),
                             lr=lr_gan_g,
                             betas=(beta1, 0.999))
optimizer_gan_d = optim.Adam(gan_disc.parameters(),
                             lr=lr_gan_d,
                             betas=(beta1, 0.999))
optimizer_gan_d_local = optim.Adam(gan_disc_local.parameters(),
                             lr=lr_gan_d,
                             betas=(beta1, 0.999))
optimizer_gan_dec = optim.Adam(autoencoder.decoder.parameters(),
                             lr=lr_gan_e,
                             betas=(beta1, 0.999))

autoencoder = autoencoder.to(device)
gan_gen = gan_gen.to(device)
gan_disc = gan_disc.to(device)
gan_disc_local = gan_disc_local.to(device)

In [12]:
def save_model():
    print("Saving models to {}".format(save))
    torch.save({
        "ae": autoencoder.state_dict(),
        "gan_g": gan_gen.state_dict(),
        "gan_d": gan_disc.state_dict(),
        "gan_d_local": gan_disc_local.state_dict()

        },
        os.path.join(save, "model.pt"))

In [13]:
def cal_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

In [14]:
def load_models():
    model_args = json.load(open(os.path.join(save, 'options.json'), 'r'))
    word2idx = json.load(open(os.path.join(save, 'vocab.json'), 'r'))
    idx2word = {v: k for k, v in word2idx.items()}

    print('Loading models from {}'.format(save))
    loaded = torch.load(os.path.join(save, "model.pt"))
    autoencoder.load_state_dict(loaded.get('ae'))
    gan_gen.load_state_dict(loaded.get('gan_g'))
    gan_disc.load_state_dict(loaded.get('gan_d'))
    gan_disc_local.load_state_dict(loaded.get('gan_d_local'))
    return model_args, idx2word, autoencoder, gan_gen, gan_disc

In [15]:
def evaluate_autoencoder(data_source, epoch):
    # Turn on evaluation mode which disables dropout.
    autoencoder.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary.word2idx)
    all_accuracies = 0
    bcnt = 0
    for i, batch in enumerate(data_source):
        source, target, lengths = batch
        with torch.no_grad():
            source = Variable(source.to(device))
            target = Variable(target.to(device))
            mask = target.gt(0)
            masked_target = target.masked_select(mask)
            # examples x ntokens
            output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens)

            # output: batch x seq_len x ntokens
            output = autoencoder(source, lengths, source, add_noise=add_noise, soft=False)
            flattened_output = output.view(-1, ntokens)

            masked_output = \
                flattened_output.masked_select(output_mask).view(-1, ntokens)
            total_loss += F.cross_entropy(masked_output, masked_target)

            # accuracy
            max_vals, max_indices = torch.max(masked_output, 1)
            accuracy = torch.mean(max_indices.eq(masked_target).float()).data.item()
            all_accuracies += accuracy
            bcnt += 1

        aeoutf = os.path.join(save, "autoencoder.txt")
        with open(aeoutf, "w") as f:
            max_values, max_indices = torch.max(output, 2)
            max_indices = \
                max_indices.view(output.size(0), -1).data.cpu().numpy()
            target = target.view(output.size(0), -1).data.cpu().numpy()
            for t, idx in zip(target, max_indices):
                # real sentence
                chars = " ".join([corpus.dictionary.idx2word[x] for x in t])
                f.write(chars + '\n')
                # autoencoder output sentence
                chars = " ".join([corpus.dictionary.idx2word[x] for x in idx])
                f.write(chars + '\n'*2)

    return total_loss.item() / len(data_source), all_accuracies/bcnt

In [16]:
def gen_fixed_noise(noise, to_save):
    gan_gen.eval()
    autoencoder.eval()

    fake_hidden = gan_gen(noise)
    max_indices = autoencoder.generate(fake_hidden, maxlen, sample=sample)

    with open(to_save, "w") as f:
        max_indices = max_indices.data.cpu().numpy()
        for idx in max_indices:
            # generated sentence
            words = [corpus.dictionary.idx2word[x] for x in idx]
            # truncate sentences to first occurrence of <eos>
            truncated_sent = []
            for w in words:
                if w != '<eos>':
                    truncated_sent.append(w)
                else:
                    break
            chars = " ".join(truncated_sent)
            f.write(chars + '\n')

In [17]:
def eval_bleu(gen_text_savepath):
    selfbleu = bleu_self(gen_text_savepath)
    real_text = os.path.join(data_path, "test.txt")
    testbleu = bleu_test(real_text, gen_text_savepath)
    return selfbleu, testbleu

In [18]:
def train_ae(epoch, batch, total_loss_ae, start_time, i):
    '''Train AE with the negative log-likelihood loss'''
    autoencoder.train()
    optimizer_ae.zero_grad()

    source, target, lengths = batch
    source = Variable(source.to(device))
    target = Variable(target.to(device))
    output = autoencoder(source, lengths, source, add_noise=add_noise, soft=False)

    mask = target.gt(0)
    masked_target = target.masked_select(mask)
    output_mask = mask.unsqueeze(1).expand(mask.size(0), ntokens)
    flat_output = output.view(-1, ntokens)
    masked_output = flat_output.masked_select(output_mask).view(-1, ntokens)
    loss = F.cross_entropy(masked_output, masked_target)
    loss.backward()
    torch.nn.utils.clip_grad_norm(autoencoder.parameters(), clip)
    train_ae_norm = cal_norm(autoencoder)
    optimizer_ae.step()

    total_loss_ae += loss.data.item()
    if i % log_interval == 0:
        probs = F.softmax(masked_output, dim=-1)
        max_vals, max_indices = torch.max(probs, 1)
        accuracy = torch.mean(max_indices.eq(masked_target).float()).data.item()
        cur_loss = total_loss_ae / log_interval
        elapsed = time.time() - start_time
        logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:08.6f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f} | acc {:8.2f} | train_ae_norm {:8.2f}'.format(
                epoch, i, len(train_data), 0,
                elapsed * 1000 / log_interval,
                cur_loss, math.exp(cur_loss), accuracy, train_ae_norm))

        total_loss_ae = 0
        start_time = time.time()
    return total_loss_ae, start_time

In [19]:
def train_gan_g(gan_type='kl'):
    gan_gen.train()
    optimizer_gan_g.zero_grad()

    z = Variable(torch.Tensor(batch_size, z_size).normal_(0, 1).to(device))
    fake_hidden = gan_gen(z)
    fake_score = gan_disc(fake_hidden)

    if gan_d_local:
        idx = random.randint(0, maxlen - gan_d_local_windowsize)
        fake_hidden_local = fake_hidden[:, idx * aehidden : (idx + gan_d_local_windowsize) * aehidden]
        fake_score_local = gan_disc_local(fake_hidden_local)

        if gan_type == 'kl':
            errG = -(torch.exp(fake_score.detach()).clamp(0.5, 2) * fake_score).mean() -(torch.exp(fake_score_local.detach()).clamp(0.5, 2) * fake_score_local).mean()
        else:
            errG = -fake_score.mean() -fake_score_local.mean()
    else:
        if gan_type == 'kl':
            errG = -(torch.exp(fake_score.detach()).clamp(0.5, 2) * fake_score).mean()
        else:
            errG = -fake_score.mean()


    errG *= gan_lambda
    errG.backward()
    optimizer_gan_g.step()

    return errG

In [20]:
def train_gan_dec(gan_type='kl'):
    autoencoder.decoder.train()
    optimizer_gan_dec.zero_grad()

    z = Variable(torch.Tensor(batch_size, z_size).normal_(0, 1).to(device))
    fake_hidden = gan_gen(z)

    # 1. decoder  - soft distribution
    enhance_source, max_indices= autoencoder.generate_enh_dec(fake_hidden, maxlen, sample=sample)
    # 2. soft distribution - > encoder  -> fake_hidden
    enhance_hidden = autoencoder(enhance_source, None, max_indices, add_noise=add_noise, soft=True, encode_only=True)
    fake_score = gan_disc(enhance_hidden)

    if gan_d_local:
        idx = random.randint(0, maxlen - gan_d_local_windowsize)
        fake_hidden_local = fake_hidden[:, idx * aehidden : (idx + gan_d_local_windowsize) * aehidden]
        fake_score_local = gan_disc_local(fake_hidden_local)

        if gan_type == 'kl':
            errG = -(torch.exp(fake_score.detach()).clamp(0.5, 2) * fake_score).mean() -(torch.exp(fake_score_local.detach()).clamp(0.5, 2) * fake_score_local).mean()
        else:
            errG = -fake_score.mean() -fake_score_local.mean()
    else:
        if gan_type == 'kl':
            errG = -(torch.exp(fake_score.detach()).clamp(0.5, 2) * fake_score).mean()
        else:
            errG = -fake_score.mean()


    errG *= gan_lambda
    errG.backward()
    optimizer_gan_dec.step()

    return errG

In [21]:
def grad_hook(grad):
    return grad * gan_lambda

In [22]:
''' Steal from https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py '''
def calc_gradient_penalty(netD, real_data, fake_data):
    bsz = real_data.size(0)
    alpha = torch.rand(bsz, 1)
    alpha = alpha.expand(bsz, real_data.size(1))  # only works for 2D XXX
    alpha = alpha.to(device)
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates)

    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                                    grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                    create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gan_gp_lambda
    return gradient_penalty

In [23]:
def train_gan_d(batch, gan_type='kl'):
    gan_disc.train()
    gan_disc_local.train()
    optimizer_gan_d.zero_grad()
    optimizer_gan_d_local.zero_grad()

    # + samples
    source, target, lengths = batch
    source = Variable(source.to(device))
    target = Variable(target.to(device))
    real_hidden = autoencoder(source, lengths, source, add_noise=add_noise, soft=False, encode_only=True)
    real_score = gan_disc(real_hidden.detach())

    idx = random.randint(0, maxlen - gan_d_local_windowsize)
    if gan_d_local:
        real_hidden_local = real_hidden[:, idx * aehidden : (idx + gan_d_local_windowsize) * aehidden]
        real_score_local = gan_disc_local(real_hidden_local)
        real_score += real_score_local


    if gan_type == 'wgan':
        errD_real = -real_score.mean()
    else: # kl or all
        errD_real = F.softplus(-real_score).mean()
    errD_real.backward()

    # - samples
    z = Variable(torch.Tensor(batch_size, z_size).normal_(0, 1).to(device))
    fake_hidden = gan_gen(z)
    fake_score = gan_disc(fake_hidden.detach())

    if gan_d_local:
        fake_hidden_local = fake_hidden[:, idx * aehidden : (idx + gan_d_local_windowsize) * aehidden]
        fake_score_local = gan_disc_local(fake_hidden_local)
        fake_score += fake_score_local

    if gan_type == 'wgan':
        errD_fake = fake_score.mean()
    else:  # kl or all
        errD_fake = F.softplus(fake_score).mean()
    errD_fake.backward()

    # gradient penalty
    if gan_type == 'wgan':
        gradient_penalty = calc_gradient_penalty(gan_disc, real_hidden.data, fake_hidden.data)
        gradient_penalty.backward()

    optimizer_gan_d.step()
    optimizer_gan_d_local.step()
    return errD_real + errD_fake, errD_real, errD_fake

In [24]:
def train_gan_d_into_ae(batch):
    autoencoder.train()
    optimizer_gan_e.zero_grad()

    source, target, lengths = batch
    source = Variable(source.to(device))
    target = Variable(target.to(device))
    real_hidden = autoencoder(source, lengths, source, add_noise=add_noise, soft=False, encode_only=True)

    if gan_d_local:
        idx = random.randint(0, maxlen - gan_d_local_windowsize)
        real_hidden_local = real_hidden[:, idx * aehidden : (idx + gan_d_local_windowsize) * aehidden]
        real_score_local = gan_disc_local(real_hidden_local)
        errD_real = gan_disc(real_hidden).mean() + real_score_local.mean()
    else:
        errD_real = gan_disc(real_hidden).mean()

    errD_real *= gan_lambda
    errD_real.backward()
    torch.nn.utils.clip_grad_norm(autoencoder.parameters(), clip)

    optimizer_gan_e.step()
    return errD_real

In [25]:
def train():
    logging("Training")
    train_data = batchify(corpus.train, batch_size, maxlen, shuffle=True)

    # gan: preparation
    if niters_gan_schedule != "":
        gan_schedule = [int(x) for x in niters_gan_schedule.split("-")]
    else:
        gan_schedule = []
    niter_gan = 1
    fixed_noise = Variable(torch.ones(eval_batch_size, z_size).normal_(0, 1).to(device))

    for epoch in range(1, epochs+1):
        # update gan training schedule
        if epoch in gan_schedule:
            niter_gan += 1
            logging("GAN training loop schedule: {}".format(niter_gan))

        total_loss_ae = 0
        epoch_start_time = time.time()
        start_time = time.time()
        niter = 0
        niter_g = 1

        while niter < len(train_data):
            # train ae
            for i in range(niters_ae):
                if niter >= len(train_data):
                    break  # end of epoch
                total_loss_ae, start_time = train_ae(epoch, train_data[niter],
                                total_loss_ae, start_time, niter)
                niter += 1
            # train gan
            for k in range(niter_gan):
                for i in range(niters_gan_d):
                    errD, errD_real, errD_fake = train_gan_d(
                            train_data[random.randint(0, len(train_data)-1)], gan_type)
                for i in range(niters_gan_ae):
                    train_gan_d_into_ae(train_data[random.randint(0, len(train_data)-1)])
                for i in range(niters_gan_g):
                    errG = train_gan_g(gan_type)
                if enhance_dec:
                    for i in range(niters_gan_dec):
                        errG_enh_dec = train_gan_dec()
                else:
                    errG_enh_dec = torch.Tensor([0])

            niter_g += 1
            if niter_g % log_interval == 0:
                logging('[{}/{}][{}/{}] Loss_D: {:.8f} (Loss_D_real: {:.8f} '
                        'Loss_D_fake: {:.8f}) Loss_G: {:.8f} Loss_Enh_Dec: {:.8f}'.format(
                         epoch, epochs, niter, len(train_data),
                         errD.data.item(), errD_real.data.item(),
                         errD_fake.data.item(), errG.data.item(), errG_enh_dec.data.item()))
        # eval
        test_loss, accuracy = evaluate_autoencoder(test_data, epoch)
        logging('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f} | '
                'test ppl {:5.2f} | acc {:3.3f}'.format(epoch,
                (time.time() - epoch_start_time), test_loss,
                math.exp(test_loss), accuracy))

        gen_text_savepath = os.path.join(save, "{:03d}_examplar_gen".format(epoch))
        gen_fixed_noise(fixed_noise, gen_text_savepath)
        if epoch % 5 == 0 or epoch % 4 == 0 or (epochs - epoch) <=2:
            selfbleu, testbleu = eval_bleu(gen_text_savepath)
            logging('bleu_self: [{:.8f},{:.8f},{:.8f},{:.8f},{:.8f}]'.format(selfbleu[0], selfbleu[1], selfbleu[2], selfbleu[3], selfbleu[4]))
            logging('bleu_test: [{:.8f},{:.8f},{:.8f},{:.8f},{:.8f}]'.format(testbleu[0], testbleu[1], testbleu[2], testbleu[3], testbleu[4]))

        if epoch % 15 == 0 or epoch == epochs-1:
            logging("New saving model: epoch {:03d}.".format(epoch))
            save_model()

In [26]:
train()

Training




| epoch   1 |     0/ 2581 batches | lr 0.000000 | ms/batch  2.32 | loss  0.12 | ppl     1.13 | acc     0.00 | train_ae_norm     1.00




[1/2][99/2581] Loss_D: 1.37442100 (Loss_D_real: 0.68377030 Loss_D_fake: 0.69065070) Loss_G: -0.00070368 Loss_Enh_Dec: -0.00080667
| epoch   1 |   100/ 2581 batches | lr 0.000000 | ms/batch 1098.52 | loss  9.56 | ppl 14252.43 | acc     0.07 | train_ae_norm     1.00
[1/2][199/2581] Loss_D: 1.37924290 (Loss_D_real: 0.68727779 Loss_D_fake: 0.69196516) Loss_G: -0.00046541 Loss_Enh_Dec: -0.00092866
| epoch   1 |   200/ 2581 batches | lr 0.000000 | ms/batch 1103.49 | loss  8.78 | ppl  6522.73 | acc     0.07 | train_ae_norm     1.00
[1/2][299/2581] Loss_D: 1.37634408 (Loss_D_real: 0.68719935 Loss_D_fake: 0.68914473) Loss_G: 0.00017021 Loss_Enh_Dec: -0.00075091
| epoch   1 |   300/ 2581 batches | lr 0.000000 | ms/batch 1098.44 | loss  8.51 | ppl  4972.58 | acc     0.08 | train_ae_norm     1.00
[1/2][399/2581] Loss_D: 1.37294817 (Loss_D_real: 0.68439507 Loss_D_fake: 0.68855309) Loss_G: 0.00091481 Loss_Enh_Dec: -0.00164446
| epoch   1 |   400/ 2581 batches | lr 0.000000 | ms/batch 1099.17 | loss 

KeyboardInterrupt: 