In [1]:
import sys
sys.path.insert(0, '../../')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Union

import candle

## (1) Code

In [3]:
def generate_text(model,
                  prompt: str,
                  n_tokens_to_generate: int,
                  beam_size: int = 10,
                  top_k: int = 50,
                  sample: bool = True,
                  stream_reply: bool = True,
                  stop_on_end_of_sentence: bool = True):
    """Given a conditioning prompt, generates N more tokens using beam search."""
    
    tokenizer = candle.models.gpt.GPT2BPETokenizer()
    if prompt == '':
        indices = candle.Tensor(np.array([tokenizer.token_to_index['<|endoftext|>']]))
    else:
        indices = candle.Tensor(np.array(tokenizer.encode(prompt)))

    generator = candle.nlp.beam_search_generation(model, indices, n_tokens_to_generate=n_tokens_to_generate,
                                                  top_k=top_k, beam_size=beam_size, temperature=1.0, sample=sample)

    answer = prompt
    for (tokens_generated, next_indices) in enumerate(generator):
        token = tokenizer.decode(next_indices)
        token = ''.join(token)
        
        if stop_on_end_of_sentence and tokens_generated > 0.5 * n_tokens_to_generate:
            end_of_sentence_chars = ['.', '!', '?']
            eos_indices = [i for (i, char) in enumerate(token) if char in end_of_sentence_chars]
            if len(eos_indices) > 0:
                token = token[:eos_indices[0] + 1]

                if stream_reply:
                    print(token, end='')
                answer += token
                break
            
        if stream_reply:
            print(token, end='')
        answer += token
    
    return answer


def start_conversation(model,
                       name: str = 'Human',
                       reply_length: int = 50,
                       beam_size: int = 10):
    print('|| ===============================================================||')
    print('|| Hello! You are now talking with GPT2 (type \'exit()\' to exit)   ||')
    print('|| ===============================================================||\n')
    while True:
        print(f'[{name}] ', end='')
        prompt = input()

        if prompt == 'exit()':
            break

        print('\n====================================================================================================\n')
        print('[GPT2] ', end='')
        generate_text(model,
                      prompt,
                      n_tokens_to_generate=reply_length,
                      beam_size=beam_size)
        print('\n====================================================================================================\n')

## (2) Initialize Model with Pre-trained Weights

In [4]:
# One of ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl']
#    gpt2:        124,439,808 params
#    gpt2-medium: 354,823,168 params
#    gpt2-large:  774,030,080 params
#    gpt2-xl:   1,557,611,200 params

model = candle.models.gpt.GPT.from_pretrained('gpt2')

model.summary()

Unnamed: 0,Layer Type,# Parameters
decoder_blocks.0,DecoderBlock,7087872
decoder_blocks.1,DecoderBlock,7087872
decoder_blocks.2,DecoderBlock,7087872
decoder_blocks.3,DecoderBlock,7087872
decoder_blocks.4,DecoderBlock,7087872
decoder_blocks.5,DecoderBlock,7087872
decoder_blocks.6,DecoderBlock,7087872
decoder_blocks.7,DecoderBlock,7087872
decoder_blocks.8,DecoderBlock,7087872
decoder_blocks.9,DecoderBlock,7087872


## (3) Have a conversation

In [None]:
start_conversation(model, name='John')