#  Finetune GPT-2 on Reddit Data

by [Max Woolf](http://minimaxir.com)

A variant of the [default notebook](https://colab.research.google.com/drive/1VLG8e7YSEwypxU-noRNhsv5dW4NfTGce) optimized for short-form titles. It is recommended to be familiar with that notebook before using this one.

This example uses a CSV export of Reddit data via BigQuery (see this post for more information).


In [0]:
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

[?25l[K     |▌                               | 10kB 23.8MB/s eta 0:00:01[K     |█                               | 20kB 27.6MB/s eta 0:00:01[K     |█▌                              | 30kB 33.2MB/s eta 0:00:01[K     |██                              | 40kB 4.4MB/s eta 0:00:01[K     |██▌                             | 51kB 5.4MB/s eta 0:00:01[K     |███                             | 61kB 6.3MB/s eta 0:00:01[K     |███▋                            | 71kB 7.2MB/s eta 0:00:01[K     |████                            | 81kB 7.1MB/s eta 0:00:01[K     |████▋                           | 92kB 7.9MB/s eta 0:00:01[K     |█████                           | 102kB 8.6MB/s eta 0:00:01[K     |█████▋                          | 112kB 8.6MB/s eta 0:00:01[K     |██████                          | 122kB 8.6MB/s eta 0:00:01[K     |██████▋                         | 133kB 8.6MB/s eta 0:00:01[K     |███████▏                        | 143kB 8.6MB/s eta 0:00:01[K     |███████▋                

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## GPU

In [0]:
!nvidia-smi

Tue Nov 12 21:41:41 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.50       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P0    31W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

## Downloading GPT-2

The default query returns 1.3MB of data, so probably should only use `124M` GPT-2 to finetune. If working with more Reddity data, then migrate to `355M`.

In [0]:
gpt2.download_gpt2(model_name="355M")

Fetching checkpoint: 1.05Mit [00:00, 374Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 86.5Mit/s]                                                   
Fetching hparams.json: 1.05Mit [00:00, 457Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 1.42Git [00:07, 201Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 358Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 85.2Mit/s]                                                
Fetching vocab.bpe: 1.05Mit [00:00, 110Mit/s]                                                       


## Mounting Google Drive

In [0]:
gpt2.mount_gdrive()

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


## Uploading a Text File to be Trained to Colaboratory

A single-column CSV is expected.

In [0]:
file_name = "AI Responses 10-23-2019.txt"

If your text file is larger than 10MB, it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

In [0]:
gpt2.copy_file_from_gdrive(file_name)

## Finetune GPT-2

Providing a single-column CSV will automatically add `<|startoftext|>` and `<|endoftext|>` tokens appropriately.

Short form text is more likely to overfit, so train it with fewer steps than you would for longform content.

In [0]:
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='355M',
              steps=9000,
              restore_from='latest',
              run_name='run42K',
              print_every=10,
              sample_every=9000,
              )

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt


  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset...


100%|██████████| 1/1 [00:11<00:00, 11.09s/it]


dataset has 1947763 tokens
Training...
[10 | 18.15] loss=2.88 avg=2.88
[20 | 26.85] loss=3.49 avg=3.19
[30 | 35.59] loss=3.00 avg=3.12
[40 | 44.30] loss=2.79 avg=3.04
[50 | 53.00] loss=3.21 avg=3.08
[60 | 61.71] loss=2.52 avg=2.98
[70 | 70.37] loss=3.11 avg=3.00
[80 | 79.08] loss=3.24 avg=3.03
[90 | 87.75] loss=2.90 avg=3.02
[100 | 96.46] loss=2.74 avg=2.99
[110 | 105.14] loss=3.16 avg=3.00
[120 | 113.84] loss=3.24 avg=3.03
[130 | 122.54] loss=3.04 avg=3.03
[140 | 131.22] loss=3.15 avg=3.04
[150 | 139.90] loss=2.96 avg=3.03
[160 | 148.59] loss=2.85 avg=3.02
[170 | 157.28] loss=3.13 avg=3.02
[180 | 165.96] loss=2.60 avg=3.00
[190 | 174.68] loss=3.18 avg=3.01
[200 | 183.36] loss=2.19 avg=2.96
[210 | 192.05] loss=2.53 avg=2.94
[220 | 200.74] loss=3.51 avg=2.97
[230 | 209.43] loss=3.32 avg=2.99
[240 | 218.15] loss=2.83 avg=2.98
[250 | 226.86] loss=3.10 avg=2.98
[260 | 235.56] loss=2.47 avg=2.96
[270 | 244.25] loss=3.27 avg=2.98
[280 | 252.92] loss=2.90 avg=2.97
[290 | 261.60] loss=3.07 avg

In [0]:
gpt2.copy_checkpoint_to_gdrive(run_name='run42K')

## Load a Trained Model Checkpoint

In [0]:
gpt2.copy_checkpoint_from_gdrive(run_name='run42K')

In [0]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run42K')

Loading checkpoint checkpoint/run42K/model-9000
INFO:tensorflow:Restoring parameters from checkpoint/run42K/model-9000


## Generate Text From The Trained Model

Same as normal generate functions, except with additional parameters to handle the new tokens.

In [0]:
gpt2.generate(sess, run_name='run42K',
             length=400,
             prefix="<|startoftext|>[WP] You upload your brain to a computer and find a document that says the word 'God' in the title.  [RESPONSE]",
             truncate="<|endoftext|>",
             include_prefix=True)

<|startoftext|>[WP] You upload your brain to a computer and find a document that says the word 'God' in the title.  [RESPONSE] They did not want to have their own God, but it was that title which made them unhappy. 
 
I was born an ordinary man, with a simple goal: to love his neighbor as he loved himself.  That was the middle-class upbringing my parents wanted for me.  If I had been raised in a gladiator's camp, a brothel, or a slaver's den, God would have frowned upon me.  But I was a soldier, fighting the disease that was plaguing my home planet.  That was the war.  If I had not been a soldier, fighting the disease, I would have been nothing but an observer.  My talents did not earn me favor with the Supreme Ruler, but He saw a man with talent like mine, and He saw a man with skill like mine.  That is how the conversation went when I asked to be his student.  It was the seventh grade when I asked to be a scout, and the Supreme Ruler was pleased with my performance.  He told me that 

In [0]:
gpt2.generate(sess, run_name='run42K',
              length=400,
              temperature=.7,
              nsamples=10,
              batch_size=10,
              prefix="<|startoftext|>[WP] Your alternative universe self has decided to tell you about the world that she built, and you're all her imagination. [RESPONSE]",
              truncate="<|endoftext|>",
              include_prefix=True
              )

<|startoftext|>[WP] Your alternative universe self has decided to tell you about the world that she built, and you're all her imagination. [RESPONSE]

The book was heavy, when I first picked it up.

It had to be heavy because it was a brand new, completely empty, book. That was a bit disappointing. It kind of reminded me of Harry Potter, which I had to say is probably a bit similar.

I suppose it shouldn't matter because I just want to read it.

I picked it up, opened it, and turned off the lights. I'm not sure why, but the book seemed to be implying that time had left me and that I was an immortal time traveler.

And that's when I noticed that the pages were blank.

I felt a bit of unease when I realized that I couldn't remember how I got here or what I was doing there. I was stuck in a loop, back in the library where I had gotten lost.

But then, I noticed that the loop was getting longer by the second page.

It was like I had been pulled into the loop, into the same physical place t

If generating in bulk, you may want to set `sample_demin=''` to remove the delimiter between each sample.

In [0]:
gen_file = 'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow())

gpt2.generate_to_file(sess,
                      destination_path=gen_file,
                      length=100,
                      temperature=1.0,
                      nsamples=100,
                      batch_size=20,
                      prefix="<|startoftext|>",
                      truncate="<|endoftext|>",
                      include_prefix=False,
                      sample_delim=''
                      )

In [0]:
# may have to run twice to get file to download
files.download(gen_file)

# Etcetera

If the notebook has errors (e.g. GPU Sync Fail), force-kill the Colaboratory virtual machine and restart it with the command below:

In [0]:
!kill -9 -1

# LICENSE

MIT License

Copyright (c) 2019 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.