Esse notebook tem como objetivo testar o checkpoing disponibilizados pelo trabalho de Wang e ver os resultados de transferência de estilo

---

In [1]:
import torch
import os
import sys

# para conseguir fazer import do código do paper, eh necessario adicionar o diretório na path
TASK = "yelp"
paper_code_path = f"../papers/controllable-text-attribute-transfer/method/mymodel-{TASK}"
sys.path.append(paper_code_path)

from model import Classifier, make_model, fgim_attack
from data import prepare_data, load_human_answer, non_pair_data_loader, id2text_sentence

AUTOENCODER_LAYERS_NUM = 2
AUTOENCODER_TRANSFORMERS_MODEL_SIZE = 256
AUTOENCODER_LATENT_SIZE = 256
AUTOENCODER_FF_TRANSFORMERS_SIZE = 1024
MAX_SEQUENCE_LENGTH = 60
ID_PAD = 0
ID_UNK = 1
ID_BOS = 2
ID_EOS = 3
WORD_DICT_MAX_NUM = 5
DATA_PATH = f"../papers/controllable-text-attribute-transfer/data/{TASK}/processed_files/"

autoencoder_checkpoint = f"../papers/controllable-text-attribute-transfer/method/mymodel-{TASK}/save/1557667911/ae_model_params.pkl"
discriminator_checkpoint = f"../papers/controllable-text-attribute-transfer/method/mymodel-{TASK}/save/1557667911/dis_model_params.pkl"
device = "cuda" if torch.cuda.is_available() else "cpu"


ID_TO_WORD, VOCAB_SIZE, TRAIN_FILE_LIST, TRAIN_LABEL_LIST = prepare_data(
    data_path=DATA_PATH, max_num=WORD_DICT_MAX_NUM, task_type=TASK)

#human anwsers
gold_ans = load_human_answer(DATA_PATH)


  from .autonotebook import tqdm as notebook_tqdm


prepare data ...
Load word-dict with 9339 size and 5 max_num.


In [2]:
# load model models checkpoints
autoencoder = make_model(d_vocab=VOCAB_SIZE,
                            N=AUTOENCODER_LAYERS_NUM,
                            d_model=AUTOENCODER_TRANSFORMERS_MODEL_SIZE,
                            latent_size=AUTOENCODER_LATENT_SIZE,
                            d_ff=AUTOENCODER_FF_TRANSFORMERS_SIZE,
    )
autoencoder.load_state_dict(torch.load(autoencoder_checkpoint))
discriminator = Classifier(latent_size=AUTOENCODER_LATENT_SIZE, output_size=1)
discriminator.load_state_dict(torch.load(discriminator_checkpoint))

# move models to GPU when available
autoencoder.to(device)
discriminator.to(device)

Classifier(
  (fc1): Linear(in_features=256, out_features=100, bias=True)
  (relu1): LeakyReLU(negative_slope=0.2)
  (fc2): Linear(in_features=100, out_features=50, bias=True)
  (relu2): LeakyReLU(negative_slope=0.2)
  (fc3): Linear(in_features=50, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [None]:
eval_data_loader = non_pair_data_loader(
        batch_size=1, id_bos=ID_BOS,
        id_eos=ID_EOS, id_unk=ID_UNK,
        max_sequence_length=MAX_SEQUENCE_LENGTH, vocab_size=VOCAB_SIZE
    )
eval_file_list = [DATA_PATH + 'sentiment.test.0',DATA_PATH + 'sentiment.test.1',]
eval_label_list = [[0],[1],]
eval_data_loader.create_batches(eval_file_list, eval_label_list, if_shuffle=False)
gold_ans = load_human_answer(DATA_PATH)
assert len(gold_ans) == eval_data_loader.num_batch


autoencoder.eval()
discriminator.eval()
for it in range(eval_data_loader.num_batch):
    batch_sentences, tensor_labels, \
    tensor_src, tensor_src_mask, tensor_tgt, tensor_tgt_y, \
    tensor_tgt_mask, tensor_ntokens = eval_data_loader.next_batch()
    
    print("------------%d------------" % it)
    print(tensor_src.size())
    print(tensor_tgt_y.size())

    print(id2text_sentence(tensor_tgt_y[0], ID_TO_WORD))
    print("origin_labels", tensor_labels)

    latent, out = autoencoder.forward(tensor_src.to(device), 
                                      tensor_tgt.to(device), 
                                      tensor_src_mask.to(device), 
                                      tensor_tgt_mask.to(device))
    print("-" * 20)
    print(out)
    print(tensor_tgt_y)
    print("-" * 20)
    generator_text = autoencoder.greedy_decode(latent,
                                            max_len=MAX_SEQUENCE_LENGTH,
                                            start_id=ID_BOS)
    print(id2text_sentence(generator_text[0], ID_TO_WORD))

    # Define target label
    target = torch.tensor([[1.0]], dtype=torch.float)
    if tensor_labels[0].item() > 0.5:
        target = torch.tensor([[0.0]], dtype=torch.float)
    print("target_labels", target)

    fgim_attack(discriminator, latent, target.to(device), autoencoder, MAX_SEQUENCE_LENGTH, ID_BOS,id2text_sentence, ID_TO_WORD, gold_ans[it])


Load data from ../papers/controllable-text-attribute-transfer/data/yelp/processed_files/sentiment.test.0 ../papers/controllable-text-attribute-transfer/data/yelp/processed_files/sentiment.test.1 !
Create 1000 batches with 1 batch_size
------------0------------
torch.Size([1, 14])
torch.Size([1, 15])
ever since joes has changed hands it 's just gotten worse and worse .
origin_labels tensor([[0.]], device='cuda:0')
--------------------
tensor([[[-29.4278, -17.1816, -13.8328,  ..., -13.8273, -13.8440, -13.8395],
         [-29.8907, -16.5639, -12.9379,  ..., -12.9637, -12.9288, -12.9205],
         [-29.0805, -11.6152, -11.2399,  ..., -11.1834, -11.1776, -11.1356],
         ...,
         [-39.3841, -25.1119, -17.2934,  ..., -17.2609, -17.2652, -17.2372],
         [-29.1264, -12.2991, -11.8000,  ..., -11.8009, -11.8023, -11.8037],
         [-26.9960, -11.7949, -11.7725,  ..., -11.7744, -11.7703, -11.7705]]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward0>)
tensor([[  78,  429, 3797,   



| It  1 | dis model pred 0.0013 |
ever since joes has changed it hands 's just gotten worse and worse .
epsilon: 1.8
| It  2 | dis model pred 0.9975 |
ever since joes has changed it hands 's just gotten worse and worse .
epsilon: 1.62
| It  3 | dis model pred 0.9976 |
ever since joes has changed it hands 's just gotten worse and worse .
epsilon: 1.4580000000000002
| It  4 | dis model pred 0.9977 |
ever since joes has changed it hands 's just gotten worse and worse .
epsilon: 1.3122000000000003
| It  5 | dis model pred 0.9978 |
ever since joes has changed it hands 's just gotten worse and worse .
epsilon: 3.0
| It  1 | dis model pred 0.0013 |
ever since joes has changed hands it 's just gotten worse and quick .
epsilon: 2.7
| It  2 | dis model pred 0.9999 |
ever since joes has changed hands it 's just gotten worse and quick .
epsilon: 2.43
| It  3 | dis model pred 0.9999 |
ever since joes has changed hands it 's just gotten worse and quick .
epsilon: 2.1870000000000003
| It  4 | dis mod