# Building Conversational AI with Transformers and Determined

In [1]:
# Imports
import random
import json
import warnings

from itertools import chain
from pprint import pformat
from attrdict import AttrDict

import torch
import torch.nn.functional as F

from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizerFast, GPT2LMHeadModel, GPT2TokenizerFast, cached_path
from data import SPECIAL_TOKENS, build_input_from_segments, add_special_tokens_
from utils import get_dataset, download_pretrained_model
from example_input import build_inputs

In [2]:
# Arguments
args = AttrDict()
args.dataset_path = ""
args.dataset_cache = "/root/.cache/"
args.model = "openai-gpt"
args.ckpt_uuid = "24f33c0f-d5fd-4cc3-8551-db48f32f8fc2"
args.max_history = 5
args.no_sample = False
args.max_length = 40
args.min_length = 1
args.temperature = 0.7
args.top_k = 0
args.top_p = 0.8

## Let's examine the data

In [3]:
PERSONACHAT_URL = "https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json"
print("Download dataset from {}".format(PERSONACHAT_URL))
personachat_file = cached_path(PERSONACHAT_URL)
with open(personachat_file, "r", encoding="utf-8") as f:
    dataset = json.loads(f.read())

Download dataset from https://s3.amazonaws.com/datasets.huggingface.co/personachat/personachat_self_original.json


In [4]:
print("There are a total of {} dialogues in the training set".format(len(dataset["train"])))

There are a total of 17878 dialogues in the training set


In [5]:
# Example
obs_ind = 10
sample = dataset["train"][obs_ind]
persona = sample["personality"]
history = sample["utterances"][0]["history"]
reply = sample["utterances"][0]["candidates"][-1][:-1]
print("Personality: ", persona)
print("Utterance: ", sample["utterances"][0])

Personality:  ['i love to sing .', 'i am a night owl .', "i'm a dancer .", 'i can play the piano .', "i'm a vegetarian ."]
Utterance:  {'candidates': ['that is so fun boys are awesome', 'hospitals are lame , you should make a run for it .', 'wow just finished reading ender s game , what a great book !', 'it is ok , we do other things like go to the park or zoo', 'what do police do for fun ? who says you gonna learn today .', 'hello there ! what are your hobbies ?', 'i am good ! waiting for my wife to get home .', 'whats your favorite color ? mine is purple .', 'how long have you been friends', 'i am not a big foodie . i prefer crafts , like whittling .', 'oh , ok sure . anything else ?', "i guess i'll make you a salad p", 'hi , how are you today', 'no siblings myself . are you at work ?', 'i bet . what do you do for fun ?', 'at this joint called the frog zone grill . its pretty chill . what about you ?', 'is the cat hairless ?', "www . cafepress . com lelesfashionshop1 is the link to m

In [6]:
# Input example
words, segments, position, sequence = build_inputs(
    [p.split(' ') for p in persona], 
    [h.split(' ') for h in history], 
    reply.split(' ')
)
print("Words: {}\n".format(words))
print("Segments: {}\n".format(segments))
print("Position: {}\n".format(position))

Words: ['<bos>', 'i', 'love', 'to', 'sing', '.', 'i', 'am', 'a', 'night', 'owl', '.', "i'm", 'a', 'dancer', '.', 'i', 'can', 'play', 'the', 'piano', '.', "i'm", 'a', 'vegetarian', '.', '<speaker2>', 'hi', 'i', 'am', 'sally', ',', 'i', 'live', 'with', 'my', 'sweet', 'dogs', 'in', 'taos', ',', 'new', 'mexico', '.', '<speaker1>', 'hi', '!', "i've", 'just', 'been', 'sitting', 'here', 'playing', 'the', 'piano', 'and', 'singing', 'alon', '<eos>']

Segments: ['<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker1>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>', '<speaker2>

## Interact with a trained model

### Load dataset and pretrained model and tokenizer

In [9]:
print("Get pretrained tokenizer and dataset")
tokenizer_class, model_class = (GPT2TokenizerFast, GPT2LMHeadModel) if args.model == 'gpt2' else (OpenAIGPTTokenizerFast, OpenAIGPTLMHeadModel)
tokenizer = tokenizer_class.from_pretrained(args.model)
model = model_class.from_pretrained(args.model)
add_special_tokens_(model, tokenizer)
dataset = get_dataset(tokenizer, args.dataset_path, args.dataset_cache)

Get pretrained tokenizer and dataset


Some weights of OpenAIGPTLMHeadModel were not initialized from the model checkpoint at openai-gpt and are newly initialized: ['lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
print("Load pretrained model from Determined checkpoint")
from determined.experimental import Determined
ckpt_path = Determined().get_checkpoint(args.ckpt_uuid).download()
ckpt = torch.load(ckpt_path + "/state_dict.pth")
model.load_state_dict(ckpt['models_state_dict'][0], strict=False)
model = model.cuda()
## Another way of loading a checkpoint from Determined that loads the actual Trial.  
## This is slower because it performs the full init for the trial, including tokenizing the dataset.
#ckpt = Determined().get_checkpoint(args.ckpt_uuid).load()
#model = ckpt.model.cuda()

Load pretrained model from Determined checkpoint


### Sample a personality and interact

In [6]:
# Functions for interacting with trained model
def top_filtering(logits, top_k=0., top_p=0.9, threshold=-float('Inf'), filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
            threshold: a minimal threshold to keep logits
    """
    assert logits.dim() == 1  # Only work for batch size 1 for now - could update but it would obfuscate a bit the code
    top_k = min(top_k, logits.size(-1))
    if top_k > 0:
        # Remove all tokens with a probability less than the last token in the top-k tokens
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # Back to unsorted indices and set them to -infinity
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value

    indices_to_remove = logits < threshold
    logits[indices_to_remove] = filter_value

    return logits


def sample_sequence(personality, history, tokenizer, model, args, current_output=None):
    special_tokens_ids = tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)
    if current_output is None:
        current_output = []

    for i in range(args.max_length):
        instance = build_input_from_segments(personality, history, current_output, tokenizer, with_eos=False)

        input_ids = torch.tensor(instance["input_ids"], device="cuda").unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"], device="cuda").unsqueeze(0)

        outputs = model(input_ids, token_type_ids=token_type_ids)
        logits = outputs.logits
        if isinstance(logits, tuple):  # for gpt2 and maybe others
            logits = logits[0]
        logits = logits[0, -1, :] / args.temperature
        logits = top_filtering(logits, top_k=args.top_k, top_p=args.top_p)
        probs = F.softmax(logits, dim=-1)

        prev = torch.topk(probs, 1)[1] if args.no_sample else torch.multinomial(probs, 1)
        if i < args.min_length and prev.item() in special_tokens_ids:
            while prev.item() in special_tokens_ids:
                if probs.max().item() == 1:
                    warnings.warn("Warning: model generating special token with probability 1.")
                    break  # avoid infinitely looping over special token
                prev = torch.multinomial(probs, num_samples=1)

        if prev.item() in special_tokens_ids:
            break
        current_output.append(prev.item())

    return current_output

def sample_personality(dataset, no_personality=False):
    if no_personality:
        return []    
    personalities = [dialog["personality"] for dataset in dataset.values() for dialog in dataset]
    personality = random.choice(personalities)
    print("Selected personality is: {}".format(' '.join(tokenizer.batch_decode(personality))))
    return personality

In [10]:
# Sampling profiles control the behavior of the sequence generation for the response.  
# "low", "medium", and "high" roughly correspond to coherence with 
# "low" generating more random response and "high" generating very similar responses to the history.
sampling_profiles = {
    'low': {'top_k': 180, 'top_p': 0.1, 'temperature': 1.9},
    'medium': {'top_k': 70, 'top_p': 0.5, 'temperature': 1.2},
    'high': {'top_k': 0, 'top_p': 0.9, 'temperature': 0.6},
    'custom': {'top_k': 1, 'top_p': 0.7, 'temperature': 1},
}

In [11]:
profile = 'medium'
for v in ['top_k', 'top_p', 'temperature']:
    args[v] = sampling_profiles[profile][v]
personality = sample_personality()

Selected personality is: i've perfect pitch. i've been published in the new yorker magazine. i'm a gourmet cook. as a child, i won a national spelling bee.


In [None]:
history = []
while True:
    raw_text = input(">>> ")
    while not raw_text:
        print('Prompt should not be empty!')
        raw_text = input(">>> ")
    history.append(tokenizer.encode(raw_text))
    with torch.no_grad():
        out_ids = sample_sequence(personality, history, tokenizer, model, args)
    history.append(out_ids)
    history = history[-(2*args.max_history+1):]
    out_text = tokenizer.decode(out_ids, skip_special_tokens=True)
    print(out_text)
