In [4]:
from __future__ import print_function
from math import ceil
import numpy as np
import sys
import pdb

import torch
import torch.optim as optim
import torch.nn as nn

import generator
import discriminator
import helpers


CUDA = True
VOCAB_SIZE = 5000
MAX_SEQ_LEN = 20
START_LETTER = 0
BATCH_SIZE = 32
MLE_TRAIN_EPOCHS = 100
ADV_TRAIN_EPOCHS = 50
POS_NEG_SAMPLES = 10000

GEN_EMBEDDING_DIM = 32
GEN_HIDDEN_DIM = 32
DIS_EMBEDDING_DIM = 64
DIS_HIDDEN_DIM = 64

oracle_samples_path = './oracle_samples.trc'
oracle_state_dict_path = './oracle_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_gen_path = './gen_MLEtrain_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_dis_path = './dis_pretrain_EMBDIM_64_HIDDENDIM64_VOCAB5000_MAXSEQLEN20.trc'


def train_generator_MLE(gen, gen_opt, oracle, real_data_samples, epochs):
    """
    Max Likelihood Pretraining for the generator
    """
    for epoch in range(epochs):
        print('epoch %d : ' % (epoch + 1), end='')
        sys.stdout.flush()
        total_loss = 0

        for i in range(0, POS_NEG_SAMPLES, BATCH_SIZE):
            inp, target = helpers.prepare_generator_batch(real_data_samples[i:i + BATCH_SIZE], start_letter=START_LETTER,
                                                          gpu=CUDA)
            gen_opt.zero_grad()
            loss = gen.batchNLLLoss(inp, target)
            loss.backward()
            gen_opt.step()

            total_loss += loss.data.item()

            if (i / BATCH_SIZE) % ceil(
                            ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                print('.', end='')
                sys.stdout.flush()

        # each loss in a batch is loss per sample
        total_loss = total_loss / ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / MAX_SEQ_LEN

        # sample from generator and compute oracle NLL
        oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                                   start_letter=START_LETTER, gpu=CUDA)

        print(' average_train_NLL = %.4f, oracle_sample_NLL = %.4f' % (total_loss, oracle_loss))


def train_generator_PG(gen, gen_opt, oracle, dis, num_batches):
    """
    The generator is trained using policy gradients, using the reward from the discriminator.
    Training is done for num_batches batches.
    """

    for batch in range(num_batches):
        s = gen.sample(BATCH_SIZE*2)        # 64 works best
        inp, target = helpers.prepare_generator_batch(s, start_letter=START_LETTER, gpu=CUDA)
        rewards = dis.batchClassify(target)

        gen_opt.zero_grad()
        pg_loss = gen.batchPGLoss(inp, target, rewards)
        pg_loss.backward()
        gen_opt.step()

    # sample from generator and compute oracle NLL
    oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                                   start_letter=START_LETTER, gpu=CUDA)

    print(' oracle_sample_NLL = %.4f' % oracle_loss)


def train_discriminator(discriminator, dis_opt, real_data_samples, generator, oracle, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val, neg_val, gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = helpers.prepare_discriminator_data(real_data_samples, s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item()
                print(out.size())
                print(((out>0.5)==(target>0.5)).size())
                print(total_acc)
                if (i / BATCH_SIZE) % ceil(ceil(2 * POS_NEG_SAMPLES / float(
                        BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum((val_pred>0.5)==(val_target>0.5)).data.item()/200.))

ModuleNotFoundError: No module named 'generator'

In [8]:
train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 50, 3)

d-step 1 epoch 1 : torch.Size([32])
torch.Size([32])
32
.torch.Size([32])
torch.Size([32])
64
torch.Size([32])
torch.Size([32])
96
torch.Size([32])
torch.Size([32])
128
torch.Size([32])
torch.Size([32])
160
torch.Size([32])
torch.Size([32])
192
torch.Size([32])
torch.Size([32])
224
torch.Size([32])
torch.Size([32])
256
torch.Size([32])
torch.Size([32])
288
torch.Size([32])
torch.Size([32])
320
torch.Size([32])
torch.Size([32])
352
torch.Size([32])
torch.Size([32])
383
torch.Size([32])
torch.Size([32])
415
torch.Size([32])
torch.Size([32])
447
torch.Size([32])
torch.Size([32])
479
torch.Size([32])
torch.Size([32])
511
torch.Size([32])
torch.Size([32])
543
torch.Size([32])
torch.Size([32])
575
torch.Size([32])
torch.Size([32])
606
torch.Size([32])
torch.Size([32])
638
torch.Size([32])
torch.Size([32])
670
torch.Size([32])
torch.Size([32])
701
torch.Size([32])
torch.Size([32])
732
torch.Size([32])
torch.Size([32])
764
torch.Size([32])
torch.Size([32])
796
torch.Size([32])
torch.Size([32])

KeyboardInterrupt: 

In [2]:
    oracle = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    oracle.load_state_dict(torch.load(oracle_state_dict_path))
    oracle_samples = torch.load(oracle_samples_path).type(torch.LongTensor)
    # a new oracle can be generated by passing oracle_init=True in the generator constructor
    # samples for the new oracle can be generated using helpers.batchwise_sample()

    gen = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
    dis = discriminator.Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)

    if CUDA:
        oracle = oracle.cuda()
        gen = gen.cuda()
        dis = dis.cuda()
        oracle_samples = oracle_samples.cuda()

    # GENERATOR MLE TRAINING
    print('Starting Generator MLE Training...')
    gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
    train_generator_MLE(gen, gen_optimizer, oracle, oracle_samples, MLE_TRAIN_EPOCHS)

    # torch.save(gen.state_dict(), pretrained_gen_path)
    # gen.load_state_dict(torch.load(pretrained_gen_path))

    # PRETRAIN DISCRIMINATOR
    print('\nStarting Discriminator Training...')
    dis_optimizer = optim.Adagrad(dis.parameters())
    train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 50, 3)

    # torch.save(dis.state_dict(), pretrained_dis_path)
    # dis.load_state_dict(torch.load(pretrained_dis_path))

    # ADVERSARIAL TRAINING
    

Starting Generator MLE Training...
epoch 1 : .......... average_train_NLL = 6.8112, oracle_sample_NLL = 14.7055
epoch 2 : .......... average_train_NLL = 6.1624, oracle_sample_NLL = 13.6573
epoch 3 : .......... average_train_NLL = 5.8390, oracle_sample_NLL = 13.1075
epoch 4 : .......... average_train_NLL = 5.6311, oracle_sample_NLL = 12.7521
epoch 5 : .......... average_train_NLL = 5.4814, oracle_sample_NLL = 12.4448
epoch 6 : .......... average_train_NLL = 5.3692, oracle_sample_NLL = 12.2441
epoch 7 : .......... average_train_NLL = 5.2807, oracle_sample_NLL = 12.0314
epoch 8 : .......... average_train_NLL = 5.2091, oracle_sample_NLL = 11.9553
epoch 9 : .......... average_train_NLL = 5.1495, oracle_sample_NLL = 11.8170
epoch 10 : .......... average_train_NLL = 5.0988, oracle_sample_NLL = 11.7294
epoch 11 : .......... average_train_NLL = 5.0554, oracle_sample_NLL = 11.6184
epoch 12 : .......... average_train_NLL = 5.0172, oracle_sample_NLL = 11.5857
epoch 13 : .......... average_train_NL

In [None]:
    print('\nStarting Adversarial Training...')
    oracle_loss = helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                                               start_letter=START_LETTER, gpu=CUDA)
    print('\nInitial Oracle Sample Loss : %.4f' % oracle_loss)

    for epoch in range(ADV_TRAIN_EPOCHS):
        print('\n--------\nEPOCH %d\n--------' % (epoch+1))
        # TRAIN GENERATOR
        print('\nAdversarial Training Generator : ', end='')
        sys.stdout.flush()
        train_generator_PG(gen, gen_optimizer, oracle, dis, 1)

        # TRAIN DISCRIMINATOR
        print('\nAdversarial Training Discriminator : ')
        train_discriminator(dis, dis_optimizer, oracle_samples, gen, oracle, 5, 3)


Starting Adversarial Training...

Initial Oracle Sample Loss : 10.8424

--------
EPOCH 1
--------

Adversarial Training Generator :  oracle_sample_NLL = 10.8539

Adversarial Training Discriminator : 
d-step 1 epoch 1 : .......... average_loss = 0.1132, train_acc = 0.9730, val_acc = 0.5900
d-step 1 epoch 2 : .......... average_loss = 0.0915, train_acc = 0.9784, val_acc = 0.5950
d-step 1 epoch 3 : .......... average_loss = 0.0811, train_acc = 0.9811, val_acc = 0.6200
d-step 2 epoch 1 : .......... average_loss = 0.1021, train_acc = 0.9752, val_acc = 0.6050
d-step 2 epoch 2 : .......... average_loss = 0.0871, train_acc = 0.9773, val_acc = 0.6000
d-step 2 epoch 3 : .......... average_loss = 0.0720, train_acc = 0.9828, val_acc = 0.6150
d-step 3 epoch 1 : .......... average_loss = 0.1019, train_acc = 0.9748, val_acc = 0.5950
d-step 3 epoch 2 : .......... average_loss = 0.0838, train_acc = 0.9792, val_acc = 0.5900
d-step 3 epoch 3 : .......... average_loss = 0.0717, train_acc = 0.9829, val_ac

In [4]:
print(8)

8
