In [1]:
import argparse
import os
import pickle as p
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from read_data import *
from process_data import *
from run import *
from constants import *
from hparams import hparams
from utils import *
import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

In [2]:
word_embeddings = p.load(open("./data/word_embeddings.p", 'rb'))
word_embeddings = np.array(word_embeddings)
word2index = p.load(open("./data/vocab.p", 'rb'))

index2kwd, kwd2index, index2cnt = read_kwd_vocab("./data/train_kwd_vocab.txt")

In [3]:
index2word = reverse_dict(word2index)

In [4]:
encoder = EncoderRNN(hparams.HIDDEN_SIZE, word_embeddings, hparams.RNN_LAYERS,
                         dropout=hparams.DROPOUT, update_wd_emb=hparams.UPDATE_WD_EMB)
decoder = AttnDecoderRNN(hparams.HIDDEN_SIZE, len(word2index), word_embeddings, hparams.ATTN_TYPE,hparams.RNN_LAYERS, dropout=hparams.DROPOUT, update_wd_emb=hparams.UPDATE_WD_EMB,
                             condition=hparams.DECODER_CONDITION_TYPE)
kwd_predictor = get_predictor(word_embeddings, hparams)
kwd_bridge = MLPBridge(hparams.HIDDEN_SIZE, hparams.MAX_KWD, hparams.HIDDEN_SIZE, len(word_embeddings[0]),
                               norm_type=hparams.BRIDGE_NORM_TYPE, dropout=hparams.DROPOUT)
if hparams.USE_CUDA:
    encoder.cuda()
    decoder.cuda()
    kwd_predictor.cuda()
    kwd_bridge.cuda()

In [5]:
# if using the pretrained ckpt
# models = torch.load("./ckpt/s2s_D0.3_cnn_noneg_dropout_replace_fr.epoch59.models")
# hparams.load("./hparams/s2s_D0.3_cnn_noneg_dropout_replace_fr.json")
models = torch.load("./ckpt/##YOUR_MODEL##.models")
hparams.load("./hparams/##YOUR_MODEL##.json")

In [6]:
with open("./data/kwd_filter_dict.json", encoding="utf-8") as f:
    filter_dict = json.load(f)

def make_filter_mask(post, filter_dict, kwd2index):
    curr_kwd_filter_mask = [0 for i in range(len(kwd2index))]
    for keys, to_filters in filter_dict.items():
        if keys.startswith("@") and keys.endswith("@"):  # regex
            if bool(re.search(keys[1:-1], post)):
                for kwd0 in to_filters:
                    curr_kwd_filter_mask[kwd2index[kwd0]] = -1e20
        else:
            for k in keys.split(","):
                if k in post:
                    for kwd0 in to_filters:
                        curr_kwd_filter_mask[kwd2index[kwd0]] = -1e20
    return curr_kwd_filter_mask

In [7]:
encoder.load_state_dict(models["encoder"])
decoder.load_state_dict(models["decoder"])
kwd_predictor.load_state_dict(models["kwd_predictor"])
kwd_bridge.load_state_dict(models["kwd_bridge"])

<All keys matched successfully>

In [8]:
sent0 = "oster fpsthm2578 6-speed retractable cord hand mixer with clean start , black  ."
words0 = sent0.strip().lower().split()[:hparams.MAX_POST_LEN]
# batch of size 1
input_seqs = [[word2index[x] if x in word2index else word2index[UNK_token] for x in words0 ]]
input_lens = [len(words0)]
test_data = [["id0"],input_seqs,input_lens,[None],[None],[0],[0]]

In [9]:
kwd_filter_mask0 = make_filter_mask(sent0, filter_dict, kwd2index)
kwd_filter_masks = [kwd_filter_mask0]  # the mask here is for filter out kwds
test_data[-1] = kwd_filter_masks

In [10]:
input_seqs

[[2658, 53710, 20790, 4961, 748, 133, 1042, 14, 139, 1025, 5, 90, 4]]

In [11]:
from beam import evaluate_beam

In [12]:
hparams.BATCH_SIZE = 1
out_seqs = evaluate_beam(word2index, index2word, encoder, decoder, kwd_predictor, kwd_bridge, test_data, hparams.MAX_QUES_LEN, "./infer_out", "infer", index2kwd, save_all_beam=True, infer=True)
print(out_seqs)

BATCH: 100%|██████████| 1/1 [00:00<00:00,  4.92it/s]['what is the wattage of the beaters ? <EOS>\n', 'what is the wattage of the mixer ? <EOS>\n', 'what is the power of the beaters ? <EOS>\n', 'what is the wattage for the beaters ? <EOS>\n', 'what is the wattage of this beaters ? <EOS>\n', 'does this mixer have a beater attachment ? <EOS>\n']



In [13]:
hparams.KWD_CLUSTERS = 2
kwd_edge_cnt = scipy.sparse.load_npz("./data/kwd_edges.npz")
kwd_clusters = get_cluster_kwds(kwd_predictor, test_data, kwd_edge_cnt, index2kwd, kwd2index)

CLUSTER: 100%|██████████| 1/1 [00:00<00:00,  4.62it/s]


In [20]:
hparams.DECODE_USE_KWD_LABEL = True
out_seqs = []
for i in range(hparams.KWD_CLUSTERS):
    test_data[5] = kwd_clusters[i]
    tmp_seqs = evaluate_beam(word2index, index2word, encoder, decoder, kwd_predictor, kwd_bridge, test_data, hparams.MAX_QUES_LEN, "./infer_out", "infer", index2kwd, save_all_beam=True, infer=True)
    print(tmp_seqs)
    out_seqs.extend(tmp_seqs)

BATCH: 100%|██████████| 1/1 [00:00<00:00,  3.80it/s]
['what is the wattage of the beaters ? <EOS>\n', 'does this mixer have a beater attachment ? <EOS>\n', 'what is the wattage of the mixer ? <EOS>\n', 'does this model have a beater attachment ? <EOS>\n', 'does this mixer have a whisk attachment ? <EOS>\n', 'does this model have a <unk> attachment ? <EOS>\n']
BATCH: 100%|██████████| 1/1 [00:00<00:00,  3.93it/s]
['how long is the cord ? <EOS>\n', 'how long is the power cord ? <EOS>\n', 'what is the power of the cord ? <EOS>\n', 'what is the power of the beaters ? <EOS>\n', 'how long is the cord <EOS>\n', 'is the cord retractable ? <EOS>\n']


In [21]:
def clean_html(text):
    text = re.sub(r"& (\S+) ;", r"&\1;", text)
    text = re.sub(r"& # (\S+) ;", r"&#\1;", text)
    text = html.unescape(text)
    return text

def clean_text(text):
    return re.sub(r"( <EOS>)", "", clean_html(text.strip()))

# Jaccard
def sent_sim(text1, text2):
    words1, words2 = set(text1.lower().strip().split()), set(text2.lower().strip().split())
    return len(words1 & words2) / len(words1 | words2)

def deduplicate(texts0, preserve=3, threshold=0.5):
    texts = texts0[:]
    assert len(texts) >= preserve and preserve > 0
    if len(texts) == preserve:
        return list(range(len(texts))), texts
    sel_ids, remain_ids = [0], list(range(1, len(texts)))
    sel_texts, remain_texts = texts[:1], texts[1:]
    for i in range(1, preserve):
        overlaps = []
        sel_cand = None
        for cand_id, cand in enumerate(remain_texts):
            overlap = max(sent_sim(cand, sel) for sel in sel_texts)
            if overlap < threshold:
                sel_cand = cand_id
                break
            overlaps.append(overlap)
        if sel_cand is None:
            sel_cand = np.argmin(overlaps)
        sel_texts.append(remain_texts[sel_cand])
        sel_ids.append(remain_ids[sel_cand])
        del remain_texts[sel_cand]
        del remain_ids[sel_cand]
    return sel_ids, sel_texts


In [22]:
cleaned_out_seqs = [clean_text(x) for x in out_seqs]
filtered_ids, filtered_texts = deduplicate(cleaned_out_seqs)
print(filtered_texts)

['what is the wattage of the beaters ?', 'does this mixer have a beater attachment ?', 'how long is the cord ?']
