# Setting things up
The following cell install all the necessary dependencies

In [19]:
!pip3 install --user -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


Import the packages needed for this task.

In [20]:
import gpt_2_simple as gpt2
import json
import os
import sys
import numpy as np
import argparse
import requests
import glob
import pickle
import pandas as pd
import re
import unicodedata
import csv
import zipfile
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s')
log = logging.getLogger(__name__)

In [None]:
run_name='run1'
data_path ='data/tweets.txt'
steps=1
length = 800 
temperature = 0.7
top_k = 0

We define a helper function for zip files

In [21]:
# helper to use code samples in zip file
def process_zip(name, regs, postfix,data_dir,min_length,max_length,preserve_form,num_samples):
    with open(os.path.join(output_dir, name + postfix + '.txt'), 'w+') as fh:
        with zipfile.ZipFile(os.path.join(data_dir, name + '.zip'), 'r') as z:
            cnt = 0
            for entry in z.namelist():
                text = z.read(entry).decode('utf-8')
                for reg, sub in regs.items():
                    text = re.sub(reg, sub, text, flags=re.DOTALL)
                if len(text) > min_length and len(text) <= max_length:
                    sample = text.strip() + "\n"
                    if preserve_form == 'true':
                        sample += "\n\n"
                    fh.write(sample)
                    cnt += 1
                if cnt >= num_samples:
                    break

The next cells will prepare the data sets we can use in this task.

In [22]:


# command line arguments parser
data_type='all'
data_dir = './datasets/'
output_dir = './data'
short_filename = 'true'
postfix = ''
num_samples = 1000
max_length = 2000
min_length = 10
preserve_lines = 'true'
preserve_form = 'false'

# form requires newlines to be preserved
if preserve_form == 'true':
    preserve_lines = 'true'

# collapsing sample into one line requires form not to be preserved
if preserve_lines == 'false':
    preserve_form = 'false'

# set postfix for output files if short-filename is false
if postfix != '':
    postfix = '_' + postfix
if short_filename == 'false':
    postfix += f'_n{num_samples}_min{min_length}_max{max_length}'
    if preserve_lines == 'false':
        postfix += '_nolines'
    else:
        postfix += '_lines'
    if preserve_form == 'false':
        postfix += '_noform'
    else:
        postfix += '_form'



### prepare tweet data set

In [23]:
# dataset from: https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi%3A10.7910%2FDVN%2FKJEBIL
if data_type in ['all','tweets']: # parse trump tweets
    print('prepare tweet data set...')
    df1 = pd.read_json(os.path.join(data_dir, 'realdonaldtrump-1.ndjson'), lines=True)
    df2 = pd.read_json(os.path.join(data_dir, 'realdonaldtrump-2.ndjson'), lines=True)
    df = pd.concat([df1, df2], sort=True)
    if preserve_lines == 'false':
        df.text = df.text.str.replace("\n"," ")
    if preserve_form == 'false':
        df.text = df.text.str.replace(r"https?://[^\s]+","")
    df['length'] = df.text.apply(len)
    filter = (df.text>'2017')&(df.text.str.startswith('RT')==False)&(df.length>min_length)
    df = df[filter]
    df.sample(num_samples).text.to_csv(os.path.join(output_dir, 'tweets' + postfix + '.txt'), index=False, header=False, quoting=csv.QUOTE_NONE, escapechar="\\", sep="\\")
    print('preparing tweet data set done.')



prepare tweet data set...
preparing tweet data set done.


### prepare chess data set

In [24]:
# dataset from: https://www.ficsgames.org/download.html | year: 2019, month: whole year, type: Standard (average rating > 2000)
if data_type in ['all','chess']: # parse chess games
    print('prepare chess data set...')
    with open(os.path.join(output_dir, 'chess' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'ficsgamesdb_2019_standard2000_nomovetimes_110541.pgn')) as fp:
           line = fp.readline()
           cnt = 0
           while line and cnt < num_samples:
               if line.startswith('1.'):
                   fh.write(line)
                   cnt += 1
               line = fp.readline()
    print('preparing chess data set done.')



prepare chess data set...
preparing chess data set done.


### prepare music data set

In [25]:
# dataset from: https://www.kaggle.com/raj5287/abc-notation-of-tunes/version/3
if data_type in ['all','music']: # parse abc songs
    print('prepare music data set...')
    with open(os.path.join(output_dir, 'music' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'abc_notation_songs.txt')) as fp:
            line = fp.readline()
            cnt = 0
            song = ""
            while line and cnt < num_samples:
                if len(line) < 2 or line[1:2] == ':':
                    if song != "":
                        fh.write(song + "\n")
                        cnt += 1
                        song = ""
                elif preserve_lines == 'false':
                    song += " " + line.strip()
                else:
                    fh.write(line.strip() + "\n")
                line = fp.readline()
    print('preparing music data set done.')



prepare music data set...
preparing music data set done.


### prepare shakespeare data set

In [26]:
# dataset from: https://www.kaggle.com/kingburrito666/shakespeare-plays
if data_type in ['all','shakespeare']: # parse shakespeare plays
    print('prepare shakespeare data set...')
    df = pd.read_csv(os.path.join(data_dir, 'shakespeare_data.csv'))
    if preserve_lines == 'false':
        df = df[df.Player!=''].groupby(['Play','PlayerLinenumber'],as_index=False).agg(' '.join)
    df.sample(num_samples).PlayerLine.to_csv(os.path.join(output_dir, 'shakespeare' + postfix + '.txt'), index=False, header=False, quoting=csv.QUOTE_NONE, escapechar="\\", sep="\\")
    print('preparing shakespeare data set done.')



prepare shakespeare data set...
preparing shakespeare data set done.


### prepare javascript data set

In [27]:
# dataset from: javascript files from https://www.sri.inf.ethz.ch/js150
if data_type in ['all','javascript']: # parse javascript files
    print('prepare javascript data set...')
    regexes = {}
    if preserve_form == 'false':
        regexes[r'(//[^\n]*)?\n|/\*.*?\*/'] = '\n'
        regexes[r'\n\s*\n'] = '\n'
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('javascript', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing javascript data set done.')



prepare javascript data set...
preparing javascript data set done.


### prepare typescript data set

In [28]:
# dataset from: typescript files collected from standard angular app
if data_type in ['all','typescript']: # parse typescript files
    print('prepare typescript data set...')
    regexes = {}
    if preserve_form == 'false':
        regexes[r'(//[^\n]*)?\n|/\*.*?\*/'] = '\n'
        regexes[r'\n\s*\n'] = '\n'
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('typescript', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing typescript data set done.')



prepare typescript data set...
preparing typescript data set done.


### prepare json data set

In [29]:
# dataset from: json files collected from standard angular app
if data_type in ['all','json']: # parse json files
    print('prepare json data set...')
    regexes = {}
    if preserve_lines == 'false':
        regexes[r'\s+'] = ' '
    process_zip('json', regexes, postfix,data_dir,min_length,max_length,preserve_form,num_samples)
    print('preparing json data set done.')



prepare json data set...
preparing json data set done.


### prepare html data set

In [30]:
# dataset from: https://www.kaggle.com/zavadskyy/lots-of-code, https://gist.github.com/VladislavZavadskyy/e31ab07b03a5c22b11982c49669a400b
if data_type in ['all','html']: # parse html
    print('prepare html data set...')
    with open(os.path.join(output_dir, 'html' + postfix + '.txt'),'w+') as fh:
        with open(os.path.join(data_dir, 'html-dataset.txt')) as fp:
            data = fp.read()
            data = data.replace('<!DOCTYPE html>','\n<!DOCTYPE html>')
            lines = data.split('\n')
            cnt = 0
            sample = ""
            for line in lines:
                if line == "":
                    continue
                if sample != "" and line.startswith('<!DOCTYPE html>'):
                    fh.write(sample.strip() + "\n")
                    sample = ""
                    cnt += 1
                if cnt >= num_samples:
                    break
                line = re.sub(r'\s+', ' ', line)
                sample += line.strip() + " "
    print('preparing html data set done.')

prepare html data set...
preparing html data set done.


# Let's fine-tune the GPT-2 model!
Choose the number of steps the model will be fine-tuned for. You can adjust the parameters  to specifiy how often you get updates on the training process, how often samples of the current model are printed, and every how many steps the model is saved.

Beside the number of steps, these parameters do not influence the training. The model will be saved automatically when done fine-tuning with the amount of steps specified. You can stop the fine-tuning anytime and the current training state of the model will be saved.

In [31]:
def start_session(sess):
    try:
        gpt2.reset_session(sess)
    except:
        pass
    return gpt2.start_tf_sess()

def fine_tune(sess,run_name,data_path,steps, model_name='124M'):
    print(f'Run fine-tuning for run {run_name} using GPT2 model {model_name}...')
    if not os.path.isdir(os.path.join("models", model_name)):
        log.info(f"Downloading {model_name} model...")
        gpt2.download_gpt2(model_name=model_name)
    sess = start_session(sess)
    gpt2.finetune(sess=sess,dataset=data_path,checkpoint_dir='runs', model_name=model_name, run_name=run_name, steps=steps, sample_every=10, save_every=10)



#run_name='run1'
#data_path ='data/tweets.txt'
sess = None
fine_tune(sess,run_name,data_path,stpes)

Run fine-tuning for run run1 using GPT2 model 124M...
Loading checkpoint models/124M/model.ckpt
INFO:tensorflow:Restoring parameters from models/124M/model.ckpt
2020-06-15 22:04:40,317 [INFO ] [tensorflow  ]: Restoring parameters from models/124M/model.ckpt
  0%|          | 0/1 [00:00<?, ?it/s]Loading dataset...
100%|██████████| 1/1 [00:00<00:00,  2.56it/s]
dataset has 25450 tokens
Training...
[1 | 91.91] loss=4.13 avg=4.13
[2 | 173.89] loss=4.07 avg=4.10
[3 | 254.24] loss=3.81 avg=4.00
[4 | 334.93] loss=3.81 avg=3.95
[5 | 416.09] loss=3.72 avg=3.91
[6 | 497.28] loss=3.67 avg=3.87
[7 | 578.22] loss=3.58 avg=3.82
[8 | 659.74] loss=3.70 avg=3.81
[9 | 744.61] loss=3.51 avg=3.77
[10 | 828.15] loss=3.42 avg=3.74
Saving runs/run1/model-10


# Text generation
We can now generate text mimiking the style of the learned samples.

You can play around with the three parameters `length`, `temperature`, and `top_k` to influnce the generated text. Further, you can provide a seed sequence that will be the beginning of the generated text.

Use the different data sets to explore how the fine-tuning works and what its' limits are. You can also use custom data sets. Just copy them to the data folder and specify the path above.

In [32]:
def generate(sess,run_name,length,temperature,top_k):

    message = "Social justice warior"
    text = gpt2.generate(sess=sess,checkpoint_dir='runs', run_name=run_name, prefix=message, length=length, temperature=temperature, top_k=top_k, return_as_list=True)
    print(text[0])

#length = 800 # { min:0, max:1000, step:5}
#temperature = 0.7 # { min:0, max:2, step:0.1}
#top_k = 0
sess = start_session(sess)
gpt2.load_gpt2(sess, checkpoint_dir='runs', run_name=run_name)
generate(sess,run_name,length,temperature,top_k)

Loading checkpoint runs/run1/model-10
INFO:tensorflow:Restoring parameters from runs/run1/model-10
2020-06-15 22:18:36,207 [INFO ] [tensorflow  ]: Restoring parameters from runs/run1/model-10
Social justice warior of a British Empire that has been brewing since the 1980s.

Island of the People, Canada – B.C.

If you're on a flight, please leave your passengers alone.

Theresa May, PM: I don't know why anyone would want to torture people. But I have to.

We know that there is a good deal of sympathy for the victims of the Paris attacks in their own country, but they're not in the United States.

The global financial crisis has put us all on the verge of a disaster.

I am deeply disappointed by President Obama's decision today to withdraw the United States from the Paris Agreement.

I am absolutely determined to work with the Obama administration to make sure that America is prepared to defend against any and all threats.

The United States will not be intimidated into lessening our comm