In [1]:
from __future__ import print_function
from math import ceil
import numpy as np
import sys
import pdb

import torch
import torch.optim as optim
import torch.nn as nn

from model import generator
from model import discriminator_cnn
import helpers
from data_iter import DisDataIter

CUDA = True
VOCAB_SIZE = 5000
MAX_SEQ_LEN = 20
START_LETTER = 0
BATCH_SIZE = 64
MLE_TRAIN_EPOCHS = 100
ADV_TRAIN_EPOCHS = 50
POS_NEG_SAMPLES = 10000

GEN_EMBEDDING_DIM = 32
GEN_HIDDEN_DIM = 32
DIS_EMBEDDING_DIM = 64
DIS_HIDDEN_DIM = 64
FILTER_SIZES = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20]
NUM_FILTERS = [100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160]

POSITIVE_FILE = 'real.data'
NEGATIVE_FILE = 'gene.data'

oracle_samples_path = './oracle_samples.trc'
oracle_state_dict_path = './oracle_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_gen_path = './gen_MLEtrain_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_dis_path = './dis_pretrain_EMBDIM_64_HIDDENDIM64_VOCAB5000_MAXSEQLEN20.trc'

def generate_samples(model, batch_size, generated_num, output_file):
    samples = []
    for _ in range(int(generated_num / batch_size)):
        sample = model.sample(batch_size).cpu().numpy().tolist()
        samples.extend(sample)

    with open(output_file, 'w') as fout:
        for sample in samples:
            string = ''.join([str(s) for s in sample])
            fout.write('{}\n'.format(string))



def train_generator_MLE(gen, gen_opt, oracle, real_data_samples, epochs):
    """
    Max Likelihood Pretraining for the generator
    """
    for epoch in range(epochs):
        print('epoch %d : ' % (epoch + 1), end='')
        sys.stdout.flush()
        total_loss = 0

        for i in range(0, POS_NEG_SAMPLES, BATCH_SIZE):
            inp, target = helpers.prepare_generator_batch(real_data_samples[i:i + BATCH_SIZE], start_letter=START_LETTER,
                                                          gpu=CUDA)
            gen_opt.zero_grad()
            loss = gen.batchNLLLoss(inp, target)
            loss.backward()
            gen_opt.step()

            total_loss += loss.data.item()

            if (i / BATCH_SIZE) % ceil(
                            ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                print('.', end='')
                sys.stdout.flush()

        # each loss in a batch is loss per sample
        total_loss = total_loss / ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / MAX_SEQ_LEN

        # sample from generator and compute oracle NLL
        oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                                   start_letter=START_LETTER, gpu=CUDA)

        print(' average_train_NLL = %.4f, oracle_sample_NLL = %.4f' % (total_loss, oracle_loss))


def train_generator_PG(gen, gen_opt, oracle, dis, num_batches):
    """
    The generator is trained using policy gradients, using the reward from the discriminator.
    Training is done for num_batches batches.
    """

    for batch in range(num_batches):
        s = gen.sample(BATCH_SIZE*2)        # 64 works best
        inp, target = helpers.prepare_generator_batch(s, start_letter=START_LETTER, gpu=CUDA)
        rewards = dis(target)
        rewards = rewards.data[:,1]
        #rewards = dis.batchClassify(target)
|
        gen_opt.zero_grad()
        pg_loss = gen.batchPGLoss(inp, target, rewards)
        pg_loss.backward()
        gen_opt.step()

    # sample from generator and compute oracle NLL
    oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                                   start_letter=START_LETTER, gpu=CUDA)

    print(' oracle_sample_NLL = %.4f' % oracle_loss)

def eval_discriminator(model, data_iter, criterion):
    """
    Evaluate discriminator, dropout is enabled
    """
    correct = 0
    total_loss = 0.
    with torch.no_grad():
        for data, target in data_iter:
            if CUDA:
                data, target = data.cuda(), target.cuda()
            target = target.contiguous().view(-1)
            output = model(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.data).cpu().sum()
            loss = criterion(output, target)
            total_loss += loss.item()
    avg_loss = total_loss / len(data_iter)
    acc = correct.item() / data_iter.data_num
    return avg_loss, acc
    
def train_discriminator(discriminator, dis_opt, real_data_samples, generator, oracle, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val, neg_val, gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = helpers.prepare_discriminator_data(real_data_samples, s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE].cpu().numpy(), dis_target[i:i + BATCH_SIZE].cpu().numpy()
                dis_opt.zero_grad()
                inp = torch.tensor(inp).cuda()
                target = torch.tensor(target).cuda()
                out = discriminator(inp)
                loss_fn = nn.NLLLoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum(out.max(1)[1]==target).data.item()
                if (i / BATCH_SIZE) % ceil(ceil(2 * POS_NEG_SAMPLES / float(
                        BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum(val_pred.max(1)[1]==val_target).data.item()/200.))

In [2]:
    oracle = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    oracle.load_state_dict(torch.load(oracle_state_dict_path))
    oracle_samples = torch.load(oracle_samples_path).type(torch.LongTensor)
    # a new oracle can be generated by passing oracle_init=True in the generator constructor
    # samples for the new oracle can be generated using helpers.batchwise_sample()

    gen = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    #num_classes, vocab_size, emb_dim, filter_sizes, num_filters, dropout
    #2, VOCAB_SIZE, DIS_EMBEDDING_DIM, [3, 4, 5], [100, 100, 100], 0.5
    dis = discriminator_cnn.Discriminator(2, VOCAB_SIZE, DIS_EMBEDDING_DIM, FILTER_SIZES, NUM_FILTERS, 0.5)

    if CUDA:
        oracle = oracle.cuda()
        gen = gen.cuda()
        dis = dis.cuda()
        oracle_samples = oracle_samples.cuda()

    # GENERATOR MLE TRAINING
    print('Starting Generator MLE Training...')
    gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
    #train_generator_MLE(gen, gen_optimizer, oracle, oracle_samples, MLE_TRAIN_EPOCHS)

    # torch.save(gen.state_dict(), pretrained_gen_path)
    gen.load_state_dict(torch.load(pretrained_gen_path))

    # PRETRAIN DISCRIMINATOR
    #print('\nStarting Discriminator Training...')
    #dis_optimizer = optim.Adam(dis.parameters())
    #train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 50, 3)

    # torch.save(dis.state_dict(), pretrained_dis_path)
    # dis.load_state_dict(torch.load(pretrained_dis_path))

    # ADVERSARIAL TRAINING
    

Starting Generator MLE Training...


In [3]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f9c8a18a650>

In [4]:
    dis = discriminator_cnn.Discriminator(2, VOCAB_SIZE, DIS_EMBEDDING_DIM, FILTER_SIZES, NUM_FILTERS, 0.2)

    if CUDA:
        dis = dis.cuda()

In [5]:
    # PRETRAIN DISCRIMINATOR
    print('\nStarting Discriminator Training...')
    dis_optimizer = optim.SGD(dis.parameters(), lr=1e-2)
    train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 50, 3)


Starting Discriminator Training...
d-step 1 epoch 1 : ..



........ average_loss = 0.6935, train_acc = 0.4970, val_acc = 0.4800
d-step 1 epoch 2 : .......... average_loss = 0.6934, train_acc = 0.4993, val_acc = 0.5100
d-step 1 epoch 3 : .......... average_loss = 0.6933, train_acc = 0.4993, val_acc = 0.4600
d-step 2 epoch 1 : .......... average_loss = 0.6932, train_acc = 0.5008, val_acc = 0.5000
d-step 2 epoch 2 : .......... average_loss = 0.6930, train_acc = 0.5050, val_acc = 0.4900
d-step 2 epoch 3 : .......... average_loss = 0.6930, train_acc = 0.5072, val_acc = 0.5000
d-step 3 epoch 1 : .......... average_loss = 0.6930, train_acc = 0.5052, val_acc = 0.5550
d-step 3 epoch 2 : .......... average_loss = 0.6929, train_acc = 0.5078, val_acc = 0.4650
d-step 3 epoch 3 : .......... average_loss = 0.6928, train_acc = 0.5084, val_acc = 0.5400
d-step 4 epoch 1 : .......... average_loss = 0.6930, train_acc = 0.5050, val_acc = 0.5800
d-step 4 epoch 2 : .......... average_loss = 0.6928, train_acc = 0.5110, val_acc = 0.5250
d-step 4 epoch 3 : .......... a

d-step 31 epoch 2 : .......... average_loss = 0.6187, train_acc = 0.6661, val_acc = 0.6400
d-step 31 epoch 3 : .......... average_loss = 0.6135, train_acc = 0.6710, val_acc = 0.6150
d-step 32 epoch 1 : .......... average_loss = 0.6146, train_acc = 0.6671, val_acc = 0.6450
d-step 32 epoch 2 : .......... average_loss = 0.6089, train_acc = 0.6737, val_acc = 0.6350
d-step 32 epoch 3 : .......... average_loss = 0.6035, train_acc = 0.6823, val_acc = 0.6450
d-step 33 epoch 1 : .......... average_loss = 0.6064, train_acc = 0.6742, val_acc = 0.6350
d-step 33 epoch 2 : .......... average_loss = 0.6009, train_acc = 0.6818, val_acc = 0.6650
d-step 33 epoch 3 : .......... average_loss = 0.5953, train_acc = 0.6883, val_acc = 0.6400
d-step 34 epoch 1 : .......... average_loss = 0.5970, train_acc = 0.6810, val_acc = 0.6350
d-step 34 epoch 2 : .......... average_loss = 0.5911, train_acc = 0.6878, val_acc = 0.6200
d-step 34 epoch 3 : .......... average_loss = 0.5851, train_acc = 0.6939, val_acc = 0.6200

In [None]:
    print('\nStarting Adversarial Training...')
    oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                               start_letter=START_LETTER, gpu=CUDA)
    print('\nInitial Oracle Sample Loss : %.4f' % oracle_loss)

    for epoch in range(ADV_TRAIN_EPOCHS):
        print('\n--------\nEPOCH %d\n--------' % (epoch+1))
        # TRAIN GENERATOR
        print('\nAdversarial Training Generator : ', end='')
        sys.stdout.flush()
        train_generator_PG(gen, gen_optimizer, oracle, dis, 1)

        # TRAIN DISCRIMINATOR
        print('\nAdversarial Training Discriminator : ')
        train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 5, 3)


Starting Adversarial Training...

Initial Oracle Sample Loss : 10.9589

--------
EPOCH 1
--------

Adversarial Training Generator : 



 oracle_sample_NLL = 11.0608

Adversarial Training Discriminator : 
d-step 1 epoch 1 : .......... average_loss = 0.3325, train_acc = 0.8789, val_acc = 0.6700
d-step 1 epoch 2 : .......... average_loss = 0.2954, train_acc = 0.8990, val_acc = 0.6600
d-step 1 epoch 3 : .......... average_loss = 0.2657, train_acc = 0.9163, val_acc = 0.6600
d-step 2 epoch 1 : .......... average_loss = 0.3144, train_acc = 0.8886, val_acc = 0.6750
d-step 2 epoch 2 : .......... average_loss = 0.2756, train_acc = 0.9100, val_acc = 0.6600
d-step 2 epoch 3 : .......... average_loss = 0.2450, train_acc = 0.9267, val_acc = 0.6600
d-step 3 epoch 1 : .......... average_loss = 0.2908, train_acc = 0.9013, val_acc = 0.6650
d-step 3 epoch 2 : .......... average_loss = 0.2515, train_acc = 0.9225, val_acc = 0.6850
d-step 3 epoch 3 : .......... average_loss = 0.2198, train_acc = 0.9379, val_acc = 0.6800
d-step 4 epoch 1 : .......... average_loss = 0.2891, train_acc = 0.9040, val_acc = 0.6700
d-step 4 epoch 2 : .......... av

d-step 3 epoch 3 : .......... average_loss = 0.0365, train_acc = 0.9954, val_acc = 0.5450
d-step 4 epoch 1 : .......... average_loss = 0.1023, train_acc = 0.9736, val_acc = 0.5300
d-step 4 epoch 2 : .......... average_loss = 0.0556, train_acc = 0.9895, val_acc = 0.5350
d-step 4 epoch 3 : .......... average_loss = 0.0346, train_acc = 0.9958, val_acc = 0.5300
d-step 5 epoch 1 : .......... average_loss = 0.1067, train_acc = 0.9739, val_acc = 0.5300
d-step 5 epoch 2 : .......... average_loss = 0.0593, train_acc = 0.9881, val_acc = 0.5300
d-step 5 epoch 3 : .......... average_loss = 0.0376, train_acc = 0.9947, val_acc = 0.5350

--------
EPOCH 7
--------

Adversarial Training Generator :  oracle_sample_NLL = 11.2863

Adversarial Training Discriminator : 
d-step 1 epoch 1 : .......... average_loss = 0.1022, train_acc = 0.9734, val_acc = 0.5600
d-step 1 epoch 2 : .......... average_loss = 0.0521, train_acc = 0.9902, val_acc = 0.5500
d-step 1 epoch 3 : .......... average_loss = 0.0321, train_ac