Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
kylemcdonald committed Mar 4, 2019
0 parents commit ac115d7
Show file tree
Hide file tree
Showing 8 changed files with 758 additions and 0 deletions.
357 changes: 357 additions & 0 deletions Generate GPT-2.ipynb
@@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import model, sample, encoder"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# !ln -s ../models models # hack to make models \"appear\" in two places"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model_name = '117M'\n",
"seed = None\n",
"nsamples = 10\n",
"batch_size = 10\n",
"length = 40\n",
"temperature = 0.8 # 0 is deterministic\n",
"top_k = 40 # 0 means no restrictions\n",
"\n",
"assert nsamples % batch_size == 0\n",
"\n",
"enc = encoder.get_encoder(model_name)\n",
"hparams = model.default_hparams()\n",
"with open(os.path.join('models', model_name, 'hparams.json')) as f:\n",
" hparams.override_from_dict(json.load(f))\n",
"\n",
"if length is None:\n",
" length = hparams.n_ctx // 2\n",
"elif length > hparams.n_ctx:\n",
" raise ValueError(\"Can't get samples longer than window size: %s\" % hparams.n_ctx)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"# replace with this in script:\n",
"# with tf.Session(graph=tf.Graph()) as sess:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from models/117M/model.ckpt\n"
]
}
],
"source": [
"context = tf.placeholder(tf.int32, [batch_size, None])\n",
"np.random.seed(seed)\n",
"tf.set_random_seed(seed)\n",
"output = sample.sample_sequence(\n",
" hparams=hparams, length=length,\n",
" context=context,\n",
" batch_size=batch_size,\n",
" temperature=temperature, top_k=top_k\n",
")\n",
"\n",
"saver = tf.train.Saver()\n",
"ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))\n",
"saver.restore(sess, ckpt)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8314\n"
]
}
],
"source": [
"from utils.list_all_files import *\n",
"import unicodedata\n",
"import os, re, random\n",
"\n",
"mapping = {\n",
" '\\xa0': ' ',\n",
" 'Æ': 'AE',\n",
" 'æ': 'ae',\n",
" 'è': 'e',\n",
" 'é': 'e',\n",
" 'ë': 'e',\n",
" 'ö': 'o',\n",
" '–': '-',\n",
" '—': '-',\n",
" '‘': \"'\",\n",
" '’': \"'\",\n",
" '“': '\"',\n",
" '”': '\"'\n",
"}\n",
"\n",
"def remove_special(text):\n",
" return ''.join([mapping[e] if e in mapping else e for e in text])\n",
"\n",
"def strip_word(word):\n",
" word = re.sub('^\\W*|\\W*$', '', word).lower()\n",
" return word\n",
"\n",
"basenames = []\n",
"all_poems = {}\n",
"total_lines = 0\n",
"words = set()\n",
"for fn in list_all_files('../../scraping/poetry/output'):\n",
" with open(fn) as f:\n",
" original = open(fn).read()\n",
" text = remove_special(original).split('\\n')\n",
" poem = text[3:]\n",
" basename = os.path.basename(fn)\n",
" basename = os.path.splitext(basename)[0]\n",
" basenames.append(basename)\n",
" all_poems[basename] = {\n",
" 'url': text[0],\n",
" 'title': text[1],\n",
" 'author': text[2],\n",
" 'poem': poem\n",
" }\n",
" total_lines += len(poem)\n",
" poem = '\\n'.join(poem)\n",
" words.update([strip_word(e) for e in poem.split()])\n",
"words.remove('')\n",
"words = list(words)\n",
" \n",
"print(total_lines)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(\"Carpenter's\", \"Carpenter'S\")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def titlecase_word(word):\n",
" return word[0].upper() + word[1:]\n",
"\n",
"titlecase_word(\"carpenter's\"), \"carpenter's\".title()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(['The heart did beat, and cheek did glow,',\n",
" 'And lip did smile, and eye did weep,'],\n",
" 'Adorations')"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def random_chunk(array, length):\n",
" start = random.randint(0, max(0, len(array) - length - 1))\n",
" return array[start:start+length]\n",
"\n",
"def random_item(array):\n",
" return array[random.randint(0, len(array) - 1)]\n",
"\n",
"random_chunk(all_poems[basenames[0]]['poem'], 2), titlecase_word(random_item(words))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seeds = '''\n",
"blue\n",
"epoch\n",
"ethereal\n",
"ineffable\n",
"iridescent\n",
"nefarious\n",
"oblivion\n",
"quiver\n",
"solitude\n",
"sonorous\n",
"'''.split()\n",
"len(seeds)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from utils.progress import progress"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"def clean(text):\n",
" return text.split('<|endoftext|>')[0]\n",
"\n",
"def generate(inspiration, seed):\n",
" inspiration = remove_special(inspiration).strip()\n",
" seed = titlecase_word(seed).strip()\n",
"\n",
" raw_text = inspiration + '\\n' + seed\n",
" context_tokens = enc.encode(raw_text)\n",
" n_context = len(context_tokens)\n",
"\n",
" results = []\n",
" for _ in range(nsamples // batch_size):\n",
" out = sess.run(output, feed_dict={\n",
" context: [context_tokens for _ in range(batch_size)]\n",
" })\n",
" for sample in out:\n",
" text = enc.decode(sample[n_context:])\n",
" result = seed + text\n",
" results.append(result)\n",
" \n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"blue\n",
"epoch\n",
"ethereal\n",
"ineffable\n",
"iridescent\n",
"nefarious\n",
"oblivion\n",
"quiver\n",
"solitude\n",
"sonorous\n"
]
}
],
"source": [
"inspiration_lines = 16\n",
"\n",
"all_results = {}\n",
"for seed in seeds:\n",
" print(seed)\n",
" cur = {}\n",
" for basename in basenames: \n",
" inspiration = random_chunk(all_poems[basename]['poem'], inspiration_lines)\n",
" inspiration = '\\n'.join(inspiration)\n",
" results = generate(inspiration, seed)\n",
" cur[basename] = results\n",
" all_results[seed] = cur"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"with open('poems.json', 'w') as f:\n",
" json.dump(all_poems, f, separators=(',', ':'))\n",
" \n",
"with open('generated.json', 'w') as f:\n",
" json.dump(all_results, f, separators=(',', ':'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit ac115d7

Please sign in to comment.