Let’s write some python code to use all four of the released models for generating text. That will let us see how the changes in capacity related to the quality of the text produced.

We download the GPT-2 library from OpenAI.

The OpenAI codebase has a list of other libraries that it requires, which is handled by installing requirements.txt. We go to the appropriate file, requirements.txt, and install those libraries.

Then, we download four different pre-trained models OpenAI made available, each roughly double in size from the previous.

In [None]:
!git clone https://github.com/openai/gpt-2.git
import os
os.chdir("gpt-2")
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 
warnings.filterwarnings('ignore')
%tensorflow_version 1.x
!pip3 install -r requirements.txt
!python3 download_model.py 124M
!python3 download_model.py 345M
!python3 download_model.py 774M
!python3 download_model.py 1558M

fatal: destination path 'gpt-2' already exists and is not an empty directory.
TensorFlow 1.x selected.
Fetching checkpoint: 1.00kit [00:00, 808kit/s]                                                      
Fetching encoder.json: 1.04Mit [00:00, 36.3Mit/s]                                                   
Fetching hparams.json: 1.00kit [00:00, 804kit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 498Mit [00:09, 52.5Mit/s]                                  
Fetching model.ckpt.index: 6.00kit [00:00, 5.35Mit/s]                                               
Fetching model.ckpt.meta: 472kit [00:00, 32.3Mit/s]                                                 
Fetching vocab.bpe: 457kit [00:00, 34.4Mit/s]                                                       
Fetching checkpoint: 1.00kit [00:00, 771kit/s]                                                      
Fetching encoder.json: 1.04Mit [00:00, 39.0Mit/s]                                        

Next we import some addtional libraries we'll be using in this notebook.

In [None]:
!export PYTHONIOENCODING=UTF-8
os.chdir('src')

!pip install tensorflow=='1.15.2'
import model, sample, encoder
import json
import numpy as np
import tensorflow as tf

Collecting gast==0.2.2
  Downloading https://files.pythonhosted.org/packages/4e/35/11749bf99b2d4e3cceb4d55ca22590b0d7c2c62b9de38ac4a4a7f4687421/gast-0.2.2.tar.gz
Building wheels for collected packages: gast
  Building wheel for gast (setup.py) ... [?25l[?25hdone
  Created wheel for gast: filename=gast-0.2.2-cp36-none-any.whl size=7540 sha256=2b749963aa9409cdb84d20c357f6890fa518f8ef9c9428e7b57605ee64dd9d01
  Stored in directory: /root/.cache/pip/wheels/5c/2e/7e/a1d4d4fcebe6c381f378ce7743a3ced3699feb89bcfbdadadd
Successfully built gast
Installing collected packages: gast
  Found existing installation: gast 0.3.3
    Uninstalling gast-0.3.3:
      Successfully uninstalled gast-0.3.3
Successfully installed gast-0.2.2


We define an `autocomplete` function that returns the next `length` number of words given the `model_name` and the `raw_text` input.

We set up a session for talking to the tensorflow backend. We also create a place for the output of the model to go. We checkpoint the tensorflow backend so we can establish the link to our code.Once all of that is set up, we can send our text prompt to the model for processing. We pull out the output of the model and return the string.

In [None]:
# Return-a-string version

def autocomplete(model_name, raw_text, length):
    batch_size = 1
    temperature = 1
    top_k = 0
    models_dir = '../models'
    seed = None
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        context_tokens = enc.encode(raw_text)
        out = sess.run(output, feed_dict={
                context: [context_tokens]
        })[:, len(context_tokens):]
        text = enc.decode(out[0])
    return(text)

Below is an example of our `autocomplete` function, printing out the next 10 predicted words.

In [None]:
print(autocomplete('124M', "Learning about machine learning is kind of like", 10))





Instructions for updating:
Use `tf.cast` instead.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Use `tf.random.categorical` instead.
INFO:tensorflow:Restoring parameters from ../models/124M/model.ckpt
 yesterday controlling an electric car in the 1970s.


Here show how the predictions for a given phrase changes with the number of parameters in the model.

In [None]:
for gpt2model in ['124M', '345M', '774M', '1558M']:
  print(gpt2model, autocomplete(gpt2model, "My first time visiting the ocean, I marveled at", 20))

INFO:tensorflow:Restoring parameters from ../models/124M/model.ckpt
124M  it on a daily basis. The ocean is surrounded by fog, and you have fog gods; they
INFO:tensorflow:Restoring parameters from ../models/345M/model.ckpt
345M  many astonishing colors — reds with rings of orange, magenta with gold, little green gray sites
INFO:tensorflow:Restoring parameters from ../models/774M/model.ckpt
774M  it under my umbrella.

Compared to living in Japan, Italy is essentially the smallest nation in
INFO:tensorflow:Restoring parameters from ../models/1558M/model.ckpt
1558M  it in awe of the overwhelming force of the water. But I was so shy about revealing myself to
