# Setting things up
The following cell install all the necessary dependencies

In [19]:
!pip3 install --user -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


Import the packages needed for this task.

In [1]:
import gpt_2_simple as gpt2
import json
import os
import sys
import numpy as np
import argparse
import requests
import glob
import pickle
import pandas as pd
import re
import unicodedata
import csv
import zipfile
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s')
log = logging.getLogger(__name__)

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [2]:
run_name='run2'
data_path ='data/shakespeare.txt'
steps=1
length = 600 
temperature = 0.7
top_k = 0

We define a helper function for zip files

In [3]:
# helper to use code samples in zip file
def process_zip(name, regs, postfix,data_dir,min_length,max_length,preserve_form,num_samples):
    with open(os.path.join(output_dir, name + postfix + '.txt'), 'w+') as fh:
        with zipfile.ZipFile(os.path.join(data_dir, name + '.zip'), 'r') as z:
            cnt = 0
            for entry in z.namelist():
                text = z.read(entry).decode('utf-8')
                for reg, sub in regs.items():
                    text = re.sub(reg, sub, text, flags=re.DOTALL)
                if len(text) > min_length and len(text) <= max_length:
                    sample = text.strip() + "\n"
                    if preserve_form == 'true':
                        sample += "\n\n"
                    fh.write(sample)
                    cnt += 1
                if cnt >= num_samples:
                    break

The next cells will prepare the data sets we can use in this task.

In [4]:


# command line arguments parser
data_type='all'
data_dir = './datasets/'
output_dir = './data'
short_filename = 'true'
postfix = ''
num_samples = 1000
max_length = 2000
min_length = 10
preserve_lines = 'true'
preserve_form = 'false'

# form requires newlines to be preserved
if preserve_form == 'true':
    preserve_lines = 'true'

# collapsing sample into one line requires form not to be preserved
if preserve_lines == 'false':
    preserve_form = 'false'

# set postfix for output files if short-filename is false
if postfix != '':
    postfix = '_' + postfix
if short_filename == 'false':
    postfix += f'_n{num_samples}_min{min_length}_max{max_length}'
    if preserve_lines == 'false':
        postfix += '_nolines'
    else:
        postfix += '_lines'
    if preserve_form == 'false':
        postfix += '_noform'
    else:
        postfix += '_form'



### prepare tweet data set

In [23]:
# dataset from: https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi%3A10.7910%2FDVN%2FKJEBIL
if data_type in ['all','tweets']: # parse trump tweets
    print('prepare tweet data set...')
    df1 = pd.read_json(os.path.join(data_dir, 'realdonaldtrump-1.ndjson'), lines=True)
    df2 = pd.read_json(os.path.join(data_dir, 'realdonaldtrump-2.ndjson'), lines=True)
    df = pd.concat([df1, df2], sort=True)
    if preserve_lines == 'false':
        df.text = df.text.str.replace("\n"," ")
    if preserve_form == 'false':
        df.text = df.text.str.replace(r"https?://[^\s]+","")
    df['length'] = df.text.apply(len)
    filter = (df.text>'2017')&(df.text.str.startswith('RT')==False)&(df.length>min_length)
    df = df[filter]
    df.sample(num_samples).text.to_csv(os.path.join(output_dir, 'tweets' + postfix + '.txt'), index=False, header=False, quoting=csv.QUOTE_NONE, escapechar="\\", sep="\\")
    print('preparing tweet data set done.')



prepare tweet data set...
preparing tweet data set done.


### prepare chess data set

In [24]:
# dataset from: https://www.ficsgames.org/download.html | year: 2019, month: whole year, type: Standard (average rating > 2000)
if data_type in ['all','chess']: # parse chess games
    print('prepare chess data set...')
    with open(os.path.join(output_dir, 'chess' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'ficsgamesdb_2019_standard2000_nomovetimes_110541.pgn')) as fp:
           line = fp.readline()
           cnt = 0
           while line and cnt < num_samples:
               if line.startswith('1.'):
                   fh.write(line)
                   cnt += 1
               line = fp.readline()
    print('preparing chess data set done.')



prepare chess data set...
preparing chess data set done.


### prepare music data set

In [25]:
# dataset from: https://www.kaggle.com/raj5287/abc-notation-of-tunes/version/3
if data_type in ['all','music']: # parse abc songs
    print('prepare music data set...')
    with open(os.path.join(output_dir, 'music' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'abc_notation_songs.txt')) as fp:
            line = fp.readline()
            cnt = 0
            song = ""
            while line and cnt < num_samples:
                if len(line) < 2 or line[1:2] == ':':
                    if song != "":
                        fh.write(song + "\n")
                        cnt += 1
                        song = ""
                elif preserve_lines == 'false':
                    song += " " + line.strip()
                else:
                    fh.write(line.strip() + "\n")
                line = fp.readline()
    print('preparing music data set done.')



prepare music data set...
preparing music data set done.


### prepare shakespeare data set

In [5]:
# dataset from: https://www.kaggle.com/kingburrito666/shakespeare-plays
if data_type in ['all','shakespeare']: # parse shakespeare plays
    print('prepare shakespeare data set...')
    df = pd.read_csv(os.path.join(data_dir, 'shakespeare_data.csv'))
    if preserve_lines == 'false':
        df = df[df.Player!=''].groupby(['Play','PlayerLinenumber'],as_index=False).agg(' '.join)
    df.sample(num_samples).PlayerLine.to_csv(os.path.join(output_dir, 'shakespeare' + postfix + '.txt'), index=False, header=False, quoting=csv.QUOTE_NONE, escapechar="\\", sep="\\")
    print('preparing shakespeare data set done.')



prepare shakespeare data set...
preparing shakespeare data set done.


### prepare javascript data set

In [27]:
# dataset from: javascript files from https://www.sri.inf.ethz.ch/js150
if data_type in ['all','javascript']: # parse javascript files
    print('prepare javascript data set...')
    regexes = {}
    if preserve_form == 'false':
        regexes[r'(//[^\n]*)?\n|/\*.*?\*/'] = '\n'
        regexes[r'\n\s*\n'] = '\n'
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('javascript', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing javascript data set done.')



prepare javascript data set...
preparing javascript data set done.


### prepare typescript data set

In [28]:
# dataset from: typescript files collected from standard angular app
if data_type in ['all','typescript']: # parse typescript files
    print('prepare typescript data set...')
    regexes = {}
    if preserve_form == 'false':
        regexes[r'(//[^\n]*)?\n|/\*.*?\*/'] = '\n'
        regexes[r'\n\s*\n'] = '\n'
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('typescript', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing typescript data set done.')



prepare typescript data set...
preparing typescript data set done.


### prepare json data set

In [29]:
# dataset from: json files collected from standard angular app
if data_type in ['all','json']: # parse json files
    print('prepare json data set...')
    regexes = {}
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('json', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing json data set done.')



prepare json data set...
preparing json data set done.


### prepare html data set

In [30]:
# dataset from: https://www.kaggle.com/zavadskyy/lots-of-code, https://gist.github.com/VladislavZavadskyy/e31ab07b03a5c22b11982c49669a400b
if data_type in ['all','html']: # parse html
    print('prepare html data set...')
    with open(os.path.join(output_dir, 'html' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'html-dataset.txt')) as fp:
            data = fp.read()
            data = data.replace('<!DOCTYPE html>','\n<!DOCTYPE html>')
            lines = data.split('\n')
            cnt = 0
            sample = ""
            for line in lines:
                if line == "":
                    continue
                if sample != "" and line.startswith('<!DOCTYPE html>'):
                    fh.write(sample.strip() + "\n")
                    sample = ""
                    cnt += 1
                if cnt >= num_samples:
                    break
                line = re.sub(r'\s+', ' ', line)
                sample += line.strip() + " "
    print('preparing html data set done.')

prepare html data set...
preparing html data set done.


# Let's fine-tune the GPT-2 model!
Choose the number of steps the model will be fine-tuned for. You can adjust the parameters  to specifiy how often you get updates on the training process, how often samples of the current model are printed, and every how many steps the model is saved.

Beside the number of steps, these parameters do not influence the training. The model will be saved automatically when done fine-tuning with the amount of steps specified. You can stop the fine-tuning anytime and the current training state of the model will be saved.

In [3]:
def start_session(sess):
    try:
        gpt2.reset_session(sess)
    except:
        pass
    return gpt2.start_tf_sess()

def fine_tune(sess,run_name,data_path,steps, model_name='124M'):
    print(f'Run fine-tuning for run {run_name} using GPT2 model {model_name}...')
    if not os.path.isdir(os.path.join("models", model_name)):
        log.info(f"Downloading {model_name} model...")
        gpt2.download_gpt2(model_name=model_name)
    sess = start_session(sess)
    gpt2.finetune(sess=sess,dataset=data_path,checkpoint_dir='runs', model_name=model_name, run_name=run_name, steps=steps, sample_every=10, save_every=10)



#run_name='run1'
#data_path ='data/tweets.txt'
sess = None
fine_tune(sess,run_name,data_path,steps)

Run fine-tuning for run run2 using GPT2 model 124M...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


2020-07-13 12:08:32,493 [WARNI] [tensorflow  ]: From /home/jovyan/.local/lib/python3.6/site-packages/gpt_2_simple/src/sample.py:17: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Loading checkpoint models/124M/model.ckpt
INFO:tensorflow:Restoring parameters from models/124M/model.ckpt


2020-07-13 12:08:56,607 [INFO ] [tensorflow  ]: Restoring parameters from models/124M/model.ckpt
100%|██████████| 1/1 [00:00<00:00, 717.47it/s]

Loading dataset...
dataset has 11351 tokens
Training...





[1 | 69.35] loss=4.61 avg=4.61
[2 | 125.39] loss=4.61 avg=4.61
[3 | 182.37] loss=4.39 avg=4.54
[4 | 238.20] loss=4.44 avg=4.51
[5 | 292.90] loss=4.27 avg=4.46
[6 | 347.93] loss=4.36 avg=4.45
[7 | 401.99] loss=4.18 avg=4.41
[8 | 456.62] loss=4.24 avg=4.38
[9 | 511.50] loss=4.03 avg=4.34
[10 | 567.92] loss=4.06 avg=4.31
Saving runs/run2/model-10


# Text generation
We can now generate text mimiking the style of the learned samples.

You can play around with the three parameters `length`, `temperature`, and `top_k` to influnce the generated text. Further, you can provide a seed sequence that will be the beginning of the generated text.

Use the different data sets to explore how the fine-tuning works and what its' limits are. You can also use custom data sets. Just copy them to the data folder and specify the path above.

In [5]:
import math
import torch
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')
model.eval()
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')

def score(sentence):
    tokenize_input = tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
    loss=model(tensor_input, lm_labels=tensor_input)
    return math.exp(loss)

2020-07-13 12:23:59,630 [INFO ] [pytorch_pret]: Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
2020-07-13 12:23:59,642 [DEBUG] [urllib3.conn]: Starting new HTTPS connection (1): s3.amazonaws.com:443
2020-07-13 12:24:00,047 [DEBUG] [urllib3.conn]: https://s3.amazonaws.com:443 "HEAD /models.huggingface.co/bert/openai-gpt-pytorch_model.bin HTTP/1.1" 200 0
2020-07-13 12:24:00,052 [DEBUG] [urllib3.conn]: Starting new HTTPS connection (1): s3.amazonaws.com:443
2020-07-13 12:24:00,439 [DEBUG] [urllib3.conn]: https://s3.amazonaws.com:443 "HEAD /models.huggingface.co/bert/openai-gpt-config.json HTTP/1.1" 200 0
2020-07-13 12:24:00,444 [INFO ] [pytorch_pret]: loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin from cache at /home/jovyan/.pytorch_pretrained_bert/e45ee1afb14c5d77c946e66cb0fa70073a77882097a1a2cefd51fd24b172355e.e7ee3fcd07c695a4c9f31ca735502c090230d988de03202f7af9ebe1c3a4054c
2020-07-13 12:2

In [6]:
def generate(sess,run_name,length,temperature,top_k):

    message = "Social justice warior"
    text = gpt2.generate(sess=sess,checkpoint_dir='runs', run_name=run_name, prefix=message, length=length, temperature=temperature, top_k=top_k, return_as_list=True)
    print(text[0])
    print("-----------------------------")
    print(score(text[0][0:length-10]))
    
    metrics = {
    'metrics': [{
      'name': 'perplexity-score', # The name of the metric. Visualized as the column name in the runs table.
      'numberValue':score(text[0][0:length-10]), # The value of the metric. Must be a numeric value.
      'format': "RAW",   # The optional format of the metric. Supported values are "RAW" (displayed in raw format) and "PERCENTAGE" (displayed in percentage format).
      }]
    }
    with file_io.FileIO('/mlpipeline-metrics.json', 'w') as f:
        json.dump(metrics, f)
    
    
    
#length = 800 # { min:0, max:1000, step:5}
#temperature = 0.7 # { min:0, max:2, step:0.1}
#top_k = 0
sess = start_session(sess)
gpt2.load_gpt2(sess, checkpoint_dir='runs', run_name=run_name)
generate(sess,run_name,length,temperature,top_k)

Loading checkpoint runs/run2/model-10
INFO:tensorflow:Restoring parameters from runs/run2/model-10


2020-07-13 12:24:10,019 [INFO ] [tensorflow  ]: Restoring parameters from runs/run2/model-10


Social justice warior, and the greatest of the Missouri, the most valiant, the most valiant, and the most valiant.

We have now come to the place where the troops have come, and we must put them to death.

This is the word of the Lord, and if there be not a hundred to be made of this body,

but the most noble, the most fattening, the most gallant, the most bold.

And now, sir, let our red-hot fire stand, and I'll throw it into the river,

And go, my lord, get back your head.

And I know not what.

And so, in a day or two, and then, when the scepter has been cast,

Come, let's talk.

I'll tell you, sir, the state, and I'll tell you,

And I know not how to fly into this and that.

We have in his hand, I know not how to make myself.

Come, let's get you some good news, and say, 'I have a letter from you,

This is, I get you some rest, and you go,

for if my father's the king, he'll make his peace.'

And, though it be true, I'd rather take him as a friend

That's not in consideration.

And