<a href="https://colab.research.google.com/github/jamesramsay100/twitter_bot/blob/reformat_notebook/Train_a_GPT_2_Text_Generating_Model_w_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune the GPT-2 model and make tweet predictions

## Setup
Always run these cells (training or predictions)
* Imports
* Mount Google drive

In [1]:
%tensorflow_version 1.x
!pip install -q gpt-2-simple
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files
import json, time, os

TensorFlow 1.x selected.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [2]:
gpt2.mount_gdrive()

Mounted at /content/drive


## Loading and traning the model

In [3]:
def download_and_train(
    train_tweets_file: str,
    model_name: str,
    steps: int,
    run_name = f'run_{time.strftime("%Y%m%d_%H%M%S")}'
):
    # Downloading GPT-2 model
    if os.path.exists(f'models/{run_name}') == False:
        print("\n")
        print("Downloading GPT-2 model...")
        gpt2.download_gpt2(model_name=model_name)

    # Downloading training file
    print("\n")
    print("Downloading training file...")
    gpt2.copy_file_from_gdrive(train_tweets_file)

    # Starting GPT-2 session
    print("\n")
    print("Starting GPT-2 session...")
    global sess
    sess = gpt2.start_tf_sess()

    # Beginning training
    print("\n")
    print("Beginning training...")
    gpt2.finetune(
        sess,
        dataset=train_tweets_file,
        model_name=model_name,
        steps=steps,
        restore_from='fresh',
        run_name=run_name,
        print_every=10,
        sample_every=200,
        # save_every=500
    )

    # Saving model to Google Drive
    print("\n")
    print("Saving model to Google Drive...")
    gpt2.copy_checkpoint_to_gdrive(run_name=run_name)

    sess.close()

    return run_name

In [4]:
trained_model = download_and_train(
    train_tweets_file='tweets_clean.csv',
    model_name='355M',
    steps=8000,
    run_name = f'run_{time.strftime("%Y%m%d_%H%M%S")}'
)

Fetching checkpoint: 1.05Mit [00:00, 363Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 131Mit/s]                                                    
Fetching hparams.json: 1.05Mit [00:00, 720Mit/s]                                                    



Downloading GPT-2 model...



Fetching model.ckpt.data-00000-of-00001: 1.42Git [00:11, 121Mit/s]                                  
Fetching model.ckpt.index: 1.05Mit [00:00, 301Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 136Mit/s]                                                 
Fetching vocab.bpe: 1.05Mit [00:00, 154Mit/s]                                                       




Downloading training file...


Starting GPT-2 session...


Beginning training...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Please use tensorflow.python.ops.op_selector.get_backward_walk_ops.
Loading checkpoint models/355M/model.ckpt
INFO:tensorflow:Restoring parameters from models/355M/model.ckpt


100%|██████████| 1/1 [00:00<00:00, 36.65it/s]

Loading dataset...





dataset has 516059 tokens
Training...
[10 | 24.45] loss=4.01 avg=4.01
[20 | 42.77] loss=2.51 avg=3.25
[30 | 60.64] loss=2.37 avg=2.96
[40 | 77.79] loss=3.05 avg=2.98
[50 | 94.81] loss=3.77 avg=3.14
[60 | 112.06] loss=2.34 avg=3.00
[70 | 129.54] loss=2.61 avg=2.95
[80 | 147.02] loss=2.51 avg=2.89
[90 | 164.36] loss=2.69 avg=2.87
[100 | 181.63] loss=2.73 avg=2.85
[110 | 198.93] loss=2.56 avg=2.82
[120 | 216.26] loss=2.28 avg=2.78
[130 | 233.62] loss=3.11 avg=2.80
[140 | 250.98] loss=1.97 avg=2.74
[150 | 268.33] loss=2.27 avg=2.71
[160 | 285.65] loss=2.41 avg=2.69
[170 | 302.92] loss=2.77 avg=2.69
[180 | 320.21] loss=1.85 avg=2.64
[190 | 337.53] loss=1.59 avg=2.58
[200 | 354.86] loss=2.08 avg=2.55
|ofc<|endoftext|>
<|startoftext|>so i asked myself im having dinner with my buddy the wolf cub he asked if im a vegetarian<|endoftext|>
<|startoftext|>i would ask the wolf cub if his favorite color is blue if its red if its green it would still choose his correct color<|endoftext|>
<|startoftext

## Making predictions
Need to restart runtime before making predictions if training has been done in same runtime (needs fixing)

In [2]:
def make_predictions(
    run_name: str,
    num_tweets: int,
    start_phrase: str = None,
    output_file: str = f'tweet_predictions_{time.strftime("%Y%m%d_%H%M%S")}',
    download=False
):  
    # Download traing GPT-2 model
    if os.path.exists(f'checkpoint_{run_name}.tar') == False:
      print("\n")
      print("Downloading model...")
      gpt2.copy_checkpoint_from_gdrive(run_name=run_name)
      sess = gpt2.start_tf_sess()

      print("\n")
      print("Loading model...")
      gpt2.load_gpt2(sess, run_name=run_name)

    else:
      sess = gpt2.start_tf_sess()
      print("\n")
      print("Loading model...")
      try:
        gpt2.load_gpt2(sess, run_name=run_name)
      except ValueError:
        print("Didn't load model as it already exists...")

    # making predictions
    print("\n")
    print("Making predictions...")
    complete_preds = []
    raw_output = 'tweet_preds_raw_output.txt'
    i = 0
    while len(complete_preds) < num_tweets:
      
        gpt2.generate_to_file(
            sess,
            run_name=run_name,
            destination_path=raw_output,
            length=500,
            temperature=0.7,
            nsamples=5,
            batch_size=5,
            prefix=start_phrase,
            top_k=40
        )

        # open batch results
        with open(raw_output) as f:
            content = f.readlines()

        # parse raw output
        for entry in content:
            if entry.startswith('<|startoftext|>') and entry.endswith('<|endoftext|>\n'):
                complete_preds.append(entry[len('<|startoftext|>'):-len('<|endoftext|>\n')])

        print("\n")
        print(f"Tweet count after iteration {i} : {len(complete_preds)}")
        print(f"Sample tweet : {complete_preds[-1:]}")

        os.remove(raw_output)
        i=i+1

    # convert results to json file
    with open(output_file, 'w') as outfile:
      json.dump(complete_preds, outfile)

    if download:
      files.download(output_file)

    return output_file

In [3]:
trained_model="run_20210129_224253"
output_file = make_predictions(
    run_name=trained_model,
    num_tweets=1000,
    start_phrase="Klopp",
    download=True
)



Loading model...
Loading checkpoint checkpoint/run_20210129_224253/model-8000
INFO:tensorflow:Restoring parameters from checkpoint/run_20210129_224253/model-8000


Making predictions...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Tweet count after iteration 0 : 51
Sample tweet : ['I just asked Jordan Henderson if he d like to sit next to me on the bench He said Yeah that would make for a great photo']


Tweet count after iteration 0 : 106
Sample tweet : ['I just asked Jürgen if he was excited for the new season He said Yes I said Are you heeded by the older generation He said Yes I said Genuinely']


Tweet count after iteration 0 : 170
Sample tweet : ['I must say I quite like the sound of West Ham and their theme music Could they have played little bit football we don t know']


Tweet count after iteration 0 : 224
Sample tweet : ['Who said it Brendan Rodgers or Deluded Brendan Question 1 The problem with being a manager is it s like

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>