In [0]:
import re
from collections import Counter
from tqdm import tqdm
import json
from torch.utils.data import Dataset, DataLoader
import torch

# ConvAI dataset

![convai2-img](http://convai.io/personachat-example.png 'example')

This is how raw input/target sample from training data looks like:

```
{
    "text": "your persona: i had a gig at local theater last night.\nyour persona: i work as a stand up comedian.\nyour persona: i come from a small town.\nyour persona: my favorite drink is cuba libre.\nyour persona: i did a few small roles in tv series.\nwe all live in a yellow submarine , a yellow submarine . morning !\nhi ! that is a great line for my next stand up .\nlol . i am shy , anything to break the ice , and i am a beatles fan .\ni can tell . i am not , you can see me in some tv shows\nreally ? what shows ? i like tv , it makes me forget i do not like my family\nwow , i wish i had a big family . i grew up in a very small town .\ni did too . i do not get along with mine . they have no class .\njust drink some cola with rum and you'll forget about them !\nput the lime in the coconut as well . . .\nnah , plain cuba libre , that's what we drank yesterday at the theater .\ni prefer mojitos . watermelon or cucumber .", 
    "labels": ["those are really yummy too , but not my favorite ."], 
    "reward": 0, 
    "episode_done": true, 
    "id": "convai2:self:no_cands"
}
```

## Tokenization

Here tokenization is done using a regular expression as in ParlAI framework (where the dataset is coming from!)

In [0]:
RETOK = re.compile(r'\w+|[^\w\s]|\n', re.UNICODE)

In [3]:
# example of parsed text

RETOK.findall('your persona: i had a gig at local theater last night.\nyour persona: i work as a stand up comedian.')

['your',
 'persona',
 ':',
 'i',
 'had',
 'a',
 'gig',
 'at',
 'local',
 'theater',
 'last',
 'night',
 '.',
 '\n',
 'your',
 'persona',
 ':',
 'i',
 'work',
 'as',
 'a',
 'stand',
 'up',
 'comedian',
 '.']

# ConvAI dictionary

The dataset is coming with a precomputed dictionary, it looks like this:

For each word there is a corresponding count. Counts for special symbols are artificially presented / not real.

```
__null__	1000000003
__start__	1000000002
__end__	1000000001
__unk__	1000000000
.	276863
i	270789
you	93655
your	91941
a	89140
?	85346
persona	80372
\n	80365
:	80365
,	79513
to	79240
my	73999
'	68126
do	55199
is	53581
the	49955
```

`ChatDictionary` class implements the loading of that file with helpful functions.

In [0]:
class ChatDictionary(object):
    """
    Simple dict loader
    """
    def __init__(self, dict_file_path):
        self.word2ind = {}  # word:index
        self.ind2word = {}  # index:word
        self.counts = {}  # word:count

        dict_raw = open(dict_file_path, 'r').readlines()
        
        for i, w in enumerate(dict_raw):
            _word, _count = w.strip().split('\t')
            if _word == '\\n':
                _word = '\n'
            self.word2ind[_word] = i
            self.ind2word[i] = _word
            self.counts[_word] = _count
            
    def t2v(self, tokenized_text):
        return [self.word2ind[w] if w in self.counts else self.word2ind['__unk__'] for w in tokenized_text]

    def v2t(self, list_ids):
        return ' '.join([self.ind2word[i] for i in list_ids])
    
    def pred2text(self, tensor):
        result = []
        for i in range(tensor.size(0)):
            if tensor[i].item() == '__end__'  or tensor[i].item() == '__null__':  # null is pad
                break
            else:
                result.append(self.ind2word[tensor[i].item()])
        return ' '.join(result)
    
    def __len__(self):
        return len(self.counts)

# Dataset class

The `ChatDataset` should be familiar to all of you, nothing fancy there

In [0]:
class ChatDataset(Dataset):
    """
    Json dataset wrapper
    """
    
    def __init__(self, dataset_file_path, dictionary, dt='train'):
        super().__init__()
        
        json_text = open(dataset_file_path, 'r').readlines()
        self.samples = []
        
        for sample in tqdm(json_text):
            sample = sample.rstrip()
            sample = json.loads(sample)
            _inp_toked = RETOK.findall(sample['text'])
            _inp_toked_id = dictionary.t2v(_inp_toked)

            sample['text_vec'] = torch.tensor(_inp_toked_id, dtype=torch.long)
            
            # train and valid have different key names for target
            if dt == 'train':
                _tar_toked = RETOK.findall(sample['labels'][0]) + ['__end__']
            elif dt == 'valid':
                _tar_toked = RETOK.findall(sample['eval_labels'][0]) + ['__end__']
                
            _tar_toked_id = dictionary.t2v(_tar_toked)
            
            sample['target_vec'] = torch.tensor(_tar_toked_id, dtype=torch.long)
            
            self.samples.append(sample)
            
    def __getitem__(self, i):
        return self.samples[i]['text_vec'], self.samples[i]['target_vec']
    
    def __len__(self):
        return len(self.samples)

# Padding, sorting, packing

`pad_tensor` function implements padding of a given tensor using the specified PAD token.

`argsort` reorders the given list using provided keys. This is necessary for unpacking. (see [here](https://pytorch.org/docs/master/nn.html?highlight=pack#torch.nn.utils.rnn.pack_padded_sequence))

`batchify` uses both previous function to make a minibatch which is ready to be packed.

In [0]:
def pad_tensor(tensors, sort=True, pad_token=0):
    rows = len(tensors)
    lengths = [len(i) for i in tensors]
    max_t = max(lengths)
        
    output = tensors[0].new(rows, max_t)
    output.fill_(pad_token)  # 0 is a pad token here
    
    for i, (tensor, length) in enumerate(zip(tensors, lengths)):
        output[i,:length] = tensor

    return output, lengths

def argsort(keys, *lists, descending=False):
    """Reorder each list in lists by the (descending) sorted order of keys.
    :param iter keys: Keys to order by.
    :param list[list] lists: Lists to reordered by keys's order.
                             Correctly handles lists and 1-D tensors.
    :param bool descending: Use descending order if true.
    :returns: The reordered items.
    """
    ind_sorted = sorted(range(len(keys)), key=lambda k: keys[k])
    if descending:
        ind_sorted = list(reversed(ind_sorted))
    output = []
    for lst in lists:
        if isinstance(lst, torch.Tensor):
            output.append(lst[ind_sorted])
        else:
            output.append([lst[i] for i in ind_sorted])
    return output

def batchify(batch):
    inputs = [i[0] for i in batch]
    labels = [i[1] for i in batch]
    
    input_vecs, input_lens = pad_tensor(inputs)
    label_vecs, label_lens = pad_tensor(labels)
    
    # sort only wrt inputs here for encoder packinng
    input_vecs, input_lens, label_vecs, label_lens = argsort(input_lens, input_vecs, input_lens, label_vecs, label_lens, descending=True)

    return {
        "text_vecs": input_vecs,
        "text_lens": input_lens,
        "target_vecs": label_vecs,
        "target_lens": label_lens
    }

In [7]:
# loading datasets and dictionary

# downloading pretrained models and data

### DOWNLOADING THE FILES
import os

### persona chat dataset
if not os.path.exists('./dict'):
    !wget "https://nyu.box.com/shared/static/sj9f87tofpicll89xbc154pmbztu5q4h" -O './dict'
if not os.path.exists('./train.jsonl'):
    !wget "https://nyu.box.com/shared/static/aqp0jyjaixjmukm5asasivq2bcfze075.jsonl" -O './train.jsonl'
if not os.path.exists('./valid.jsonl'):
    !wget "https://nyu.box.com/shared/static/eg4ivddtqib2hkf1k8rkxnmzmo0cq27p.jsonl" -O './valid.jsonl'

if not os.path.exists('./chat_model_best_22.pt'):
    !wget "https://nyu.box.com/shared/static/24zsynuks8nzg7530tgakzh8o62id9xa.pt" -O './chat_model_best_22.pt'

chat_dict = ChatDictionary('./dict')
train_dataset = ChatDataset('./train.jsonl', chat_dict)
valid_dataset = ChatDataset('./valid.jsonl', chat_dict, 'valid')

--2019-11-10 05:22:16--  https://nyu.box.com/shared/static/sj9f87tofpicll89xbc154pmbztu5q4h
Resolving nyu.box.com (nyu.box.com)... 107.152.26.197, 107.152.27.197
Connecting to nyu.box.com (nyu.box.com)|107.152.26.197|:443... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: /public/static/sj9f87tofpicll89xbc154pmbztu5q4h [following]
--2019-11-10 05:22:16--  https://nyu.box.com/public/static/sj9f87tofpicll89xbc154pmbztu5q4h
Reusing existing connection to nyu.box.com:443.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://nyu.app.box.com/public/static/sj9f87tofpicll89xbc154pmbztu5q4h [following]
--2019-11-10 05:22:16--  https://nyu.app.box.com/public/static/sj9f87tofpicll89xbc154pmbztu5q4h
Resolving nyu.app.box.com (nyu.app.box.com)... 107.152.26.199, 107.152.27.199
Connecting to nyu.app.box.com (nyu.app.box.com)|107.152.26.199|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://public.boxclo

100%|██████████| 131438/131438 [00:11<00:00, 11872.11it/s]
100%|██████████| 7801/7801 [00:00<00:00, 12584.08it/s]


In [8]:
len(train_dataset)

131438

In [0]:
train_loader = DataLoader(train_dataset, shuffle=True, collate_fn=batchify, batch_size=256)
valid_loader = DataLoader(valid_dataset, shuffle=False, collate_fn=batchify, batch_size=256)

In [0]:
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class EncoderTransformer(nn.Module):
    def __init__(self,vocab_size,max_len,shared_lt,dropout,dim=256, num_layers=2, nhead=4,pad_idx=0):
      # you need to add more things here
      super().__init__()
      self.token_embed = shared_lt
      #self.token_embed = nn.Embedding(vocab_size, dim)
      self.position_embed = nn.Embedding(max_len, dim)
      encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=nhead)
      self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
      self.pad_idx = pad_idx
      self.dropout = nn.Dropout(p=dropout)
      

    def forward(self, text_vec):
      pos = torch.arange(text_vec.size(1), device=text_vec.device).repeat(1,text_vec.size(0)).view(-1,text_vec.size(1))#.unsqueeze(1)
      #print(self.token_embed(text_vec))
      #print(self.position_embed(pos))
      output = torch.add(self.token_embed(text_vec),self.position_embed(pos)).to(current_device)
      output = self.dropout(output)
      #print(output)
      #print(text_vec.shape)
      #print(output.shape)
      attention_mask = text_vec.eq(self.pad_idx)
      
      x = self.transformer(output.transpose(0,1),src_key_padding_mask=attention_mask)
      #print(output.shape)
      x.transpose_(0,1)
      #print(x.shape)
      hidden = torch.mean(x,dim=1,keepdim=True).transpose(0,1)
      return x, hidden , attention_mask


    
class DecoderRNN(nn.Module):
    """Generates a sequence of tokens in response to context."""

    def __init__(self, vocab_size, embed_size, hidden_size, num_layers, dropout=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = nn.Dropout(p=dropout)
        
        self.embedding = nn.Embedding(self.vocab_size, self.embed_size, 0)
        
        self.gru = nn.GRU(
            self.embed_size, self.hidden_size, num_layers=self.num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0,
        )
        
        self.attention = AttentionLayer(self.hidden_size, self.embed_size)

        self.out = nn.Linear(self.hidden_size, self.vocab_size)
        self.longest_label = 100

    def forward(self, text_vec, decoder_hidden, encoder_states):
        emb = self.embedding(text_vec)
        emb = self.dropout(emb)
        seqlen = text_vec.size(1)
        encoder_output, encoder_hidden, attention_mask = encoder_states
        
        decoder_hidden = decoder_hidden
        output = []
        attn_w_log = []

        for i in range(seqlen):
            decoder_output, decoder_hidden = self.gru(emb[:,i,:].unsqueeze(1), decoder_hidden)
            
            # compute attention at each time step
            decoder_output_attended, attn_weights = self.attention(decoder_output, decoder_hidden, encoder_output, attention_mask)
            output.append(decoder_output_attended)
            attn_w_log.append(attn_weights)
            
        output = torch.cat(output, dim=1).to(text_vec.device)
        scores = self.out(output)
        
        return scores, decoder_hidden, attn_w_log
    
    def decode_forced(self, ys, encoder_states, xs_lens):
        encoder_output, encoder_hidden, attention_mask = encoder_states
        
        batch_size = ys.size(0)
        target_length = ys.size(1)
        longest_label = max(target_length, self.longest_label)
        
        starts = torch.Tensor([1]).long().to(self.embedding.weight.device).expand(batch_size, 1).long()  # expand to batch size
        
        # Teacher forcing: Feed the target as the next input
        y_in = ys.narrow(1, 0, ys.size(1) - 1)
        decoder_input = torch.cat([starts, y_in], 1)
        decoder_output, decoder_hidden, attn_w_log = self.forward(decoder_input, encoder_hidden, encoder_states)
        _, preds = decoder_output.max(dim=2)
        
        return decoder_output, preds, attn_w_log
    
    
class AttentionLayer(nn.Module):

    def __init__(self, hidden_size, embedding_size):
        super().__init__()
        input_dim = hidden_size

        self.linear_out = nn.Linear(hidden_size+input_dim, input_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.tanh = nn.Tanh()

    def forward(self, decoder_output, decoder_hidden, encoder_output, attention_mask):

        batch_size, seq_length, hidden_size = encoder_output.size()

        encoder_output_t = encoder_output.transpose(1,2)
        
        attention_scores = torch.bmm(decoder_output, encoder_output_t).squeeze(1)

        attention_scores.masked_fill_((attention_mask), -10e5)
        attention_weights = self.softmax(attention_scores)

        mix = torch.bmm(attention_weights.unsqueeze(1), encoder_output)

        combined = torch.cat((decoder_output.squeeze(1), mix.squeeze(1)), dim=1)

        output = self.linear_out(combined).unsqueeze(1)
        output = self.tanh(output)

        return output, attention_weights
    
    
class seq2seq(nn.Module):
    """
    Generic seq2seq model with attention mechanism.
    """
    def __init__(self, opts):

        super().__init__()
        self.opts = opts
        
        self.decoder = DecoderRNN(
                                    vocab_size=self.opts['vocab_size'],
                                    embed_size=self.opts['embedding_size'],
                                    hidden_size=self.opts['hidden_size'],
                                    num_layers=self.opts['num_layers_dec'],
                                    dropout=self.opts['dropout'],
                                )
        
        self.encoder = EncoderTransformer(
                                    vocab_size=self.opts['vocab_size'],
                                    max_len = 10000,shared_lt = self.decoder.embedding,dropout=self.opts["dropout"]
                                    #embed_size=self.opts['embedding_size'],
                                    #hidden_size=self.opts['hidden_size'],
                                    #num_layers=self.opts['num_layers_enc'],
                                    #dropout=self.opts['dropout'],
                                    #shared_lt=self.decoder.embedding
        )
        
    def train(self):
        self.encoder.train()
        self.decoder.train()
        
    def eval(self):
        self.encoder.eval()
        self.decoder.eval()

In [14]:
model_pt = torch.load("transformer_model_best_5.pt")
model = seq2seq(model_pt["opts"])
model.load_state_dict(model_pt["state_dict"])
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
    current_device = 'cuda'
else:
    current_device = 'cpu'
model.to(current_device)

seq2seq(
  (decoder): DecoderRNN(
    (dropout): Dropout(p=0.3, inplace=False)
    (embedding): Embedding(18760, 256, padding_idx=0)
    (gru): GRU(256, 256, batch_first=True)
    (attention): AttentionLayer(
      (linear_out): Linear(in_features=512, out_features=256, bias=False)
      (softmax): Softmax(dim=-1)
      (tanh): Tanh()
    )
    (out): Linear(in_features=256, out_features=18760, bias=True)
  )
  (encoder): EncoderTransformer(
    (token_embed): Embedding(18760, 256, padding_idx=0)
    (position_embed): Embedding(10000, 256)
    (transformer): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1

In [0]:
import torch.nn.functional as F
def nucleus_sampling(logits,p=0.0,filter_value=-float('Inf')):
  # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = 0

    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    logits[indices_to_remove] = filter_value
    return logits


In [0]:
import numpy as np
def greedy_search(model, batch, batch_size,p):
    model.eval()
        
    text_vecs = batch['text_vecs'].to(current_device)
    encoded = model.encoder(text_vecs)
    
    encoder_output, encoder_hidden, attention_mask = encoded
        
    # 1 is __start__
    starts = torch.Tensor([1]).long().to(model.decoder.embedding.weight.device).expand(batch_size, 1).long()  # expand to batch size
    decoder_hidden = encoder_hidden

    # greedy decoding here        
    preds = [starts.view(-1)]
    scores = []
    log_prob = 0.0
    # track if each sample in the mini batch is finished
    # if all finished, stop predicting
    finish_mask = torch.Tensor([0]*batch_size).byte().to(model.decoder.embedding.weight.device)
    xs = starts
    _attn_w_log = []

    for ts in range(100):
        decoder_output, decoder_hidden, attn_w_log = model.decoder(xs.view(1,-1), decoder_hidden, encoded)  # decoder_output: [batch, time, vocab]
        temp = F.softmax(decoder_output.view(-1),dim=-1)
        logits = nucleus_sampling(decoder_output.view(-1),p=p)
        prob_dist = F.softmax(logits,dim=-1)
        _preds = torch.multinomial(prob_dist,1)
        _scores = temp[_preds.item()]
        log_prob += np.log(_scores.item())
        preds.append(_preds)
        _attn_w_log.append(attn_w_log)
        scores.append(_scores.view(-1)*(finish_mask == 0).float())

        finish_mask += (_preds == 2).byte().view(-1)
        if not (torch.any(~finish_mask.bool())):
            break
        
        xs = _preds
    preds = torch.cat(preds, dim=-1)
    preds = preds[1:-1]    
    return scores, preds, log_prob

In [17]:
def dump_input(inputs,chat_dict):
  inputs = inputs + " \n"
  RETOK.findall(inputs)
  return chat_dict.t2v(RETOK.findall(inputs))

inputs = "i am listening to music."
dump_input(inputs,chat_dict)
  

[5, 28, 333, 14, 75, 4, 11]

In [23]:
answers = ["i don't care.","i am a rapper.","what are you doing?", "i don't know man."]
raw_inputs = "your persona : i play basketball . \n your persona : i ' m a singer as my second job . \n your persona : i only eat vegetable . \n your persona : i was raised in a single parent household .  \n your persona : i am from mexico . \n your persona : i love america . \n hello how are doing today ?"
inputs = dump_input(raw_inputs,chat_dict)
turn = 0
isModel = True
INTERACTIVE = False
print(raw_inputs,"\n")
while True:
  test_batch = {
    'text_vecs': torch.tensor([inputs], dtype=torch.long, device=model.decoder.embedding.weight.device),
    'text_lens': torch.tensor([len(inputs)], dtype=torch.long)
}
  if isModel:
    scores, preds, log_prob = greedy_search(model, test_batch, 1,0.5)
    inputs = inputs + preds.tolist() + [chat_dict.word2ind["\n"]]
    print("model: ",chat_dict.v2t(preds.tolist()),"\n")
  else:
    if not INTERACTIVE:
      inputs = inputs + dump_input(answers[turn],chat_dict)
      print("human: ",answers[turn],"\n")
      turn += 1
      if turn == len(answers):
        break
    else:
      inp = input("Your response: \n")
      inputs = inputs + dump_input(inp,chat_dict)
      if inputs == "END":
        break
      print("human: ",inp,"\n")
      

    
  isModel = not isModel


your persona : i play basketball . 
 your persona : i ' m a singer as my second job . 
 your persona : i only eat vegetable . 
 your persona : i was raised in a single parent household .  
 your persona : i am from mexico . 
 your persona : i love america . 
 hello how are doing today ? 

model:  i ' m good . just hanging out with my wife and all . 

human:  i don't care. 

model:  i am a nurse , so i ' m a waiter in my hometown , i am an accountant . 

human:  i am a rapper. 

model:  that is a shame . i have been thinking about it in a bit of my car . 

human:  what are you doing? 

model:  i love watching the stars . 

human:  i don't know man. 



In [0]:
\