In [71]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        try:
            from torch.nn import TransformerEncoder, TransformerEncoderLayer
        except:
            raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.')
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)

In [2]:
import torchtext

with open('../data/train.history_belief') as fp:
    raw_train_data = [line.split() for line in fp.read().split('\n')]

In [3]:
train_vocab = torchtext.vocab.build_vocab_from_iterator(raw_train_data, specials=["<unk>"])

In [4]:
train_vocab.set_default_index(train_vocab["<unk>"])

In [75]:
train_data = [torch.tensor([train_vocab(sent)], dtype=torch.long) for sent in raw_train_data]

In [76]:
model = TransformerModel(
    ntoken=len(train_vocab),
    ninp=512,
    nhead=2,
    nhid=200,
    nlayers=2
)

In [77]:
import time
import math

criterion = nn.NLLLoss()
lr = 20

model.train()
total_loss = 0.
start_time = time.time()
ntokens = len(train_vocab)
for batch, sent in tqdm(enumerate(train_data)):
    data, targets = sent[:, :-1], sent[:, 1:]
    model.zero_grad()
    output = model(data)
    loss = criterion(output[0], targets[0])
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
    for p in model.parameters():
        p.data.add_(p.grad, alpha=-lr)
    total_loss += loss.item()
    if batch % 10 == 0 and batch > 0:
        cur_loss = total_loss / 10
        elapsed = time.time() - start_time
        print('ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
            batch,
            elapsed * 1000 / 10, cur_loss, math.exp(cur_loss)))
        total_loss = 0
        start_time = time.time()

12it [00:02,  6.46it/s]

ms/batch 10.00 | loss 187.38 | ppl     8.62


22it [00:04,  6.07it/s]

ms/batch 20.00 | loss 198.77 | ppl     7.61


32it [00:06,  3.68it/s]

ms/batch 30.00 | loss 254.44 | ppl     7.65


41it [00:08,  3.14it/s]

ms/batch 40.00 | loss 247.18 | ppl     6.21


52it [00:11,  5.61it/s]

ms/batch 50.00 | loss 216.02 | ppl     6.53


62it [00:12,  6.97it/s]

ms/batch 60.00 | loss 159.10 | ppl     5.73


72it [00:15,  3.64it/s]

ms/batch 70.00 | loss 268.92 | ppl     5.99


82it [00:17,  4.70it/s]

ms/batch 80.00 | loss 233.55 | ppl     5.26


91it [00:19,  4.06it/s]

ms/batch 90.00 | loss 203.82 | ppl     4.54


101it [00:22,  4.56it/s]

ms/batch 100.00 | loss 281.05 | ppl     4.87


111it [00:24,  3.38it/s]

ms/batch 110.00 | loss 238.35 | ppl     4.67


111it [00:25,  4.39it/s]


KeyboardInterrupt: 

In [7]:
with open('../data/val.history_belief') as fp:
    raw_val_data = [line.split() for line in fp.read().split('\n')]

In [42]:
BELIEF = '<|belief|>'

val_data = []
for raw_sent in raw_val_data[:-1]:
    indexes = train_vocab(raw_sent)
    belief_idx = indexes.index(train_vocab[BELIEF])
    data, target = indexes[:belief_idx], indexes[belief_idx:]
    val_data.append((torch.tensor(data, dtype=torch.long),
                     torch.tensor(target, dtype=torch.long)))

In [69]:
def translate(indexes):
    return train_vocab.lookup_tokens(list(indexes))

In [63]:
INPUT_SOS = '<|context|>'
INPUT_EOS = '<|endofcontext|>'
OUTPUT_SOS = '<|belief|>'
OUTPUT_EOS = '<|endofbelief|>'

def belief_to_state_list(belief):
    belief_list = [token for token in belief if token not in [OUTPUT_SOS, OUTPUT_EOS]]
    belief_list = [slot.split() for slot in ' '.join(belief_list).split(',')]
    return belief_list

def belief_to_state_dict(belief):
    belief_list = belief_to_state_list(belief)
    state_dict = {}
    for state in belief_list:
        if len(state) < 3: continue
        domain = state[0]
        slot = state[1]
        sub_slot = None
        rest = state[2:]
        if slot == 'book':
            sub_slot = state[2]
            rest = state[3:]
        value = ' '.join(rest)
        d = state_dict.get(domain, {})
        if sub_slot:
            ss = d.get(slot, {})
            ss.update({
                sub_slot: value
            })
            d.update({slot: ss})
        else:
            d.update({slot: value})
        state_dict.update({domain: d})
    return state_dict


def match_slot(true, pred):
    pred_state = belief_to_state_dict(pred)
    true_list = belief_to_state_list(true)
    slot_matches = []
    for i, state in enumerate(true_list):
        if len(state) < 3: continue
        slot_matches.append(False)
        domain = state[0]
        if domain not in pred_state.keys(): continue
        
        slot = state[1]
        if slot not in pred_state[domain].keys(): continue
        
        if slot != 'book':
            true_value = " ".join(state[2:])
            pred_value = pred_state[domain][slot]
        else:
            sub_slot = state[2]
            if sub_slot not in pred_state[domain][slot]: continue
            true_value = " ".join(state[3:])
            pred_value = pred_state[domain][slot][sub_slot]
        
        if true_value != pred_value: continue
        slot_matches[i] = True
            
    all_match = sum(slot_matches) == len(true_list)
    
    return all_match, slot_matches

def get_accuracy(results):
    total_states = len(results)
    total_slots = sum([len(result[1]) for result in results])
    total_correct_states = sum([result[0] for result in results])
    total_correct_slots = sum([sum(result[1]) for result in results])
    return {
        'joint_accuracy': total_correct_states / total_states,
        'slot_accuracy': total_correct_slots / total_slots
    }

In [38]:
from tqdm import tqdm

In [73]:
EOT = '<|endoftext|>'

results = []

model.eval()
predictions = []
for data, target in tqdm(val_data[:2]):
    prediction = [train_vocab[BELIEF]]
    input_data = torch.cat([data, torch.tensor([train_vocab[BELIEF]], dtype=torch.long)])
    while len(prediction) < 5 and prediction[-1] != EOT:
        out = model(input_data)
        pred = torch.argmax(out[0, -1])
        prediction.append(int(pred))
        input_data = torch.cat([input_data, pred.reshape(1)])
    predictions.append(prediction)
    result = match_slot(translate(target), translate(prediction))
    results.append(result)

100%|██████████| 2/2 [00:00<00:00,  2.01it/s]


In [74]:
get_accuracy(results)

{'joint_accuracy': 0.0, 'slot_accuracy': 0.0}