In [1]:
import torch
from transformers import *
from run_generation import top_k_top_p_filtering
#from run_generation import sample_sequence

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Encode a text inputs
text = "Happy Birthday to"
indexed_tokens = tokenizer.encode(text)

# Convert indexed tokens in a PyTorch tensor
tokens_tensor = torch.tensor([indexed_tokens])

# Load pre-trained model (weights)
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Set the model in evaluation mode to deactivate the DropOut modules
model.eval()

# If you have a GPU, put everything on cuda
#CUDA - Compute Unified Device Architecture
#.to() - Sets dtype or device to a tensor
tokens_tensor = tokens_tensor.to('cpu')
model.to('cpu')

# Predict all tokens
with torch.no_grad():
    outputs = model(tokens_tensor)
    #filter_outputs = top_k_top_p_filtering(logits=outputs, top_k=10, top_p=0.5)
    predictions = outputs[0]

# Get the predicted next sub-word
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])

# Print the predicted word
#print(predicted_text)

def inference(model = GPT2LMHeadModel, enc = GPT2Tokenizer, phrase= ''):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nsamples = 1
    length = 40
    temperature = 1.2
    top_k = 0
    top_p = 0.9
    batch_size = 1
    stop_token = [enc.encoder[x] for x in ('<|endoftext|>', '.', '?', '!')]
    assert nsamples % batch_size == 0

    if length == -1:
        length = model.config.n_ctx // 2
    elif length > model.config.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)

    context_tokens = enc.encode(phrase) if phrase else [enc.encoder['<|endoftext|>']]
    generated = 0
    out = sample_sequence(
        model=model, length=length,
        context=context_tokens,
        start_token=None,
        batch_size=batch_size,
        temperature=temperature, top_k=top_k, device=device,
        top_p=top_p,
        stop_token=stop_token
    )
    out = out[:, len(context_tokens):].tolist()
    return enc.decode(out[0])

def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0,
                    device='cuda', top_p=0, stop_token=[]):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
    prev = context
    output = context
    past = None

    count = 0
    with torch.no_grad():
        while count < length:
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_top_p_filtering(logits, top_p=top_p, top_k=top_k)
            probs = F.softmax(logits, dim=-1)
            prev = torch.multinomial(probs, num_samples=1)
            output = torch.cat((output, prev), dim=1)
            count += 1
            if prev in stop_token:
                break
    return output

print(inference(model=model,enc=tokenizer,phrase='Hi, There'))

10/17/2019 00:39:34 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /home/spvengat/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71
10/17/2019 00:39:34 - INFO - transformers.tokenization_utils -   loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /home/spvengat/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
10/17/2019 00:39:35 - INFO - transformers.configuration_utils -   loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /home/spvengat/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.085d5f6a8e7812ea05ff0e6ed0645ab2e75d80387ad55c1ad9806ee7

IndexError: index 10663 is out of bounds for dimension 0 with size 1