In [1]:
import re

import torch
import numpy as np
import tensorflow as tf
from collections import OrderedDict

In [2]:
m0 = torch.load('../eva-ckpt/222500/mp_rank_00_model_states.pt', map_location='cpu')

In [3]:
def find_weight(model, name):
    for n, w in model['module'].items():
        if name == n:
            return w, list(w.shape)

def combine(n, dim=0):
    return torch.cat([
        find_weight(x, n)[0]
        for x in (m0, m1, m2, m3)
    ], dim=dim)

In [4]:
from typing import Callable, Optional, List
import copy
import math

import numpy as np
import torch
import torch.nn.functional as F
import jieba
from tqdm import tqdm

from configuration_enc_dec import EncDecConfig
from tokenization_enc_dec import EncDecTokenizer
from model import TorchEncDecModel

In [5]:
# !pip install jieba --user
jieba.initialize()

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.571 seconds.
Prefix dict has been built successfully.


In [6]:
config = EncDecConfig(
    d_model=2048,
    d_ff=5120,
    d_kv=64,
    num_heads=32,
    num_layers=24,
    num_decoder_layers=24,
    dropout_rate=0.0,
    feed_forward_proj="relu",
    init_method_std=0.001,
    initializer_factor=1.0,
    layer_norm_epsilon=1e-06,
    max_position_embeddings=512,
    use_cache=True,
    use_scaled_init_for_output_weights=True,
    do_dim_trick=False
)
config.vocab_size = 30000
config.vocab_size 

30000

In [7]:
print('build model')
model = TorchEncDecModel(config)

build model


In [8]:
# To float16 and GPU
if torch.cuda.is_available():
    model = model.half().cuda()

In [9]:
# for k, v in model.state_dict().items():
#     print(k, v.shape)

In [10]:
print('load state')
model.load_state_dict(m0['module'])

load state


<All keys matched successfully>

In [11]:
print('to eval')
model = model.eval()

to eval


In [12]:
tokenizer = EncDecTokenizer('./EVA/src/bpe_dialog_new/vocab.txt')

In [13]:
def get_masks_and_position_ids(tokenizer,
                               contexts,
                               targets,
                               reset_position_ids,
                               reset_attention_mask):
    # Extract batch size and sequence length.
    batch_size, enc_seq_length = contexts.size()

    # Enc Attention mask.
    enc_attn_mask = torch.zeros(
        batch_size, 1, enc_seq_length, enc_seq_length, device=contexts.device)

    ctx_lengths = (contexts != tokenizer.pad_id).sum(1)
    for b in range(batch_size):
        enc_attn_mask[b, 0, :ctx_lengths[b], :ctx_lengths[b]] = 1

    # Enc Position ids.
    enc_pos_ids = torch.arange(
        enc_seq_length, dtype=torch.long, device=contexts.device)
    enc_pos_ids = enc_pos_ids.unsqueeze(0).expand_as(contexts)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        enc_pos_ids = enc_pos_ids.clone()

    batch_size, dec_seq_length = targets.size()
    # Dec Attention mask
    dec_attn_mask = torch.tril(torch.ones(
        batch_size, 1, dec_seq_length, dec_seq_length, device=targets.device))

    # Dec Position ids.
    dec_pos_ids = torch.arange(
        dec_seq_length, dtype=torch.long, device=targets.device)
    dec_pos_ids = dec_pos_ids.unsqueeze(0).expand_as(targets)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        dec_pos_ids = dec_pos_ids.clone()

    # Cross Attention Mask
    cross_attn_mask = torch.zeros(
        batch_size, 1, dec_seq_length, enc_seq_length, device=contexts.device)

    for b in range(batch_size):
        cross_attn_mask[b, 0, :, :ctx_lengths[b]] = 1

    model_batch = {
        "enc_attention_mask": enc_attn_mask,
        "enc_position_ids": enc_pos_ids,
        "dec_attention_mask": dec_attn_mask,
        "dec_position_ids": dec_pos_ids,
        "cross_attention_mask": cross_attn_mask,
    }

    return model_batch


def get_inference_batch(
        context_tokens,
        device,
        batch_size,
        target_length,
        tokenizer
    ):
    tokens = context_tokens
    tokens = tokens.view(batch_size, -1).contiguous()
    tokens = tokens.to(device)
    
    targets = torch.zeros(batch_size, target_length, dtype=torch.long, device=device) + tokenizer.get_sentinel_id(0)

    # Get the masks and postition ids.
    model_batch = get_masks_and_position_ids(
        tokenizer,
        tokens,
        targets,
        False, # args.reset_position_ids,
        False, # args.reset_attention_mask,
    )
    
    model_batch = {
        "enc_input_ids": tokens,
        "dec_input_ids": targets,
        **model_batch
    }

    return model_batch


def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-10000, remove_unk=False):
    # This function has been mostly taken from huggingface conversational ai code at
    # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313

    if remove_unk:
        logits[..., 0] = filter_value

    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    batch_size = logits.size()[0]
    if top_p > 0.0:
        logits=logits.view(batch_size, -1).contiguous()
        for logit in logits:
            sorted_logits, sorted_indices = torch.sort(logit, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the indices to the right to keep also the first token above the threshold
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            logit[indices_to_remove] = filter_value

        logits=logits.view(batch_size, -1).contiguous()

    return logits


def calc_banned_ngram_tokens(prev_input_ids, num_hypos: int, no_repeat_ngram_size: int, cur_len: int, vocab_size: int):
    generated_ngrams = [{tuple([23]):[33, 31], tuple([31]):[123]} for _ in range(num_hypos)]
    def _get_generated_ngrams(hypo_idx):
        # Before decoding the next token, prevent decoding of ngrams that have already appeared
        start_idx = cur_len + 1 - no_repeat_ngram_size
        ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
        penalty_idx = tuple(prev_input_ids[hypo_idx, cur_len - 1: cur_len].tolist())
        return generated_ngrams[hypo_idx].get(ngram_idx, []) + generated_ngrams[hypo_idx].get(penalty_idx, [])

    if cur_len + 1 < no_repeat_ngram_size:
        if cur_len > 0:
            return [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]
    #generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
            if any(e >= vocab_size for e in ngram):
                continue
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
    return banned_tokens


In [30]:
def generate_samples(
    model, tokenizer, sents,
    device='cpu',
    max_length=100,
    temperature=0.7,
    top_k=0,
    top_p=0.9,
    no_repeat_ngram_size = 3,
    repetition_penalty = 1.2
):
    batch_size = 1
    model.eval()
    with torch.no_grad():
        
        all_input_tokens = []
        for sent in sents:
            all_input_tokens.extend(tokenizer.encode(sent) + [tokenizer.sep_id])
        all_input_tokens.extend([tokenizer.get_sentinel_id(0)])

        input_len = len(all_input_tokens)
        length_tensor = torch.tensor([input_len], dtype=torch.long).to(device)
        token_tensor = torch.tensor(all_input_tokens, dtype=torch.long).to(device)
        token_tensor = token_tensor.unsqueeze(0)

        target_length = max_length

        model_batch = get_inference_batch(token_tensor, device, batch_size, target_length, tokenizer)

        enc_input_ids = model_batch['enc_input_ids']
        enc_attention_mask = model_batch['enc_attention_mask']
        enc_position_ids = model_batch['enc_position_ids']

        enc_outputs = model(
            enc_input_ids=enc_input_ids,
            only_encoder=True,
            enc_attention_mask=enc_attention_mask,
            enc_position_ids=enc_position_ids
        )
        enc_hidden_states = enc_outputs["encoder_last_hidden_state"]

        # for generating responses
        # we only use the <go> token, so truncate other tokens
        dec_input_ids = model_batch['dec_input_ids'][..., :1]
        dec_attention_mask = model_batch['dec_attention_mask'][..., :1, :1]
        dec_position_ids = model_batch['dec_position_ids'][..., :1]
        # we use past_key_values, so only the current token mask is needed
        cross_attention_mask = model_batch['cross_attention_mask'][..., :1, :]

        unfinished_sents = enc_input_ids.new(enc_input_ids.size(0)).fill_(1)
        output_ids = enc_input_ids.new_zeros([enc_input_ids.size(0), 0])
        output_probs = torch.zeros(batch_size, 1).to(device)
        prob_idx = torch.arange(batch_size)
        past_key_values = None
        
        gen_len = 0
        while gen_len < target_length:

            dec_outputs = model(
                dec_input_ids=dec_input_ids,
                dec_position_ids=dec_position_ids,
                dec_attention_mask=dec_attention_mask,
                cross_attention_mask=cross_attention_mask,
                enc_hidden_states=enc_hidden_states,
                past_key_values=past_key_values,
            )
            lm_logits = dec_outputs['lm_logits']
            past_key_values = dec_outputs['past_key_values']

            logits = lm_logits[:, -1, :] / temperature

            prev_output_tokens = torch.cat([enc_input_ids, output_ids], dim=-1)

            # repetition_penalty
            if repetition_penalty != 1.0:
                for i in range(logits.size(0)):
                    for previous_token in set(prev_output_tokens[i].tolist()):
                        # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
                        if logits[i, previous_token] < 0:
                            logits[i, previous_token] *= repetition_penalty
                        else:
                            logits[i, previous_token] /= repetition_penalty

            # no_repeat_ngram_size
            if no_repeat_ngram_size > 0:
                banned_batch_tokens = calc_banned_ngram_tokens(
                    output_ids, logits.size(0), no_repeat_ngram_size, gen_len, logits.size(1)
                )
                for i, banned_tokens in enumerate(banned_batch_tokens):
                    logits[i, banned_tokens] = -1e5

            logits = top_k_logits(logits, top_k=top_k, top_p=top_p, remove_unk=True)
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
            next_prob = probs[prob_idx, next_token]
            tokens_to_add = next_token * unfinished_sents + tokenizer.sep_id * (1 - unfinished_sents)
            probs_to_add = next_prob * unfinished_sents
            output_probs = torch.cat([output_probs, probs_to_add.unsqueeze(-1)], dim=-1)

            dec_input_ids = tokens_to_add.unsqueeze(-1)
            output_ids = torch.cat([output_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
            dec_position_ids = dec_position_ids[:, -1:] + 1
            # let the current token attend to all previous tokens
            dec_attention_mask = torch.cat([dec_attention_mask, dec_attention_mask[:, :, :, -1:]], dim=-1)

            if tokenizer.sep_id == tokens_to_add.detach().cpu().numpy().flatten()[0]:
                return tokenizer.decode(
                    output_ids.detach().cpu().numpy()[0][:-1]
                )
            
            gen_len += 1
        
        return next_token

In [32]:
def chat():
    history = []
    max_history = 5
    while True:
        text = input('>>>')
        if text == 'quit' or len(text.strip()) <= 0:
            break
        history.append(text)
        history = history[-max_history:]
        out = generate_samples(model, tokenizer, history)
        print(out)
        history.append(out)

In [33]:
chat()

>>>你好
您好,有什么问题我可以帮您处理或解决呢?#e-s[数字x]
>>>


In [35]:
chat()

>>>你喜欢听相声吗？
我很少看电视,除了看电视的时候就没别的兴趣了。
>>>你喜欢郭德纲吗
不喜欢他的歌啊。
>>>郭德纲是干嘛的？
什么叫做相声演员,怎么会没有人知道呢?
>>>那你听过他的作品吗
当然了,不过我很少看相声,要么就是因为郭老师的一些事情才关注他的,你可以去搜一下他以前的节目,或者看看有哪些好玩的。
>>>那你知道他的徒弟吗
不清楚,但你要问他师傅就一定不是他师傅吧?
>>>


In [36]:
chat()

>>>今天好倒霉啊
我下午打球都痛到没知觉了,现在才睡醒
>>>我说我好倒霉啊
你这也太衰了吧。。。
>>>你不问问吗
真的很烦恼,我的天
>>>你烦恼什么
我还没睡,就是突然很想哭。。
>>>为什么想哭
因为没有人陪。。
>>>我陪着你啊
好吧,,晚安~
>>>


In [37]:
chat()

>>>我想要去土耳其
你不是说就快了嘛
>>>嗯，下个月去
哦哦!那还挺好的。就是太冷了。
>>>温度还好吧
也还好啊,冷是因为有暖气
>>>
