In [3]:
# Load basic module
import os
import json
import random
random.seed(0)
import math
from copy import deepcopy
import argparse
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# Load self-defined module
from generator import Generator, Gen_args
from discriminator import Discriminator, Dis_args
from train import pretrain_gen, train_adv, train_dis, train_gap
from data_loader import LoadData
from rollout import Rollout

# Set random seed
SEED = 0
random.seed(SEED)
np.random.seed(SEED)

# Basic Training Paramters
BATCH_SIZE = 64
USE_CUDA = False
PRE_GEN_EPOCH_NUM = 100
PRE_ADV_EPOCH_NUM = 5
PRE_DIS_EPOCH_NUM = 5
GAP_EPOCH_NUM = 20
MC_NUM = 16
GAP_W = [0.2, 0.2, 0.6]
GEN_LR = 0.01
ADV_LR = 0.01
DIS_LR = 0.01
PRE_GEN_PATH = "../param/pre_generator.pkl"
PRE_ADV_PATH = "../param/pre_adversary.pkl"
PRE_DIS_PATH = "../param/pre_discriminator.pkl"
GEN_PATH = "../param/generator_v3.pkl"
ADV_PATH = "../param/adversary_v3.pkl"
DIS_PATH = "../param/discriminator_v3.pkl"


with open("../data/word_map.json", "r") as json_file:
    word_map = json.load(json_file)
index_map = {}
for key in word_map.keys():
    index_map[word_map[key]] = key
    
# Get training and testing dataloader
train_loader, test_loader, \
    MAX_SEQ_LEN, VOCAB_SIZE, index_map = LoadData(data_path="../data/dataset_batch.json", 
                                       word2id_path="../data/word_map.json", 
                                       train_split=0.8,
                                       BATCH_SIZE=64)

# Genrator Parameters
gen_args = Gen_args(vocab_size=VOCAB_SIZE, 
                    emb_dim=64, 
                    hidden_dim=64)

# Discriminator Parameters
dis_args = Dis_args(num_classes=2, 
                    vocab_size=VOCAB_SIZE, 
                    emb_dim=64, 
                    filter_sizes=[3, 4, 5], 
                    num_filters=[150, 150, 150], 
                    dropout=0.5)

# Adversarial Parameters
adv_args = Dis_args(num_classes=3, 
                    vocab_size=VOCAB_SIZE, 
                    emb_dim=64, 
                    filter_sizes=[3, 4, 5], 
                    num_filters=[150, 150, 150], 
                    # filter_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20],
                    # num_filters=[100, 200, 200, 200, 200, 100, 100, 100, 100, 100, 160, 160],
                    dropout=0.5)

# Define Networks
generator = Generator(gen_args, USE_CUDA)
discriminator = Discriminator(dis_args)
adversary = Discriminator(adv_args)
if USE_CUDA:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    adversary = adversary.cuda()


# Load pretrained parameters
generator.load_state_dict(torch.load(GEN_PATH))
# discriminator.load_state_dict(torch.load(DIS_PATH))
# adversary.load_state_dict(torch.load(ADV_PATH))

# Define optimizer and loss function for discriminator

total_loss = 0.
total_words = 0.
for batch in tqdm(test_loader):
    data = batch["x"]
    target = batch["x"][:,:,0]
    if USE_CUDA:
        data, target = data.cuda(), target.cuda()
    target = target.contiguous().view(-1)
    with torch.no_grad():
        pred = generator.forward(data)
target_ = target.detach().cpu().numpy()
_, pred_ = torch.max(pred, axis=-1)
pred_ = pred_.cpu().numpy()
target_query = []
pred_query = []
for i in range(0, 144+72):
    target_query.append(index_map[target_[i]])
    pred_query.append(index_map[pred_[i]])
print("[INFO] Target query: ", target_query)
print("[INFO] Predicted query: ", pred_query)


  0%|          | 0/10 [00:00<?, ?it/s]

[INFO] Complete loading data, with # of {'0': 1000, '1': 1000, '2': 1000}


100%|██████████| 10/10 [00:11<00:00,  1.17s/it]

[INFO] Target query:  ['<SOS>', 'st', 'aspirin', 'during', 'pregnancy', '<POS>', 'spotting', 'during', 'pregnancy', '<POS>', 'pregnancy', 'indigestion', '<POS>', 'the', 'miraculous', 'world', 'of', 'your', 'unborn', 'baby', '<EOS>', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '<SOS>', 'breast', 'cancer', '<POS>', 'breast', 'cancer', '<POS>', 'different', 'type', 'of', 'breast', 'cancer', '<POS>', 'will', 'invasive', 'carcinoma', 'come', 'back', '<POS>', 'breast', 'cancer', '<EOS>', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '<SOS>', 'teacher', 'choice', '<POS>', 'teacher', 'c


