In [1]:
import json
import os
import numpy as np
import tensorflow as tf
import model, sample, encoder

In [3]:
# !ln -s ../models models # hack to make models "appear" in two places

In [4]:
model_name = '117M'
seed = None
nsamples = 10
batch_size = 10
length = 40
temperature = 0.8 # 0 is deterministic
top_k = 40 # 0 means no restrictions

assert nsamples % batch_size == 0

enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
    hparams.override_from_dict(json.load(f))

if length is None:
    length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
    raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

In [5]:
sess = tf.InteractiveSession()

# replace with this in script:
# with tf.Session(graph=tf.Graph()) as sess:

In [6]:
context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
    hparams=hparams, length=length,
    context=context,
    batch_size=batch_size,
    temperature=temperature, top_k=top_k
)

saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)

INFO:tensorflow:Restoring parameters from models/117M/model.ckpt


In [9]:
import unicodedata
import os, re, random, fnmatch

def list_all_files(directory, extensions=None, exclude_prefixes=('__', '.')):
    for root, dirnames, filenames in os.walk(directory):
        filenames = [f for f in filenames if not f.startswith(exclude_prefixes)]
        dirnames[:] = [d for d in dirnames if not d.startswith(exclude_prefixes)]
        for filename in filenames:
            base, ext = os.path.splitext(filename)
            joined = os.path.join(root, filename)
            if extensions is None or ext.lower() in extensions:
                yield joined

mapping = {
 '\xa0': ' ',
 'Æ': 'AE',
 'æ': 'ae',
 'è': 'e',
 'é': 'e',
 'ë': 'e',
 'ö': 'o',
 '–': '-',
 '—': '-',
 '‘': "'",
 '’': "'",
 '“': '"',
 '”': '"'
}

def remove_special(text):
    return ''.join([mapping[e] if e in mapping else e for e in text])

def strip_word(word):
    word = re.sub('^\W*|\W*$', '', word).lower()
    return word

basenames = []
all_lyrics = {}
total_lines = 0
words = set()
for fn in list_all_files('../../gpt2-raps/output'):
    with open(fn) as f:
        original = open(fn).read()
        text = remove_special(original).split('\n')
        lyrics = text[3:]
        basename = os.path.basename(fn)
        basename = os.path.splitext(basename)[0]
        basenames.append(basename)
        all_lyrics[basename] = {
            'url': text[0],
            'title': text[1],
            'artist': text[2],
            'lyrics': lyrics
        }
        total_lines += len(lyrics)
        lyrics = '\n'.join(lyrics)
        words.update([strip_word(e) for e in lyrics.split()])
words.remove('')
words = list(words)
        
print(total_lines)

676


In [10]:
def titlecase_word(word):
    return word[0].upper() + word[1:]

titlecase_word("carpenter's"), "carpenter's".title()

("Carpenter's", "Carpenter'S")

In [11]:
def random_chunk(array, length):
    start = random.randint(0, max(0, len(array) - length - 1))
    return array[start:start+length]

def random_item(array):
    return array[random.randint(0, len(array) - 1)]

random_chunk(all_lyrics[basenames[0]]['lyrics'], 2), titlecase_word(random_item(words))

(['Thought a new dress would make it better', 'I tried to work it away'],
 'Message')

In [12]:
seeds = '''
work
based
love
beach
went
cry
heavy
groceries
heaven
blame
coming
sleeping
blue
city
peace
'''.split()
len(seeds)

15

In [13]:
from multiprocessing import Pool, cpu_count
from IPython.display import clear_output
import time
from datetime import datetime, timedelta
import sys

def progress(itr, total=None, update_interval=1):
    if total is None and hasattr(itr, '__len__'):
        total = len(itr)
    if total:
        print('0/{} 0s 0/s'.format(total))
    else:
        print('0 0s 0/s')
    start_time = None
    last_time = None
    for i, x in enumerate(itr):
        cur_time = time.time()
        if start_time is None:
            start_time = cur_time
            last_time = cur_time
        yield x
        if cur_time - last_time > update_interval:
            duration = cur_time - start_time
            speed = (i + 1) / duration
            duration_str = timedelta(seconds=round(duration))
            clear_output(wait=True)
            if total:
                duration_total = duration * total / (i + 1)
                duration_remaining = duration_total - duration
                duration_remaining_str = timedelta(seconds=round(duration_remaining))
                pct = 100. * (i + 1) / total
                print('{:.2f}% {}/{} {}<{} {:.2f}/s'.format(pct, i+1, total, duration_str, duration_remaining_str, speed))
            else:
                print('{} {} {:.2f}/s'.format(i+1, duration_str, speed))
            last_time = cur_time
    
    duration = time.time() - start_time
    speed = (i + 1) / duration
    duration_str = timedelta(seconds=round(duration))
    clear_output(wait=True)
    print('{} {} {:.2f}/s'.format(i+1, duration_str, speed))
        
class job_wrapper(object):
    def __init__(self, job):
        self.job = job
    def __call__(self, args):
        i, task = args
        return i, self.job(task)
    
def progress_parallel(job, tasks, total=None, update_interval=1, processes=None):
    results = []
    if total is None and hasattr(tasks, '__len__'):
        total = len(tasks)
    if processes is None:
        processes = cpu_count()
    try:
        with Pool(processes) as pool:
            results = list(progress(pool.imap_unordered(job_wrapper(job), enumerate(tasks)),
                                    total=total, update_interval=update_interval))
            results.sort()
            return [x for i,x in results]
    except KeyboardInterrupt:
        pass

In [15]:
def clean(text):
    return text.split('<|endoftext|>')[0]

def generate(inspiration, seed):
    inspiration = remove_special(inspiration).strip()
    seed = titlecase_word(seed).strip()

    raw_text = inspiration + '\n' + seed
    context_tokens = enc.encode(raw_text)
    n_context = len(context_tokens)

    results = []
    for _ in range(nsamples // batch_size):
        out = sess.run(output, feed_dict={
            context: [context_tokens for _ in range(batch_size)]
        })
        for sample in out:
            text = enc.decode(sample[n_context:])
            result = seed + text
            results.append(result)
    
    return results

In [17]:
inspiration_lines = 16

all_results = {}
for seed in seeds:
    print(seed)
    cur = {}
    for basename in basenames:        
        inspiration = random_chunk(all_lyrics[basename]['lyrics'], inspiration_lines)
        inspiration = '\n'.join(inspiration)
        results = generate(inspiration, seed)
        cur[basename] = results
    all_results[seed] = cur

work
based
love
beach
went
cry
heavy
groceries
heaven
blame
coming
sleeping
blue
city
peace


In [19]:
import json
with open('../../gpt2-raps/output/lyrics.json', 'w') as f:
    json.dump(all_lyrics, f, separators=(',', ':'))
    
with open('../../gpt2-raps/output/generated.json', 'w') as f:
    json.dump(all_results, f, separators=(',', ':'))