* The base of this code is from https://github.com/openai/gpt-2
* you should clone the repository and run this notebook from /src folder

In [None]:
# download the model with 774 million weight
!python ../download_model.py 774M

In [2]:
import fire
import json
import os
import numpy as np
import tensorflow as tf

import model, sample, encoder

In [3]:
def interact_model(
    querys,
    model_name='774M',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
    top_p=1,
    models_dir='../models',
):
    """
    Interactively run the model
    :model_name=124M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
     :models_dir : path to parent folder containing model subfolders
     (i.e. contains the <model_name> folder)
    """
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)
        texts = []
        for query in querys:
            context_tokens = enc.encode(query)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={context: [context_tokens for _ in range(batch_size)]}
                              )[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    texts.append(text)
        return texts

In [4]:
querys = ["Poland could have to leave the EU over its judicial reform proposals, the country's Supreme Court has warned.",
        "This is the face of a woman who lived 6,000 years ago in Scandinavia.",
        "The Trump administration has said it does not consider the mass killings of Armenians in 1915 to be a genocide, contradicting a unanimous vote by the US Senate.",
        "Toronto's Danforth Avenue shooting victims have launched a class action lawsuit against US gun maker Smith & Wesson.",
        "A man who built an exploding glitter bomb last Christmas which went viral online has added stronger smells and a police soundtrack to his latest version.",
        "A gin company has been ordered to pay Dame Vera Lynn £1,800 in legal costs after losing a case to trademark the singer's name for its drink.",
        "A woman in southern Germany has taken a cheese shop to court over her right to display signs complaining about the smell.",
        "Supreme Court Justice Ruth Bader Ginsburg has responded to Donald Trump's call for the top US court to stop impeachment.",
        "President Donald Trump has lashed out over his impending impeachment in an irate letter to top Democrat Nancy Pelosi, accusing her of declaring 'open war on American democracy'.",
        "The Pope has declared that the rule of 'pontifical secrecy' no longer applies to the sexual abuse of minors, in a bid to improve transparency in such cases."]

In [None]:
# generate 100 text, with the seeds above
print('before starting' + strftime("%H:%M:%S", localtime()))

with open("../generated_texts.txt", "a+") as f:
    for i in range(10):
        for text in interact_model(querys, model_name="774M"):
            f.write(text.replace("\n", " ")+"\n")
            print(str(i)+'. loop: '+ strftime("%H:%M:%S", localtime()))

before starting09:19:08
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
Instructions for updating:
Use tf.random.categorical instead.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ../models/774M/model.ckpt
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
0. loop: 09:24:54
INFO:tensorflow:Restoring parameters from ../models/774M/model.ckpt
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
1. loop: 09:30:12
INFO:tensorflow:Restoring parameters from ../models/774M/model.ckpt
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:30
2. loop: 09:35:31