In [None]:
# Load the tensorboard notebook extension
%load_ext tensorboard

In [1]:
cd /tf/src/data/gpt-2/

/tf/src/data/gpt-2


In [2]:
! pip3 install -r requirements.txt

Collecting fire>=0.1.3 (from -r requirements.txt (line 1))
[?25l  Downloading https://files.pythonhosted.org/packages/d9/69/faeaae8687f4de0f5973694d02e9d6c3eb827636a009157352d98de1129e/fire-0.2.1.tar.gz (76kB)
[K     |████████████████████████████████| 81kB 3.6MB/s eta 0:00:011
[?25hCollecting regex==2017.4.5 (from -r requirements.txt (line 2))
[?25l  Downloading https://files.pythonhosted.org/packages/36/62/c0c0d762ffd4ffaf39f372eb8561b8d491a11ace5a7884610424a8b40f95/regex-2017.04.05.tar.gz (601kB)
[K     |████████████████████████████████| 604kB 11.2MB/s eta 0:00:01
[?25hCollecting requests==2.21.0 (from -r requirements.txt (line 3))
[?25l  Downloading https://files.pythonhosted.org/packages/7d/e3/20f3d364d6c8e5d2353c72a67778eb189176f08e873c9900e10c0287b84b/requests-2.21.0-py2.py3-none-any.whl (57kB)
[K     |████████████████████████████████| 61kB 15.3MB/s eta 0:00:01
[?25hCollecting tqdm==4.31.1 (from -r requirements.txt (line 4))
[?25l  Downloading https://files.pythonhosted

In [None]:
! python3 download_model.py 117M

In [3]:
import fire
import json
import os
import numpy as np
import tensorflow as tf
import regex as re
from functools import lru_cache
from statistics import median
import argparse
import time
import tqdm
from tensorflow.core.protobuf import rewriter_config_pb2
import glob
import pickle

tf.__version__

'2.0.0-beta1'

# Encoding

In [4]:
"""Byte pair encoding utilities"""


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

def get_pairs(word):
    """Return set of symbol pairs in a word.

    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

# Model

In [5]:
class HParams():
  n_vocab=50257
  n_ctx=1024
  n_embd=768
  n_head=12
  n_layer=12
  
  def __init__(self, n_vocab, n_ctx, n_embd, n_head, n_layer):
    self.n_vocab = n_vocab
    self.n_ctx = n_ctx
    self.n_embd = n_embd
    self.n_head = n_head
    self.n_layer = n_layer

In [6]:
def default_hparams():
    return HParams(
        n_vocab=50257,
        n_ctx=1024,
        n_embd=768,
        n_head=12,
        n_layer=12,
    )

def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(input=x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

def gelu(x):
    return 0.5 * x * (1 + tf.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.compat.v1.variable_scope(scope):
        n_state = x.shape[-1]
        g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.compat.v1.constant_initializer(1), use_resource=False)
        b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        u = tf.reduce_mean(input_tensor=x, axis=axis, keepdims=True)
        s = tf.reduce_mean(input_tensor=tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.math.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.compat.v1.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.compat.v1.get_variable('w', [1, nx, nf], initializer=tf.compat.v1.random_normal_initializer(stddev=w_init_stdev), use_resource=False)
        b = tf.compat.v1.get_variable('b', [nf], initializer=tf.compat.v1.constant_initializer(0), use_resource=False)
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(a=split_states(x, hparams.n_head), perm=[0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(a=x, perm=[0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.math.rsqrt(tf.cast(v.shape[-1], w.dtype))

        w = mask_attn_weights(w)
        w = tf.nn.softmax(w, axis=-1)
        a = tf.matmul(w, v)
        return a

    with tf.compat.v1.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2

def block(x, scope, *, past, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value=value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

def positions_for(tokens, past_length):
    batch_size = tf.shape(input=tokens)[0]
    nsteps = tf.shape(input=tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)

def clf(x, ny, w_init=tf.compat.v1.random_normal_initializer(stddev=0.02), b_init=tf.compat.v1.constant_initializer(0), train=False):
    with tf.variable_scope('clf'):
        nx = shape_list(x)[-1]
        w = tf.compat.v1.get_variable("w", [nx, ny], initializer=w_init)
        b = tf.compat.v1.get_variable("b", [ny], initializer=b_init)
        return tf.matmul(x, w)+b

def model(hparams, X, past=None, scope='model', reuse=tf.compat.v1.AUTO_REUSE):
    with tf.compat.v1.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.01), use_resource=False)
        wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.compat.v1.random_normal_initializer(stddev=0.02), use_resource=False)
        past_length = 0 if past is None else tf.shape(input=past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')
        
        # Classification on h vector (from paper https://openai.com/blog/language-unsupervised/)
        clf_h = tf.reshape(h, [-1, hparams.n_embd])
        pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], hparams.n_vocab), tf.float32), 1), tf.int32)
        clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32)*n_ctx+pool_idx)

        clf_h = tf.reshape(clf_h, [-1, 2, hparams.n_embd])
        if train and clf_pdrop > 0:
            shape = shape_list(clf_h)
            shape[1] = 1
            clf_h = tf.nn.dropout(clf_h, 1-clf_pdrop, shape)
        clf_h = tf.reshape(clf_h, [-1, n_embd])
        clf_logits = clf(clf_h, 1, train=train)
        clf_logits = tf.reshape(clf_logits, [-1, 2])
        results['clf_logits'] = clf_logits

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

In [None]:
def model(X, M, Y, train=False, reuse=False):
    with tf.variable_scope('model', reuse=reuse):
        we = tf.get_variable("we", [n_vocab+n_special+n_ctx, n_embd], initializer=tf.random_normal_initializer(stddev=0.02))
        we = dropout(we, embd_pdrop, train)

        X = tf.reshape(X, [-1, n_ctx, 2])
        M = tf.reshape(M, [-1, n_ctx])

        h = embed(X, we)
        for layer in range(n_layer):
            h = block(h, 'h%d'%layer, train=train, scale=True)

        lm_h = tf.reshape(h[:, :-1], [-1, n_embd])
        lm_logits = tf.matmul(lm_h, we, transpose_b=True)
        lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=lm_logits, labels=tf.reshape(X[:, 1:, 0], [-1]))
        lm_losses = tf.reshape(lm_losses, [shape_list(X)[0], shape_list(X)[1]-1])
        lm_losses = tf.reduce_sum(lm_losses*M[:, 1:], 1)/tf.reduce_sum(M[:, 1:], 1)

        clf_h = tf.reshape(h, [-1, n_embd])
        pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32)
        clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32)*n_ctx+pool_idx)

        clf_h = tf.reshape(clf_h, [-1, 2, n_embd])
        if train and clf_pdrop > 0:
            shape = shape_list(clf_h)
            shape[1] = 1
            clf_h = tf.nn.dropout(clf_h, 1-clf_pdrop, shape)
        clf_h = tf.reshape(clf_h, [-1, n_embd])
        clf_logits = clf(clf_h, 1, train=train)
        clf_logits = tf.reshape(clf_logits, [-1, 2])

        clf_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=clf_logits, labels=Y)
        return clf_logits, clf_losses, lm_losses

# Sample from Model

In [7]:
def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.compat.v1.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    return tf.cond(
       pred=tf.equal(k, 0),
       true_fn=lambda: logits,
       false_fn=lambda: _top_k(),
    )


def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    def body(past, prev, output):
        next_outputs = step(hparams, prev, past=past)
        logits = next_outputs['logits'][:, -1, :]  / tf.cast(temperature, dtype=tf.float32)
        logits = top_k_logits(logits, k=top_k)
        samples = tf.random.categorical(logits=logits, num_samples=1, dtype=tf.int32)
        return [
            next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
            samples,
            tf.concat([output, samples], axis=1)
        ]

    past, prev, output = body(None, context, context)

    def cond(*args):
        return True

    _, _, tokens = tf.while_loop(
        cond=cond, body=body,
        maximum_iterations=length - 1,
        loop_vars=[
            past,
            prev,
            output
        ],
        shape_invariants=[
            tf.TensorShape(past_shape(hparams=hparams, batch_size=batch_size)),
            tf.TensorShape([batch_size, None]),
            tf.TensorShape([batch_size, None]),
        ],
        back_prop=False,
    )

    return tokens

In [8]:
from pathlib import Path
def load_dataset(enc, path):
    paths = []
    if os.path.isfile(path):
        # Simple file
        paths.append(path)
    elif os.path.isdir(path):
        # Directory
        for i, (dirpath, _, fnames) in enumerate(os.walk(path)):
            for fname in fnames:
                paths.append(os.path.join(dirpath, fname))
    else:
        # Assume glob
        paths = glob.glob(path)

        
    token_chunks = []
    raw_text = ''
    for i, path in enumerate(tqdm.tqdm(paths)):
#         if i >= 10000: break
        try:
            with open(path, 'r') as fp:
                raw_text += fp.read()
                raw_text += '<|endoftext|>'
            tokens = np.stack(enc.encode(raw_text))
            token_chunks.append(tokens)
            raw_text = ''
        except Exception as e:
            print(e)
    return token_chunks

def binary_search(f, lo, hi):
    if f(lo) or not f(hi):
        return None
    while hi > lo + 1:
        mid = (lo + hi) // 2
        if f(mid):
            hi = mid
        else:
            lo = mid
    return hi


class Sampler(object):
    """Fairly samples a slice from a set of variable sized chunks.

    'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
    but without crossing chunk boundaries."""

    def __init__(self, chunks, seed=None):
        self.chunks = chunks
        self.total_size = sum(chunk.shape[0] for chunk in chunks)
        self.boundaries = [0]
        for i in range(len(chunks)):
            self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
        self.rs = np.random.RandomState(seed=seed)

    def sample(self, length):
        assert length < self.total_size // len(
            self.chunks
        ), "Dataset files are too small to sample {} tokens at a time".format(
            length)
        while True:
            index = self.rs.randint(0, self.total_size - length - 1)
            i = binary_search(lambda j: self.boundaries[j] > index, 0,
                              len(self.boundaries) - 1) - 1
            if self.boundaries[i + 1] > index + length:
                within_chunk = index - self.boundaries[i]
                return self.chunks[i][within_chunk:within_chunk + length]

In [9]:
class Args():
    def __init__(self, trn_dataset, model_name, combine, batch_size, learning_rate, optimizer, noise, top_k, top_p, run_name, sample_every, sample_length, sample_num, save_every, val_dataset, val_batch_size, val_batch_count, val_every, pretrained, iterations):
        self.trn_dataset = trn_dataset
        self.model_name = model_name
        self.combine = combine
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.noise = noise
        self.top_k = top_k
        self.top_p = top_p
        self.run_name = run_name
        self.sample_every = sample_every
        self.sample_length = sample_length
        self.sample_num = sample_num
        self.save_every = save_every
        self.val_dataset = val_dataset
        self.val_batch_size = val_batch_size
        self.val_batch_count = val_batch_count
        self.val_every = val_every
        self.pretrained = pretrained
        self.iterations = iterations

In [14]:
args = Args(
                trn_dataset="/tf/src/data/methods/DATA00M_[god-r]/train",
                model_name="117M",
                combine=50000,
                batch_size=1, # DO NOT TOUCH. INCREASING THIS WILL RAIN DOWN HELL FIRE ONTO YOUR COMPUTER.
                learning_rate=0.00002,
                optimizer="sgd",
                noise=0.0,
                top_k=40,
                top_p=0.0,
                run_name="m4",
                sample_every=100,
                sample_length=1023,
                sample_num=1,
                save_every=1000,
                val_dataset="/tf/src/data/methods/DATA00M_[god-r]/valid",
                val_batch_size=1,
                val_batch_count=40,
                val_every=100,
                pretrained=True,
                iterations=493000
    )

In [None]:
enc = get_encoder(args.model_name, "models")
trn_set = load_dataset(enc, args.trn_dataset)
val_set = load_dataset(enc, args.val_dataset)
len(trn_set), len(val_set)

  0%|          | 729/972771 [00:13<6:50:51, 39.43it/s]

In [None]:
# DATASET_SIZE = len(dataset)
# TRN_SET_SIZE = int(DATASET_SIZE * 0.8)
# VAL_SET_SIZE = int(DATASET_SIZE * 0.1)
# TST_SET_SIZE = int(DATASET_SIZE * 0.1)

# trn_set = dataset[:TRN_SET_SIZE]
# val_set = dataset[TRN_SET_SIZE:TRN_SET_SIZE + VAL_SET_SIZE]
# tst_set = dataset[-TST_SET_SIZE:]
# DATASET_SIZE, len(trn_set), len(val_set), len(tst_set)

In [11]:
CHECKPOINT_DIR = 'checkpoint'
SAMPLE_DIR = 'samples'

trn_losses = []
trn_avgs   = []
val_losses = []

In [12]:
# Restore previous metrics 
with open(os.path.join(CHECKPOINT_DIR, args.run_name, 'metrics.pickle'), 'rb') as f:
    loss_dict = pickle.load(f)
    
trn_losses = loss_dict["trn_losses"]
trn_avgs   = loss_dict["avg_trn_losses"]
val_losses = loss_dict["val_losses"]

In [13]:
len(trn_losses), len(trn_avgs), len(val_losses)

(508000, 507999, 5080)

In [None]:
def maketree(path):
    try:
        os.makedirs(path)
    except:
        pass


def randomize(context, hparams, p):
    if p > 0:
        mask = tf.random.uniform(shape=tf.shape(input=context)) < p
        noise = tf.random.uniform(shape=tf.shape(input=context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32)
        return tf.compat.v1.where(mask, noise, context)
    else:
        return context


def main():
    enc = get_encoder(args.model_name, "models")
    hparams = default_hparams()

    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.compat.v1.Session(config=config) as sess:
        context = tf.compat.v1.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model(hparams=hparams, X=context_in)
        
        val_context = tf.compat.v1.placeholder(tf.int32, [args.val_batch_size, None])
        val_output = model(hparams=hparams, X=val_context)
        

        tf_sample = sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k)

        all_vars = [v for v in tf.compat.v1.trainable_variables() if 'model' in v.name]
        train_vars = all_vars

        if args.optimizer == 'adam':
            opt = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        
        
        ## Collect Metrics for Tensorboard
        with tf.compat.v1.name_scope('metrics'):
            with tf.compat.v1.name_scope('train'):
                trn_loss        = tf.reduce_mean(
                                    input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                                        labels=context[:, 1:], logits=output['logits'][:, :-1]))
                trn_loss_summ   = tf.compat.v1.summary.scalar('loss', trn_loss)
                
                trn_med_ph      = tf.compat.v1.placeholder(tf.float32,shape=None,name='median')
                trn_med_summ    = tf.compat.v1.summary.scalar('median', trn_med_ph)
                
                trn_mean_ph     = tf.compat.v1.placeholder(tf.float32,shape=None,name='mean')
                trn_mean_summ   = tf.compat.v1.summary.scalar('mean', trn_mean_ph)
            
            with tf.compat.v1.name_scope('valid'):
                val_loss        = tf.reduce_mean(
                                    input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
                                        labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
                val_loss_summ   = tf.compat.v1.summary.scalar('loss', val_loss)



                val_med_ph      = tf.compat.v1.placeholder(tf.float32,shape=None,name='median')
                val_med_summ    = tf.compat.v1.summary.scalar('median', val_med_ph)
            
            
            
        trn_summaries = tf.compat.v1.summary.merge([trn_loss_summ, trn_med_summ, trn_mean_summ])
        val_summaries = tf.compat.v1.summary.merge([val_loss_summ, val_med_summ])

        opt_grads = tf.gradients(ys=trn_loss, xs=train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)

        trn_summ_log = tf.compat.v1.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name, 'train'))
        val_summ_log = tf.compat.v1.summary.FileWriter(os.path.join(CHECKPOINT_DIR, args.run_name, 'valid'))
        
        saver = tf.compat.v1.train.Saver(
            var_list=all_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.compat.v1.global_variables_initializer())

        ckpt = tf.train.latest_checkpoint(
            os.path.join(CHECKPOINT_DIR, args.run_name))
        if ckpt is None:
            # Get fresh GPT weights if new run.
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))

        if args.pretrained == True:
            print('Loading checkpoint', ckpt)
            saver.restore(sess, ckpt)

        print('Loading dataset...')
        data_sampler = Sampler(trn_set)
        if args.val_every > 0:
            val_chunks = val_set
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[val_data_sampler.sample(512) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
                
            # Save metrics such as losses
            metrics = {
                "trn_losses": trn_losses,
                "avg_trn_losses": trn_avgs,
                "val_losses": val_losses
            }

            with open(os.path.join(CHECKPOINT_DIR, args.run_name, 'metrics.pickle'), 'wb') as f:
                pickle.dump(metrics, f, protocol=pickle.HIGHEST_PROTOCOL)

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))
                
        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            val_losses.append(v_val_loss)
            v_summary = sess.run(val_summaries, feed_dict={val_loss: v_val_loss, val_med_ph: median(losses)})
            val_summ_log.add_summary(v_summary, counter)
            val_summ_log.flush()
            print(
                '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'
                .format(
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(256) for _ in range(args.batch_size)]


        avg_trn_loss = (0.0, 0.1)
#         trn_losses = [0.0]
#         val_losses = []
        start_time = time.time()
#         trn_avgs = []

        try:
            for _ in range(args.iterations):
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()
                    
                if _ == 0:
                    avg = 0
                else: avg = avg_trn_loss[0] / avg_trn_loss[1]

                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, trn_loss, trn_summaries),
                    feed_dict={context: sample_batch(), trn_med_ph: median(trn_losses), trn_mean_ph: avg})
                trn_losses.append(v_loss)
                
                trn_summ_log.add_summary(v_summary, counter)

                avg_trn_loss = (avg_trn_loss[0] * 0.99 + v_loss,
                            avg_trn_loss[1] * 0.99 + 1.0)

                trn_avgs.append(avg)
                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_trn_loss[0] / avg_trn_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
        
        save()
if __name__ == '__main__':
    main()

In [None]:
%tensorboard --logdir ./checkpoint/unconditional_experiment/

In [None]:
! curl -X POST -H 'Content-type: application/json' --data '{"text":"from: semeru tower 1\nstatus: model 4 finished training"}' https://hooks.slack.com/services/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

In [None]:
# Reading in the data
with open(os.path.join(CHECKPOINT_DIR, args.run_name, 'metrics.pickle'), 'rb') as f:
    loss_dict = pickle.load(f)
    
loss_dict