In this notebook I load and train the GPT-2 model by OpenAI on various datasets: scientific abstracts, news and novels. I save the trained models (checkpoints) for further analysis. I then generate chunk of text for further analysis. I rely on the *gpt-2-simple* library for finetuning and generating the text.

## Loading the baseline GPT-2 model and datasets

In [None]:
%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

TensorFlow 1.x selected.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



Firstly, I download the model. I choose the small version with 124M parameters for simplicity.

In [None]:
gpt2.download_gpt2(model_name="124M")

Fetching checkpoint: 1.05Mit [00:00, 603Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 3.23Mit/s]
Fetching hparams.json: 1.05Mit [00:00, 325Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 498Mit [00:14, 34.5Mit/s]
Fetching model.ckpt.index: 1.05Mit [00:00, 385Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 3.89Mit/s]
Fetching vocab.bpe: 1.05Mit [00:00, 3.86Mit/s]


I mount the google drive to be able to load the datasets and save checkpoints later:

In [None]:
gpt2.mount_gdrive()

Mounted at /content/drive


In [None]:
for d in ['abstracts', 'news', 'novels']:
  gpt2.copy_file_from_gdrive('data_' + d + '.txt')

## Finetuning the model on a custom dataset

For finetuning I use Adam optimizer, learning rate 0.0001 and follow 400 steps  (these and other parameters are chosen as optimal based on the model's performance). 



In [None]:
def finetune_on_dataset(d):
  session = gpt2.start_tf_sess()
  gpt2.finetune(session,
                dataset='data_' + d + '.txt',
                model_name='124M',
                steps=400,
                restore_from='latest',
                run_name='run_' + d,
                learning_rate=0.0001,
                optimizer='adam',
                print_every=10,
                sample_every=100,
                sample_length=900,
                save_every=400
                )

  gpt2.copy_checkpoint_to_gdrive(run_name='run_' + d)

In [None]:
#d = 'abstracts', 'novels', 'news'
finetune_on_dataset('news')

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Loading checkpoint models/124M/model.ckpt
INFO:tensorflow:Restoring parameters from models/124M/model.ckpt


  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset...


100%|██████████| 1/1 [00:01<00:00,  1.39s/it]


dataset has 210264 tokens
Training...
[10 | 28.76] loss=2.45 avg=2.45
[20 | 51.72] loss=2.67 avg=2.56
[30 | 74.90] loss=2.92 avg=2.68
[40 | 97.77] loss=2.73 avg=2.69
[50 | 120.59] loss=2.20 avg=2.59
[60 | 143.58] loss=2.45 avg=2.57
[70 | 166.57] loss=2.01 avg=2.49
[80 | 189.47] loss=2.12 avg=2.44
[90 | 212.38] loss=1.90 avg=2.38
[100 | 235.35] loss=1.86 avg=2.32
 of the European Union and the Netherlands. 
  
 
  
  
 
 

The European Union negotiated a free trade agreement with 11 countries on Thursday night, which the United States administration hailed as an important step toward a better trade deal. Brazil and Mexico announced that they would join the deal, with the United Kingdom announcing on Tuesday it would be part of the deal by the end of the year. 
 

 

 


 


For more than a century Mexico has been Mexico City’s most central trading partner. It’s the only city in the United States that lacks a publicly traded port, allowing it to export goods that come from countries such

## Loading a trained model and generating the text


In [None]:
def load_trained(d,s):
  gpt2.copy_checkpoint_from_gdrive(run_name='run_' + d)
  gpt2.load_gpt2(s, run_name='run_' + d)


In [None]:
def generate(d,s):
  gpt2.generate_to_file(s, run_name='run_' + d, destination_path = 'gen_' + d + '.txt', temperature = 0.7, nsamples = 100, length = 100) 
  gpt2.generate_to_file(s, run_name='run_' + d, destination_path = 'gen_' + d + '_abs.txt',  temperature = 0.7, nsamples = 10, length = 100, prefix = "Complex classification problems") 
  gpt2.generate_to_file(s, run_name='run_' + d, destination_path = 'gen_' + d + '_new.txt',  temperature = 0.7, nsamples = 10, length = 100, prefix = "Many argue that european policy of") 
  gpt2.generate_to_file(s, run_name='run_' + d, destination_path = 'gen_' + d + '_nov.txt',  temperature = 0.7, nsamples = 10, length = 100, prefix = "My eyes filled with tears") 
  gpt2.generate_to_file(s, run_name='run_' + d, destination_path = 'gen_' + d + '_neutral.txt',  temperature = 0.7, nsamples = 10, length = 100, prefix = "It is surprising") 

In [None]:
session = gpt2.start_tf_sess()
load_trained('news', session)
generate('news', session)

Loading checkpoint checkpoint/run_news/model-400
INFO:tensorflow:Restoring parameters from checkpoint/run_news/model-400


In [None]:
gpt2.generate(session, run_name='run_' + 'news',  temperature = 0.7, nsamples = 10, length = 100, prefix = "Complex classification problems") 

Complex classification problems mean that even information about a child’s social class — a proxy for whether he or she is social — is not exact. Experts say there is scope for misclassification even in very poor countries. In a survey of 1,500 adults conducted for a newspaper in the Democratic Republic of the Congo in March, FIOS found that nearly three-quarters of the respondents had difficulty establishing their social class based on only a limited view of their nearest neighbors. Nearly half said they had difficulty establishing a
Complex classification problems allow us to store information about the people we want to interview at a time, but not what they say, except perhaps when they “” “SLAM!” or something along those lines. To make a request, contact a source listed below or come across our San Francisco office via our mobile app. We love hearing what you think about programming and ad tech, so feel free to let us know what you think. Also, if you decide to attend the conferen