In [7]:
import numpy as np
import re
import copy
from gensim.models.keyedvectors import KeyedVectors

In [None]:
model_path = './data/GoogleNews-vectors-negative300.bin'
word2vec = KeyedVectors.load_word2vec_format('./data/GoogleNews-vectors-negative300.bin', binary=True)  
# model.wv['computer']

In [8]:
entities = {}
with open('dialog-bAbI-tasks/dialog-babi-kb-all.txt', 'r') as file:
    lines = file.readlines()
    for l in lines:
        wds = l.rstrip().split(' ')[2].split('\t')
        slot_type = wds[0] # ex) R_price
        slot_val = wds[1] # ex) cheap
        if slot_type not in entities:
            entities[slot_type] = []
        if slot_val not in entities[slot_type]:
            entities[slot_type].append(slot_val)
for idx, (ent_name, ent_vals) in enumerate(entities.items()):
    print(idx, ent_name, ent_vals[0] )

0 R_cuisine korean
1 R_location seoul
2 R_price cheap
3 R_rating 1
4 R_phone resto_seoul_cheap_korean_1stars_phone
5 R_address resto_seoul_cheap_korean_1stars_address
6 R_number four


In [9]:
def update_context(context, sentence):
    for idx, (ent_key, ent_vals) in enumerate(entities.items()):
        for w in sentence:
            if w in ent_vals:
#                 print('idx', idx, w)
                context[idx] = 1

#test
t_context = [0] * len(entities.keys())
t_sentence = 'for people two moderate range are looking for'
update_context(t_context, t_sentence)
print('sentence:', t_sentence)
print('context:', t_context)

sentence: for people two moderate range are looking for
context: [0, 0, 0, 0, 0, 0, 0]


In [10]:
def get_bow(sentence, vocab, w2i):
    bow = [0] * len(vocab)
    for word in sentence.split(' '):
        if word in w2i:
            bow[w2i[word]] += 1
    return bow

In [11]:
def load_data(fpath, entities, vocab, system_acts):
    data = []
    with open(fpath, 'r') as f:
        lines = f.readlines()
        x, y, c = [], [], []
        context = [0] * len(entities.keys())
        for idx, l in enumerate(lines):
            l = l.rstrip()
            if l == '':
                data.append((x, y, c))
                # reset
                x, y, c = [], [], []
                context = [0] * len(entities.keys())
            else:
                ls = l.split("\t")
                t_u = ls[0].split(' ', 1)
                turn = t_u[0]
                uttr = t_u[1].split(' ')
                update_context(context, uttr)
                sys_act = SILENT
                if len(ls) == 2: # includes user and system utterance
                    for w in uttr:
                        if w not in vocab: vocab.append(w)

                    sys_act = ls[1]
                    sys_act = re.sub(r'resto_\S+', '', sys_act)
                    if sys_act.startswith('api_call'): sys_act = 'api_call'
                    if sys_act not in system_acts: system_acts.append(sys_act)
                else:
                    continue # TODO
                        
                x.append(uttr)
                y.append(sys_act)
                c.append(copy.copy((context)))
    vocab = sorted(vocab)
    return data, vocab, system_acts

# create training dataset
SILENT = '<SILENT>'
system_acts = [SILENT]
vocab = []
fpath_train = 'dialog-bAbI-tasks/dialog-babi-task5-full-dialogs-trn.txt'
fpath_test = 'dialog-bAbI-tasks/dialog-babi-task5-full-dialogs-tst-OOV.txt'
train_data, vocab, system_acts = load_data(fpath_train, entities, vocab, system_acts)
test_data, vocab, system_acts = load_data(fpath_test, entities, vocab, system_acts)


In [12]:
# training
# train, train_dlg, vocabs_train = load('dialog-bAbI-tasks/dialog-babi-task5-full-dialogs-trn.txt')
max_turn_train = max([len(d[0]) for d in train_data])
max_turn_test = max([len(d[0]) for d in test_data])
max_turn = max(max_turn_train, max_turn_test)
print('max turn:', max_turn)
w2i = dict((w, i) for i, w in enumerate(vocab))
i2w = dict((i, w) for i, w in enumerate(vocab))
act2i = dict((act, i) for i, act in enumerate(system_acts))

print('action size:', len(system_acts))

max turn: 27
action size: 16


In [50]:
act2i

{'<SILENT>': 0,
 'any preference on a type of cuisine': 15,
 'api_call': 7,
 'great let me do the reservation': 11,
 'hello what can i help you with today': 1,
 'here it is ': 12,
 'how many people would be in your party': 4,
 "i'm on it": 2,
 'is there anything i can help you with': 13,
 'ok let me look into some options for you': 6,
 'sure is there anything else to update': 8,
 'sure let me find an other option for you': 10,
 'what do you think of this option: ': 9,
 'where should it be': 3,
 'which price range are looking for': 5,
 "you're welcome": 14}

In [58]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import random
from torch.autograd import Variable

def to_var(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

class WordEmbedding(nn.Module):
    '''
    In : (N, sentence_len)
    Out: (N, sentence_len, embd_size)
    '''
    def __init__(self, vocab_size, embd_size, pre_embd_w=None, is_train_embd=False):
        super(WordEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embd_size)
        if pre_embd_w is not None:
            self.embedding.weight = nn.Parameter(pre_embd_w, requires_grad=is_train_embd)

    def forward(self, x):
        return self.embedding(x)

class HybridCodeNetwork(nn.Module):
    def __init__(self, vocab_size, embd_size, hidden_size, action_size):
        super(HybridCodeNetwork, self).__init__()
        self.embd_size = embd_size
        self.hidden_size = hidden_size
        self.embedding = WordEmbedding(vocab_size, embd_size)
        self.lstm = nn.LSTM(307, hidden_size, batch_first=True) # TODO input dim
        self.fc = nn.Linear(hidden_size, action_size)

#     def forward(self, uttr, context, act_mask, bow, last_act):
    def forward(self, uttr, context):
        # uttr: (bs, dialog_len, sentence_len)
        # uttr: (bs, dialog_len, context_dim)
        bs = uttr.size(0)
        dlg_len = uttr.size(1)
        sent_len = uttr.size(2)
        
        embd = self.embedding(uttr.view(bs, -1)) # (bs, dialog_len*sentence_len, embd)
        embd = embd.view(bs, dlg_len, sent_len, -1) # (bs, dialog_len, sentence_len, embd)
        embd = torch.mean(embd, 2) # (bs, dialog_len, embd)
        x = torch.cat((embd, context), 2) # (bs, dialog_len, embd+context_dim)
        x, (h, c) = self.lstm(x) # (bs, seq, hid), ((1, bs, hid), (1, bs, hid))
        y = F.log_softmax(self.fc(x), -1) # (bs, seq, action_size)
        return y

embd_size = 300
hidden_size = 100
print(len(system_acts))
model = HybridCodeNetwork(len(vocab), embd_size, hidden_size, len(system_acts))
if torch.cuda.is_available():
    model.cuda()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))

16


In [68]:
def add_padding(data, seq_len):
    pad_len = max(0, seq_len - len(data))
    data += [0] * pad_len
    data = data[:seq_len]
    return data

def make_word_vector(uttrs_list, w2i, dialog_maxlen, uttr_maxlen):
    dialog_list = []
    for uttrs in uttrs_list:
        dialog = []
        for sentence in uttrs:
            sent_vec = [w2i[w] for w in sentence]
            sent_vec = add_padding(sent_vec, uttr_maxlen)
            dialog.append(sent_vec)
        for _ in range(dialog_maxlen - len(dialog)):
            dialog.append([0] * uttr_maxlen)
        dialog = torch.LongTensor(dialog[:dialog_maxlen])
        dialog_list.append(dialog)
    return to_var(torch.stack(dialog_list, 0))
    
def train(model, data, optimizer, n_epochs=10, batch_size=64):
    for epoch in range(n_epochs):
        random.shuffle(data)
        data = copy.copy(data)
        for i in range(0, len(data)-batch_size, batch_size):
            batch = data[i:i+batch_size]
            uttrs_list = [d[0] for d in batch]
            dialog_maxlen = max([len(uttrs) for uttrs in uttrs_list])
            uttr_maxlen = max([len(u) for uttrs in uttrs_list for u in uttrs])
#             print('dialog_maxlen', dialog_maxlen, ', uttr_maxlen', uttr_maxlen)
            uttr_var = make_word_vector(uttrs_list, w2i, dialog_maxlen, uttr_maxlen)
            batch_labels = [d[1] for d in batch]
            labels_var = []
            for labels in batch_labels:
                vec_labels = [act2i[l] for l in labels]
                pad_len = dialog_maxlen - len(labels)
#                 print('b vec_labels', len(vec_labels))
                for _ in range(pad_len):
                    vec_labels.append(act2i[SILENT])
#                 print('vec_labels', len(vec_labels))
                labels_var.append(torch.LongTensor(vec_labels))
            labels_var = to_var(torch.stack(labels_var, 0))
            context = copy.deepcopy([d[2] for d in batch])
            for i, c in enumerate(context):
                pad_len = dialog_maxlen - len(c)
                for _ in range(pad_len):
                    context[i].append([1] * len(entities.keys()))
            context = to_var(torch.FloatTensor(context))
            pred = model(uttr_var, context)
            action_size = pred.size(-1)
            loss = F.nll_loss(pred.view(-1, action_size), labels_var.view(-1))
            print('loss', loss.data[0])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
train(model, train_data, optimizer)

loss 0.22652828693389893
loss 0.21633586287498474
loss 0.21057692170143127
loss 0.21129053831100464
loss 0.2321590930223465
loss 0.21111074090003967
loss 0.19333162903785706
loss 0.21138924360275269
loss 0.21803469955921173
loss 0.20133042335510254
loss 0.2107846587896347
loss 0.22052684426307678
loss 0.19224828481674194
loss 0.1939406543970108
loss 0.18736137449741364
loss 0.19156664609909058
loss 0.19662989675998688
loss 0.18234926462173462
loss 0.17528751492500305
loss 0.1813841015100479
loss 0.19426864385604858
loss 0.178115576505661
loss 0.18112820386886597
loss 0.16078197956085205
loss 0.17916178703308105
loss 0.1937027871608734
loss 0.1767042875289917
loss 0.17284417152404785
loss 0.16177503764629364
loss 0.17099489271640778
loss 0.16867317259311676
loss 0.15026849508285522
loss 0.1678142100572586
loss 0.16510629653930664
loss 0.1562018245458603
loss 0.15048418939113617
loss 0.15733806788921356
loss 0.1669473946094513
loss 0.1580592542886734
loss 0.16015255451202393
loss 0.14651