- [The Illustrated GPT-2 (Visualizing Transformer Language Models) – Jay Alammar – Visualizing machine learning one concept at a time](http://jalammar.github.io/illustrated-gpt2/)
- [openai/gpt-2: Code for the paper "Language Models are Unsupervised Multitask Learners"](https://github.com/openai/gpt-2)
- [Morizeyao/GPT2-Chinese: Chinese version of GPT2 training code, using BERT tokenizer.](https://github.com/Morizeyao/GPT2-Chinese)

# gpt-2

## Encoder

主要目的是对 text 利用 unicode 进行编码和解码。

In [12]:
import os
import json
import regex as re
from functools import lru_cache

In [13]:
# byte: unicode
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))

In [14]:
# Bi-gram
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs

In [15]:
class Encoder:
    def __init__(self, encoder, bpe_merges, errors='replace'):
        self.encoder = encoder
        self.decoder = {v:k for k,v in self.encoder.items()}
        self.errors = errors # how to handle errors in decoding
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v:k for k,v in self.byte_encoder.items()}
        self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
        self.cache = {}

        # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
        self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token)
        pairs = get_pairs(word)

        if not pairs:
            return token

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
        return text

In [16]:
def get_encoder(model_name, models_dir):
    with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
        encoder = json.load(f)
    with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
        bpe_data = f.read()
    bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(
        encoder=encoder,
        bpe_merges=bpe_merges,
    )

In [17]:
models_dir = "/Users/HaoShaochun/Documents/play/gpt-2/models/"
model_name = "124m"

In [44]:
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
    encoder = json.load(f)

In [45]:
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
    bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]

In [213]:
encoder = encoder
decoder = {v:k for k,v in encoder.items()}
errors = 'replace'
byte_encoder = bytes_to_unicode()
byte_decoder = {v:k for k,v in byte_encoder.items()}
bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
cache = {}
pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

#
# \s+ 匹配任意空格
# \s+(?!\S) 匹配末尾的空格（空格前面有非空格时不匹配）

# [^\s\p{L}\p{N}]+ 匹配不含空格、letter 和 number 的，比如：# * 之类
# \p{L}+ 匹配任意的 letter，比如 abc
# \p{N}+ 匹配任意的 number，比如 123
#

In [216]:
def bpe(token):
    if token in cache:
        return cache[token]
    word = tuple(token)
    pairs = get_pairs(word)

    if not pairs:
        return token

    while True:
        bigram = min(pairs, key = lambda pair: bpe_ranks.get(pair, float('inf')))
        if bigram not in bpe_ranks:
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)
    word = ' '.join(word)
    cache[token] = word
    return word

In [245]:
# encode
text = "I'm loving U."
bpe_tokens = []
for token in re.findall(pat, text):
    # byte -> token
    token = ''.join(byte_encoder[b] for b in token.encode('utf-8'))
    bpe_tokens.extend(encoder[bpe_token] for bpe_token in bpe(token).split(' '))

In [255]:
bpe_tokens

[40, 1101, 14442, 471, 13]

In [246]:
re.findall(pat, text)

['I', "'m", ' loving', ' U', '.']

In [247]:
for i in "loving".encode('utf-8'):
    print(i)

108
111
118
105
110
103


In [253]:
bpe("'m").split(" ")

["'m"]

In [268]:
# decode

text = ''.join([decoder[token] for token in bpe_tokens])
text

"I'mĠlovingĠU."

In [269]:
decoder[1101]

"'m"

In [271]:
for c in text:
    print(c, byte_decoder[c])

I 73
' 39
m 109
Ġ 32
l 108
o 111
v 118
i 105
n 110
g 103
Ġ 32
U 85
. 46


In [263]:
text = bytearray([byte_decoder[c] for c in text]).decode('utf-8', errors=errors)
text

"I'm loving U."

In [275]:
bytearray([73, 39, 109]).decode('utf8')

"I'm"

## Model

In [1]:
import numpy as np
import tensorflow as tf
from dataclasses import dataclass

In [77]:
@dataclass
class HParams:
    n_vocab:int=50257
    n_ctx:int=1024
    n_embd:int=768
    n_head:int=12
    n_layer:int=12

In [78]:
def default_hparams():
    return HParams()

In [79]:
def model(hparams, X, past=None, scope='model', reuse=False):
    with tf.variable_scope(scope, reuse=reuse):
        results = {}
        batch, sequence = shape_list(X)

        wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.01))
        wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                             initializer=tf.random_normal_initializer(stddev=0.02))
        past_length = 0 if past is None else tf.shape(past)[-2]
        h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

        # Transformer
        presents = []
        pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
        assert len(pasts) == hparams.n_layer
        for layer, past in enumerate(pasts):
            h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
            presents.append(present)
        results['present'] = tf.stack(presents, axis=1)
        h = norm(h, 'ln_f')

        # Language model loss.  Do tokens <n predict token n?
        h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
        logits = tf.matmul(h_flat, wte, transpose_b=True)
        logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
        results['logits'] = logits
        return results

In [81]:
models_dir = "/Users/HaoShaochun/Documents/Study/gpt-2/models/"
model_name = "124M"
enc = get_encoder(model_name, models_dir)

### 输入

In [80]:
def shape_list(x):
    """Deal with dynamic shape in tensorflow cleanly."""
    static = x.shape.as_list()
    dynamic = tf.shape(x)
    return [dynamic[i] if s is None else s for i, s in enumerate(static)]

In [237]:
batch_size = 1
start_token = None
X = tf.fill([batch_size, 1], enc.encoder['<|endoftext|>'])

In [240]:
X

<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[50256]], dtype=int32)>

In [229]:
hparams = default_hparams()

In [230]:
def positions_for(tokens, past_length):
    batch_size = tf.shape(tokens)[0]
    nsteps = tf.shape(tokens)[1]
    return expand_tile(past_length + tf.range(nsteps), batch_size)

In [231]:
def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)

In [259]:
X

<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[50256]], dtype=int32)>

In [258]:
positions_for(X, 5)

<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[5]], dtype=int32)>

In [257]:
tf.shape(X)[0], tf.shape(X)[1]

(<tf.Tensor: shape=(), dtype=int32, numpy=1>,
 <tf.Tensor: shape=(), dtype=int32, numpy=1>)

In [266]:
tf.gather?

In [270]:
X.numpy()

array([[50256]], dtype=int32)

In [272]:
wte

<tf.Variable 'model/wte:0' shape=(50257, 768) dtype=float32, numpy=
array([[ 0.01648771, -0.0164929 ,  0.02838891, ..., -0.00645825,
        -0.00522234, -0.0007867 ],
       [-0.01100282, -0.05329556, -0.01903998, ..., -0.04945637,
         0.0188335 ,  0.00340233],
       [-0.00040727, -0.00531276,  0.01675709, ..., -0.0118375 ,
         0.01513721,  0.01788099],
       ...,
       [-0.00225281, -0.00350592, -0.02709479, ...,  0.00464215,
         0.01216319, -0.00408764],
       [ 0.01639596,  0.01235129, -0.00731677, ..., -0.00058879,
         0.0012605 ,  0.03780697],
       [-0.01169358,  0.02725702,  0.01802673, ..., -0.0029963 ,
         0.00533693,  0.01667264]], dtype=float32)>

In [234]:
past = None
with tf.compat.v1.variable_scope("model", reuse=False):
    results = {}
    batch, sequence = shape_list(X)

    wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.01))
    wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.02))
    past_length = 0 if past is None else tf.shape(past)[-2]
    h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))

In [236]:
h.shape

TensorShape([1, 1, 768])

### Transformer

In [102]:
presents = []
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
assert len(pasts) == hparams.n_layer

In [286]:
tf.compat.v1.rsqrt(4.0)

<tf.Tensor: shape=(), dtype=float32, numpy=0.5>

In [288]:
tf.square(2)

<tf.Tensor: shape=(), dtype=int32, numpy=4>

In [290]:
x = tf.constant([[1., 1.], [2., 2.]])

In [291]:
tf.reduce_mean(x, -1)

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([1., 2.], dtype=float32)>

In [343]:
def softmax(x, axis=-1):
    x = x - tf.reduce_max(x, axis=axis, keepdims=True)
    ex = tf.exp(x)
    return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)

def gelu(x):
    return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))

def norm(x, scope, *, axis=-1, epsilon=1e-5):
    """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
    with tf.compat.v1.variable_scope(scope):
        n_state = x.shape[-1]#.value
        g = tf.compat.v1.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
        b = tf.compat.v1.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
        u = tf.reduce_mean(x, axis=axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True)
        x = (x - u) * tf.compat.v1.rsqrt(s + epsilon)
        x = x*g + b
        return x

def split_states(x, n):
    """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
    *start, m = shape_list(x)
    return tf.reshape(x, start + [n, m//n])

def merge_states(x):
    """Smash the last two dimensions of x into a single dimension."""
    *start, a, b = shape_list(x)
    return tf.reshape(x, start + [a*b])

def conv1d(x, scope, nf, *, w_init_stdev=0.02):
    with tf.compat.v1.variable_scope(scope):
        *start, nx = shape_list(x)
        w = tf.compat.v1.get_variable('w', [1, nx, nf], 
                                      initializer=tf.random_normal_initializer(stddev=w_init_stdev))
        b = tf.compat.v1.get_variable('b', [nf], 
                                      initializer=tf.constant_initializer(0))
        c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf])
        return c

def attention_mask(nd, ns, *, dtype):
    """1's in the lower triangle, counting from the lower right corner.

    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)


def split_heads(x):
    # From [batch, sequence, features] to [batch, heads, sequence, features]
    return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
def attn(x, scope, n_state, *, past, hparams):
    assert x.shape.ndims == 3  # Should be [batch, sequence, features]
    assert n_state % hparams.n_head == 0
    if past is not None:
        assert past.shape.ndims == 5  # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]

    def split_heads(x):
        # From [batch, sequence, features] to [batch, heads, sequence, features]
        return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])

    def merge_heads(x):
        # Reverse of split_heads
        return merge_states(tf.transpose(x, [0, 2, 1, 3]))

    def mask_attn_weights(w):
        # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
        _, _, nd, ns = shape_list(w)
        b = attention_mask(nd, ns, dtype=w.dtype)
        b = tf.reshape(b, [1, 1, nd, ns])
        w = w*b - tf.cast(1e10, w.dtype)*(1-b)
        return w

    def multihead_attn(q, k, v):
        # q, k, v have shape [batch, heads, sequence, features]
        w = tf.matmul(q, k, transpose_b=True)
        w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))

        w = mask_attn_weights(w)
        w = softmax(w)
        a = tf.matmul(w, v)
        return a

    with tf.compat.v1.variable_scope(scope):
        c = conv1d(x, 'c_attn', n_state*3)
        q, k, v = map(split_heads, tf.split(c, 3, axis=2))
        present = tf.stack([k, v], axis=1)
        if past is not None:
            pk, pv = tf.unstack(past, axis=1)
            k = tf.concat([pk, k], axis=-2)
            v = tf.concat([pv, v], axis=-2)
        a = multihead_attn(q, k, v)
        a = merge_heads(a)
        a = conv1d(a, 'c_proj', n_state)
        return a, present


def mlp(x, scope, n_state, *, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]#.value
        h = gelu(conv1d(x, 'c_fc', n_state))
        h2 = conv1d(h, 'c_proj', nx)
        return h2


def block(x, scope, *, past, hparams):
    with tf.compat.v1.variable_scope(scope):
        nx = x.shape[-1]#.value
        a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
        x = x + a
        m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
        x = x + m
        return x, present

In [306]:
x = norm(h, 'ln_1')

In [323]:
c = conv1d(x, 'c_attn', 768*3)

In [344]:
q, k, v = map(split_heads, tf.split(c, 3, axis=2))

In [366]:
tf.cast(1e10, w.dtype)

<tf.Tensor: shape=(), dtype=float32, numpy=10000000000.0>

In [367]:
w = tf.matmul(q, k, transpose_b=True)

In [369]:
w = w * tf.compat.v1.rsqrt(tf.cast(v.shape[-1], w.dtype))

In [370]:
w

<tf.Tensor: shape=(1, 12, 1, 1), dtype=float32, numpy=
array([[[[ 0.0291546 ]],

        [[ 0.18900473]],

        [[ 0.5840976 ]],

        [[-0.2806542 ]],

        [[-0.24451897]],

        [[-0.07359418]],

        [[ 0.4051998 ]],

        [[-0.17023544]],

        [[-0.08238913]],

        [[ 0.02659817]],

        [[-0.1240125 ]],

        [[ 0.17797786]]]], dtype=float32)>

In [371]:
def mask_attn_weights(w):
    # w has shape [batch, heads, dst_sequence, src_sequence]
    # where information flows from src to dst.
    _, _, nd, ns = shape_list(w)
    b = attention_mask(nd, ns, dtype=w.dtype)
    b = tf.reshape(b, [1, 1, nd, ns])
    w = w*b - tf.cast(1e10, w.dtype)*(1-b)
    return w

In [376]:
def attention_mask(nd, ns, *, dtype):
    """
    1's in the lower triangle, counting from the lower right corner.
    Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), 
    but doesn't produce garbage on TPUs.
    """
    i = tf.range(nd)[:,None]
    j = tf.range(ns)
    m = i >= j - ns + nd
    return tf.cast(m, dtype)

In [378]:
attention_mask(1, 1, dtype=w.dtype)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>

In [379]:
tf.range(1)[:,None]

<tf.Tensor: shape=(1, 1), dtype=int32, numpy=array([[0]], dtype=int32)>

In [383]:
tf.range(1) -1 + 1

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>

In [384]:
tf.range(1)[:,None] >= tf.range(1) -1 + 1

<tf.Tensor: shape=(1, 1), dtype=bool, numpy=array([[ True]])>

In [385]:
tf.cast(_, w.dtype)

<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>

In [134]:
hparams

HParams(n_vocab=50257, n_ctx=1024, n_embd=768, n_head=12, n_layer=12)

In [154]:
past = None
with tf.compat.v1.variable_scope("model", reuse=False):
    results = {}
    batch, sequence = shape_list(X)

    wpe = tf.compat.v1.get_variable('wpe', [hparams.n_ctx, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.01))
    wte = tf.compat.v1.get_variable('wte', [hparams.n_vocab, hparams.n_embd],
                         initializer=tf.random_normal_initializer(stddev=0.02))
    past_length = 0 if past is None else tf.shape(past)[-2]
    h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length))
    
    presents = []
    pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
    assert len(pasts) == hparams.n_layer
    for layer, past in enumerate(pasts):
        h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
        presents.append(present)
    
    results['present'] = tf.stack(presents, axis=1)
    h = norm(h, 'ln_f')
    
    
    h_flat = tf.reshape(h, [batch*sequence, hparams.n_embd])
    logits = tf.matmul(h_flat, wte, transpose_b=True)
    logits = tf.reshape(logits, [batch, sequence, hparams.n_vocab])
    results['logits'] = logits

### Generate

In [155]:
out_logits = logits[:, :, :hparams.n_vocab]

In [167]:
out_presents = results['present']

In [168]:
out_presents.shape

TensorShape([8, 12, 2, 12, 1, 64])

In [162]:
def past_shape(*, hparams, batch_size=None, sequence=None):
    return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]

In [163]:
past_shape(hparams=hparams, batch_size=batch_size)

[8, 12, 2, 12, None, 64]

In [169]:
out_presents.set_shape(past_shape(hparams=hparams, batch_size=batch_size))

In [170]:
out_presents.shape

TensorShape([8, 12, 2, 12, 1, 64])

In [176]:
next_logits = out_logits[:, -1, :]  / tf.cast(1, dtype=tf.float32)

In [177]:
next_logits

<tf.Tensor: shape=(8, 50257), dtype=float32, numpy=
array([[-0.22656128,  0.51452285,  0.26646936, ...,  0.12472114,
        -0.05033965,  0.6154422 ],
       [-0.22656128,  0.51452285,  0.26646936, ...,  0.12472114,
        -0.05033965,  0.6154422 ],
       [-0.22656128,  0.51452285,  0.26646936, ...,  0.12472114,
        -0.05033965,  0.6154422 ],
       ...,
       [-0.22656128,  0.51452285,  0.26646936, ...,  0.12472114,
        -0.05033965,  0.6154422 ],
       [-0.22656104,  0.5145227 ,  0.2664702 , ...,  0.12472016,
        -0.05034024,  0.6154418 ],
       [-0.22656104,  0.5145227 ,  0.2664702 , ...,  0.12472016,
        -0.05034024,  0.6154418 ]], dtype=float32)>

In [None]:
def top_k_logits(logits, k):
    if k == 0:
        # no truncation
        return logits

    def _top_k():
        values, _ = tf.nn.top_k(logits, k=k)
        min_values = values[:, -1, tf.newaxis]
        return tf.where(
            logits < min_values,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
    return tf.cond(
       tf.equal(k, 0),
       lambda: logits,
       lambda: _top_k(),
    )

In [180]:
nxt_logits = top_k_logits(next_logits, k=0)

In [192]:
def top_p_logits(logits, p):
    """Nucleus sampling"""
    batch, _ = logits.shape.as_list()
    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
    cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
    indices = tf.stack([
        tf.range(0, batch),
        # number of indices to include
        tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
    ], axis=-1)
    min_values = tf.gather_nd(sorted_logits, indices)
    print(min_values)
    return tf.where(
        logits < min_values,
        tf.ones_like(logits) * -1e10,
        logits,
    )

In [None]:
nxt_logits = top_p_logits(nxt_logits, p=1)

In [201]:
samples = tf.random.categorical(nxt_logits, num_samples=1, dtype=tf.int32)

In [202]:
samples

<tf.Tensor: shape=(8, 1), dtype=int32, numpy=
array([[38531],
       [13258],
       [49950],
       [33422],
       [39794],
       [ 8949],
       [48245],
       [24931]], dtype=int32)>

In [227]:
enc.decode([seq[0] for seq in samples.numpy().tolist()])

' Crunch Pope criminality baptism Cells 101 IOC Kasich'