#  Finetune GPT-2 on Reddit Data

by [Max Woolf](http://minimaxir.com)

A variant of the [default notebook](https://colab.research.google.com/drive/1VLG8e7YSEwypxU-noRNhsv5dW4NfTGce) optimized for short-form titles. It is recommended to be familiar with that notebook before using this one.

This example uses a CSV export of Reddit data via BigQuery (see this post for more information).


In [0]:
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

[?25l[K     |▌                               | 10kB 21.0MB/s eta 0:00:01[K     |█                               | 20kB 26.1MB/s eta 0:00:01[K     |█▌                              | 30kB 6.5MB/s eta 0:00:01[K     |██                              | 40kB 8.2MB/s eta 0:00:01[K     |██▌                             | 51kB 9.8MB/s eta 0:00:01[K     |███                             | 61kB 11.2MB/s eta 0:00:01[K     |███▋                            | 71kB 12.5MB/s eta 0:00:01[K     |████                            | 81kB 13.7MB/s eta 0:00:01[K     |████▋                           | 92kB 8.9MB/s eta 0:00:01[K     |█████                           | 102kB 9.7MB/s eta 0:00:01[K     |█████▋                          | 112kB 9.7MB/s eta 0:00:01[K     |██████                          | 122kB 9.7MB/s eta 0:00:01[K     |██████▋                         | 133kB 9.7MB/s eta 0:00:01[K     |███████▏                        | 143kB 9.7MB/s eta 0:00:01[K     |███████▋              

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## GPU

In [0]:
!nvidia-smi

Thu Nov 14 14:56:19 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Downloading GPT-2

The default query returns 1.3MB of data, so probably should only use `124M` GPT-2 to finetune. If working with more Reddity data, then migrate to `355M`.

In [0]:
gpt2.download_gpt2(model_name="355M")

Fetching checkpoint: 1.05Mit [00:00, 551Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 106Mit/s]                                                    
Fetching hparams.json: 1.05Mit [00:00, 356Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 1.42Git [00:11, 122Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 367Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 108Mit/s]                                                 
Fetching vocab.bpe: 1.05Mit [00:00, 227Mit/s]                                                       


## Mounting Google Drive

In [0]:
gpt2.mount_gdrive()

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


## Uploading a Text File to be Trained to Colaboratory

A single-column CSV is expected.

In [0]:
file_name = "WritingPrompts50KPlusAI 10-23-2019.csv"

If your text file is larger than 10MB, it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

In [0]:
gpt2.copy_file_from_gdrive(file_name)

## Finetune GPT-2

Providing a single-column CSV will automatically add `<|startoftext|>` and `<|endoftext|>` tokens appropriately.

Short form text is more likely to overfit, so train it with fewer steps than you would for longform content.

In [0]:
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=5000,
              restore_from='latest',
              run_name='run42K',
              print_every=10,
              sample_every=5000,
              )

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt


100%|██████████| 1/1 [00:00<00:00, 12.09it/s]

Loading dataset...





dataset has 1065179 tokens
Training...
[10 | 44.49] loss=2.14 avg=2.14
[20 | 75.63] loss=2.30 avg=2.22
[30 | 106.76] loss=2.49 avg=2.31
[40 | 137.89] loss=2.39 avg=2.33
[50 | 169.04] loss=2.65 avg=2.40
[60 | 200.18] loss=2.38 avg=2.39
[70 | 231.34] loss=2.45 avg=2.40
[80 | 262.41] loss=2.32 avg=2.39
[90 | 293.48] loss=2.37 avg=2.39
[100 | 324.54] loss=2.36 avg=2.39
[110 | 355.62] loss=2.57 avg=2.40
[120 | 386.69] loss=2.00 avg=2.37
[130 | 417.76] loss=2.23 avg=2.36
[140 | 448.87] loss=2.49 avg=2.37
[150 | 479.96] loss=2.12 avg=2.35
[160 | 511.05] loss=2.13 avg=2.33
[170 | 542.13] loss=2.65 avg=2.35
[180 | 573.24] loss=2.20 avg=2.35
[190 | 604.35] loss=2.30 avg=2.34
[200 | 635.51] loss=2.32 avg=2.34
[210 | 666.55] loss=2.37 avg=2.34
[220 | 697.67] loss=2.55 avg=2.35
[230 | 728.77] loss=2.37 avg=2.35
[240 | 759.89] loss=2.63 avg=2.37
[250 | 791.04] loss=2.40 avg=2.37
[260 | 822.09] loss=2.16 avg=2.36
[270 | 853.19] loss=2.25 avg=2.35
[280 | 884.38] loss=2.31 avg=2.35
[290 | 915.55] loss=

In [0]:
gpt2.copy_checkpoint_to_gdrive(run_name='run42K')

## Load a Trained Model Checkpoint

In [0]:
gpt2.copy_checkpoint_from_gdrive(run_name='run42K')

In [0]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run42K')

Loading checkpoint checkpoint/run42K/model-5000
INFO:tensorflow:Restoring parameters from checkpoint/run42K/model-5000


## Generate Text From The Trained Model

Same as normal generate functions, except with additional parameters to handle the new tokens.

In [0]:
gpt2.generate(sess, run_name='run42K',
             length=100,
             prefix="<|startoftext|>",
             truncate="<|endoftext|>",
             include_prefix=False)

In [0]:
gpt2.generate(sess, run_name='run42K',
              length=100,
              temperature=.7,
              nsamples=10,
              batch_size=10,
              prefix="[WP] You discover an alternate reality",
              truncate="<|endoftext|>",
              include_prefix=True
              )

[WP] You discover an alternate reality where you are the "chosen one" and the council of gods keeps changing the world to fit your fancy. Unfortunately, the competition is pretty boring so you settle on this one...fake?
[WP] You discover an alternate reality where you are rich and famous for your charitable giving, and people are genuinely touched by your generosity. A genuine sense of joy and satisfaction builds up in your heart. This is your first "life".
[WP] You discover an alternate reality where humans and other animals live together in harmony. You're all set to become extinct but the species that first separated humans from the rest of the world manages to persist and gain a foothold in another part of the universe. This is your life now."
[WP] You discover an alternate reality where people can bank sleep and never wake up to it.com is a very real and very convenient method for getting things done in the real world.
[WP] You discover an alternate reality where humans ruled the 

If generating in bulk, you may want to set `sample_demin=''` to remove the delimiter between each sample.

In [0]:
gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

gpt2.generate_to_file(sess,
                      destination_path=gen_file,
                      length=100,
                      temperature=1.0,
                      nsamples=100,
                      batch_size=20,
                      prefix="<|startoftext|>",
                      truncate="<|endoftext|>",
                      include_prefix=False,
                      sample_delim=''
                      )

In [0]:
# may have to run twice to get file to download
files.download(gen_file)

# Etcetera

If the notebook has errors (e.g. GPU Sync Fail), force-kill the Colaboratory virtual machine and restart it with the command below:

In [0]:
!kill -9 -1

# LICENSE

MIT License

Copyright (c) 2019 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.