In [55]:
cd data/gpt-2/

[Errno 2] No such file or directory: 'data/gpt-2/'
/tf/prototypes/gpt-2/tf2/data/gpt-2


In [2]:
!pip3 install -r requirements.txt



In [56]:
import fire
import json
import os
import numpy as np
import tensorflow as tf
import regex as re
from functools import lru_cache
import argparse
import time
import tqdm
from tensorflow.core.protobuf import rewriter_config_pb2
import glob

tf.__version__

'2.0.0-beta1'

# Encoding

In [57]:
"""Byte pair encoding utilities"""


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """Return set of symbol pairs in a word.

    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

# Model

In [58]:
class HParams():
  n_vocab=50257
  n_ctx=1024
  n_embd=768
  n_head=12
  n_layer=12
  
  def __init__(self, n_vocab, n_ctx, n_embd, n_head, n_layer):
    self.n_vocab = n_vocab
    self.n_ctx = n_ctx
    self.n_embd = n_embd
    self.n_head = n_head
    self.n_layer = n_layer

In [59]:
def default_hparams():
    return HParams(
        n_vocab=50257,
        n_ctx=1024,
        n_embd=768,
        n_head=12,
        n_layer=12,
    )

def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(input=x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def gelu(x):
    return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.compat.v1.variable_scope(scope):
        n_state = x.shape[-1]
        g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1), use_resource=False)
        b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        u = tf.reduce_mean(input_tensor=x, axis=axis, keepdims=True)
        s = tf.reduce_mean(input_tensor=tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.math.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.compat.v1.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.compat.v1.get_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev), use_resource=False)
        b = tf.compat.v1.get_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(a=split_states(x, hparams.n_head), perm=[0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(a=x, perm=[0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.math.rsqrt(tf.cast(v.shape[-1], w.dtype))

        w = mask_attn_weights(w)
        w = tf.nn.softmax(w, axis=-1)
        a = tf.matmul(w, v)
        return a

    with tf.compat.v1.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2

def block(x, scope, *, past, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value=value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

def positions_for(tokens, past_length):
    batch_size = tf.shape(input=tokens)[0]
    nsteps = tf.shape(input=tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)


def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE):
    with tf.compat.v1.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.01), use_resource=False)
        wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.02), use_resource=False)
        past_length = 0 if past is None else tf.shape(input=past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

# Sample from Model

In [60]:
def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.compat.v1.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    return tf.cond(
       pred=tf.equal(k, 0),
       true_fn=lambda: logits,
       false_fn=lambda: _top_k(),
    )


def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    def body(past, prev, output):
        next_outputs = step(hparams, prev, past=past)
        logits = next_outputs['logits'][:, -1, :]  / tf.cast(temperature, dtype=tf.float32)
        logits = top_k_logits(logits, k=top_k)
        samples = tf.random.categorical(logits=logits, num_samples=1, dtype=tf.int32)
        return [
            next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
            samples,
            tf.concat([output, samples], axis=1)
        ]

    past, prev, output = body(None, context, context)

    def cond(*args):
        return True

    _, _, tokens = tf.while_loop(
        cond=cond, body=body,
        maximum_iterations=length - 1,
        loop_vars=[
            past,
            prev,
            output
        ],
        shape_invariants=[
            tf.TensorShape(past_shape(hparams=hparams, batch_size=batch_size)),
            tf.TensorShape([batch_size, None]),
            tf.TensorShape([batch_size, None]),
        ],
        back_prop=False,
    )

    return tokens

In [61]:
from pathlib import Path
def load_dataset(enc, path, combine):
    paths = []
    if os.path.isfile(path):
        # Simple file
        paths.append(path)
    elif os.path.isdir(path):
        # Directory
        for i, (dirpath, _, fnames) in enumerate(os.walk(path)):
            if i % 5000 == 0:
                print(i)
            for fname in fnames:
                paths.append(os.path.join(dirpath, fname))
                
            if i == 50000:
                print("Breaking")
                break
    else:
        # Assume glob
        paths = glob.glob(path)

        
    token_chunks = []
    raw_text = ''
    for i, path in enumerate(tqdm.tqdm(paths)):
        if 'after.java' not in path:
            continue

        try:
            with open(path, 'r') as fp:
                raw_text += fp.read()
            tokens = np.stack(enc.encode(raw_text))
            token_chunks.append(tokens)
            raw_text = ''
        except:
            print(e)
        if i >= 100000:
            break
    return token_chunks

def binary_search(f, lo, hi):
    if f(lo) or not f(hi):
        return None
    while hi > lo + 1:
        mid = (lo + hi) // 2
        if f(mid):
            hi = mid
        else:
            lo = mid
    return hi


class Sampler(object):
    """Fairly samples a slice from a set of variable sized chunks.

    'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
    but without crossing chunk boundaries."""

    def __init__(self, chunks, seed=None):
        self.chunks = chunks
        self.total_size = sum(chunk.shape[0] for chunk in chunks)
        self.boundaries = [0]
        for i in range(len(chunks)):
            self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
        self.rs = np.random.RandomState(seed=seed)

    def sample(self, length):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            index = self.rs.randint(0, self.total_size - length - 1)
            i = binary_search(lambda j: self.boundaries[j] > index, 0,
                              len(self.boundaries) - 1) - 1
            if self.boundaries[i + 1] > index + length:
                within_chunk = index - self.boundaries[i]
                return self.chunks[i][within_chunk:within_chunk + length]

In [62]:
class Args():
    def __init__(self, dataset, model_name, combine, batch_size, learning_rate, optimizer, noise, top_k, top_p, run_name, sample_every, sample_length, sample_num, save_every, val_dataset, val_batch_size, val_batch_count, val_every):
        self.dataset = dataset
        self.model_name = model_name
        self.combine = combine
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.noise = noise
        self.top_k = top_k
        self.top_p = top_p
        self.run_name = run_name
        self.sample_every = sample_every
        self.sample_length = sample_length
        self.sample_num = sample_num
        self.save_every = save_every
        self.val_dataset = val_dataset
        self.val_batch_size = val_batch_size
        self.val_batch_count = val_batch_count
        self.val_every = val_every

In [63]:
args = Args(
                dataset="../sciclone/data10/mtufano/deepLearningMutants/out/changes/code",
                model_name="117M",
                combine=50000,
                batch_size=1,
                learning_rate=0.00002,
                optimizer="sgd",
                noise=0.0,
                top_k=40,
                top_p=0.0,
                run_name="run1",
                sample_every=100,
                sample_length=1023,
                sample_num=1,
                save_every=1000,
                val_dataset=None,
                val_batch_size=2,
                val_batch_count=40,
                val_every=100
    )

In [64]:
enc = get_encoder(args.model_name, "models")
data_set = load_dataset(enc, args.dataset, args.combine)
len(data_set)

0
5000
10000
15000
20000
25000
30000
35000
40000
45000


  0%|          | 26/115158 [00:00<07:35, 252.62it/s]

50000
Breaking


 87%|████████▋ | 99948/115158 [01:25<00:09, 1660.01it/s]

25358

In [65]:
DATA_SET_SIZE = len(data_set)
TRN_SET_SIZE = int(DATA_SET_SIZE * 0.8)
VAL_SET_SIZE = int(DATA_SET_SIZE * 0.1)
TST_SET_SIZE = int(DATA_SET_SIZE * 0.1)

trn_set = data_set[:TRN_SET_SIZE]
val_set = data_set[TRN_SET_SIZE:TRN_SET_SIZE + VAL_SET_SIZE]
tst_set = data_set[-TST_SET_SIZE:]
len(trn_set), len(val_set), len(tst_set)

(20286, 2535, 2535)

In [66]:
CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'

def maketree(path):
    try:
        os.makedirs(path)
    except:
        pass


def randomize(context, hparams, p):
    if p > 0:
        mask = tf.random.uniform(shape=tf.shape(input=context)) < p
        noise = tf.random.uniform(shape=tf.shape(input=context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32)
        return tf.compat.v1.where(mask, noise, context)
    else:
        return context


def main():
    enc = get_encoder(args.model_name, "models")
    hparams = default_hparams()

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.compat.v1.Session(config=config) as sess:
        context = tf.compat.v1.placeholder(tf.int32, [args.batch_size, None])
        context_in = context # randomize(context, hparams, args.noise)
        output = model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.compat.v1.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.compat.v1.summary.scalar('val_loss', val_loss)


        tf_sample = sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k)

        all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
        train_vars = all_vars

        if args.optimizer == 'adam':
            opt = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        opt_grads = tf.gradients(ys=loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.compat.v1.summary.scalar('loss', loss)

        summaries = tf.compat.v1.summary.merge([summary_loss])

        summary_log = tf.compat.v1.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.compat.v1.train.Saver(
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.compat.v1.global_variables_initializer())

        ckpt = tf.train.latest_checkpoint(
            os.path.join(CHECKPOINT_DIR, args.run_name))
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))

        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        data_sampler = Sampler(trn_set)
        if args.val_every > 0:
            val_chunks = val_set
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(128) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(128) for _ in range(args.batch_size)]


        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summaries),
                    feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()


if __name__ == '__main__':
    main()

 87%|████████▋ | 99948/115158 [01:40<00:09, 1660.01it/s]

Loading checkpoint checkpoint/run1/model-101
Loading dataset...
dataset has 15705181 tokens
Training...
[102 | 5.33] loss=2.31 avg=2.31
[103 | 5.39] loss=1.16 avg=1.73
[104 | 5.44] loss=3.00 avg=2.16
[105 | 5.50] loss=1.70 avg=2.04
[106 | 5.55] loss=3.01 avg=2.24
[107 | 5.60] loss=2.16 avg=2.23
[108 | 5.66] loss=2.31 avg=2.24
[109 | 5.71] loss=1.57 avg=2.15
[110 | 5.76] loss=3.10 avg=2.26
[111 | 5.81] loss=2.13 avg=2.25
[112 | 5.86] loss=2.05 avg=2.23
[113 | 5.91] loss=2.31 avg=2.23
[114 | 5.97] loss=1.19 avg=2.15
[115 | 6.03] loss=2.47 avg=2.17
[116 | 6.08] loss=2.97 avg=2.23
[117 | 6.13] loss=1.20 avg=2.16
[118 | 6.19] loss=2.04 avg=2.15
[119 | 6.25] loss=1.93 avg=2.14
[120 | 6.30] loss=1.75 avg=2.12
[121 | 6.35] loss=2.07 avg=2.11
[122 | 6.41] loss=2.89 avg=2.16
[123 | 6.46] loss=3.39 avg=2.22
[124 | 6.52] loss=2.22 avg=2.22
[125 | 6.58] loss=1.27 avg=2.17
[126 | 6.64] loss=2.08 avg=2.17
[127 | 6.69] loss=2.23 avg=2.17
[128 | 6.74] loss=0.96 avg=2.12
[129 | 6.80] loss=3.36 avg=2.17



  0%|          | 0/40 [00:00<?, ?it/s][A

  The most important thing that you will NOT want to get in an old car                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  


  2%|▎         | 1/40 [00:02<01:30,  2.31s/it][A
 15%|█▌        | 6/40 [00:02<00:55,  1.63s/it][A
 28%|██▊       | 11/40 [00:02<00:33,  1.14s/it][A
 40%|████      | 16/40 [00:02<00:19,  1.24it/s][A
 50%|█████     | 20/40 [00:02<00:11,  1.74it/s][A
 62%|██████▎   | 25/40 [00:02<00:06,  2.45it/s][A
 75%|███████▌  | 30/40 [00:03<00:02,  3.41it/s][A
 85%|████████▌ | 34/40 [00:03<00:01,  4.69it/s][A
 98%|█████████▊| 39/40 [00:03<00:00,  6.40it/s][A
100%|██████████| 40/40 [00:03<00:00, 12.28it/s][A

[200 | 32.34] validation loss = 2.09
[200 | 32.39] loss=1.77 avg=2.05
[201 | 32.44] loss=1.61 avg=2.04
[202 | 32.49] loss=2.45 avg=2.05
[203 | 32.54] loss=1.62 avg=2.04
[204 | 32.59] loss=1.53 avg=2.03
[205 | 32.65] loss=0.99 avg=2.02
[206 | 32.70] loss=2.68 avg=2.03
[207 | 32.75] loss=3.38 avg=2.05
[208 | 32.80] loss=2.37 avg=2.05
[209 | 32.86] loss=2.72 avg=2.06
[210 | 32.91] loss=2.15 avg=2.06
[211 | 32.96] loss=2.78 avg=2.08
[212 | 33.01] loss=2.48 avg=2.08
[213 | 33.07] loss=1.88 avg=2.08
[214 | 33.12] loss=1.42 avg=2.07
[215 | 33.17] loss=2.51 avg=2.07
[216 | 33.23] loss=2.56 avg=2.08
[217 | 33.28] loss=3.23 avg=2.10
[218 | 33.34] loss=2.11 avg=2.10
[219 | 33.39] loss=2.32 avg=2.10
[220 | 33.45] loss=1.48 avg=2.09
[221 | 33.50] loss=2.91 avg=2.10
[222 | 33.56] loss=0.16 avg=2.08
[223 | 33.61] loss=2.39 avg=2.08
[224 | 33.66] loss=2.59 avg=2.09
[225 | 33.72] loss=3.55 avg=2.11
[226 | 33.77] loss=2.55 avg=2.12
[227 | 33.82] loss=2.81 avg=2.13
[228 | 33.88] loss=2.06 avg=2.12
[229 |


  0%|          | 0/40 [00:00<?, ?it/s][A
 12%|█▎        | 5/40 [00:00<00:00, 41.57it/s][A

(s)) * s.type.length * s.length + 2); // check if s.is_null (s.type.length) return s[s]; } return true; }

The two functions are really in the same file. I created these from the source code, just because we're getting in front of that. We could then do things like update s and add some arguments. But there aren't many calls in that file. It is better to make it read only. I'm not happy about that. If we want to call things, we need to set up our own methods. Let's break down those.

Update / Initialize s with s.update({}); // update the contents with .apply(s.first, s.last, ...);

The code below will call s.update(... args) in order to update the contents of a file. We should be able to read all the files. We'll only get one file. The function s.update(s) now takes in different arguments to update a new file. First, we must update s.first. We should also update s first before any of s.last. Finally, we should replace these arguments with values from our function. For example, we could


 25%|██▌       | 10/40 [00:00<00:00, 41.75it/s][A
 35%|███▌      | 14/40 [00:00<00:00, 40.28it/s][A
 48%|████▊     | 19/40 [00:00<00:00, 40.73it/s][A
 60%|██████    | 24/40 [00:00<00:00, 41.18it/s][A
 72%|███████▎  | 29/40 [00:00<00:00, 41.61it/s][A
 82%|████████▎ | 33/40 [00:00<00:00, 41.05it/s][A
 95%|█████████▌| 38/40 [00:00<00:00, 41.19it/s][A
100%|██████████| 40/40 [00:00<00:00, 41.08it/s][A

[300 | 52.61] validation loss = 2.08
[300 | 52.67] loss=1.86 avg=2.02
[301 | 52.72] loss=2.26 avg=2.02
[302 | 52.77] loss=1.99 avg=2.02
[303 | 52.83] loss=1.60 avg=2.01
[304 | 52.88] loss=2.58 avg=2.02
[305 | 52.93] loss=2.55 avg=2.03
[306 | 52.99] loss=2.05 avg=2.03
[307 | 53.05] loss=2.34 avg=2.03
[308 | 53.10] loss=1.87 avg=2.03
[309 | 53.15] loss=1.57 avg=2.02
[310 | 53.20] loss=0.25 avg=2.00
[311 | 53.26] loss=1.81 avg=2.00
[312 | 53.31] loss=2.58 avg=2.01
[313 | 53.36] loss=0.04 avg=1.98
[314 | 53.42] loss=2.40 avg=1.99
[315 | 53.47] loss=2.64 avg=2.00
[316 | 53.52] loss=0.32 avg=1.98
[317 | 53.58] loss=1.18 avg=1.97
[318 | 53.63] loss=1.84 avg=1.97
[319 | 53.69] loss=2.21 avg=1.97
[320 | 53.74] loss=1.18 avg=1.96
[321 | 53.79] loss=3.04 avg=1.97
[322 | 53.84] loss=2.31 avg=1.98
[323 | 53.90] loss=1.87 avg=1.98
[324 | 53.95] loss=2.46 avg=1.98
[325 | 54.00] loss=1.84 avg=1.98
[326 | 54.06] loss=1.42 avg=1.97
[327 | 54.11] loss=0.76 avg=1.96
[328 | 54.16] loss=2.78 avg=1.97
[329 |

In [23]:
# Load the tensorboard notebook extension
%load_ext tensorboard

In [67]:
%tensorboard --logdir checkpoint/run1/

Reusing TensorBoard on port 6006 (pid 14104), started 0:38:40 ago. (Use '!kill 14104' to kill it.)

In [15]:
trn_set[0]

array([   31,  2398,    13, 16469, 30604,    13, 12384,    13, 21653,
          13,  1236, 14221,    13, 18453,    44,  5912,     7,  8367,
         796, 12813,  1177, 17752,  4943,   198, 11377,  8745,    13,
       16469, 30604,    13,  4023,    13, 31077, 32398, 47934,    29,
        1570,    41,  1559,  3419,  1391,   198,   220,   220,   220,
         401,    13, 28864,  2419,    13,  5420,  1324,    13, 27830,
          13, 16922,    62, 36479,  1095,  6631,    62, 36479,  1095,
         796,   649,   401,    13, 28864,  2419,    13,  5420,  1324,
          13, 27830,    13, 16922,    62, 36479,  1095,  9783,   198,
         220,   220,   220,  6631,    62, 36479,  1095,    13,  2617,
       41972, 10430,     7,  3605, 20129,    13, 22602,    13, 10430,
       35430,   198,   220,   220,   220,  6631,    62, 36479,  1095,
          13,  2617,  5956, 17354, 10430,     7,  3605, 20129,    13,
       22602,    13, 10430, 35430,   198,   220,   220,   220,  6631,
          62, 36479,