# First: Download a pre-trained model

In [None]:
!git clone https://github.com/graykode/gpt-2-Pytorch
%cd gpt-2-Pytorch
!curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin

# Import the required libraries to load and execute the model

In [None]:
import torch
import random
import numpy as np
from GPT2.model import (GPT2LMHeadModel)
from GPT2.utils import load_weight
from GPT2.config import GPT2Config
from GPT2.sample import sample_sequence
from GPT2.encoder import get_encoder

state_dict = torch.load('gpt2-pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None)

# Define a pipeline to process the input and run the model with it

In [None]:
# text: is the text input that will go in the model
# length: how many words we want to predict
# top-k: returns the k highest probabilities, this represents how specific you want the predictions to be
def text_generator(text, length = -1, top_k = 40):

  # these can be modified. For advanced testing
  nsamples = 1
  temperature = 0.7
  batch_size = 1

  assert nsamples % batch_size == 0

  seed = random.randint(0, 2147483647)
  np.random.seed(seed)
  torch.random.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  # Load Model
  enc = get_encoder()
  config = GPT2Config()
  model = GPT2LMHeadModel(config)
  model = load_weight(model, state_dict)
  model.to(device)
  model.eval()

  if length == -1:
      length = config.n_ctx // 2
  elif length > config.n_ctx:
      raise ValueError("Can't get samples longer than window size: %s" % config.n_ctx)

  context_tokens = enc.encode(text)

  generated = 0
  for _ in range(nsamples // batch_size):
    out = sample_sequence(
      model = model, length=length,
      context = context_tokens,
      start_token = None,
      batch_size = batch_size,
      temperature = temperature, top_k=top_k, device=device
    )
    out = out[:, len(context_tokens):].tolist()
    print("\n"+text)
    for i in range(batch_size):
      generated += 1
      text = enc.decode(out[i]).replace("<|endoftext|>", ". ")
      print(text)

In [None]:
text_generator("I usually walk my dog in the morning", length=30, top_k=25)