Skip to content

Commit

Permalink
Implement top_k sampling (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
AeroScripts authored Jun 11, 2021
1 parent 6a020fa commit 56626ff
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 16 additions & 3 deletions mesh_transformer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@ def softmax_sample(key, logits, _, temp=1):
return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None


def nucleaus_filter(logits, top_p=0.9):
def nucleaus_filter(logits, top_p=0.9, top_k=None):
sorted_logits = jnp.sort(logits)[:, ::-1] # sort descending
sorted_indices = jnp.argsort(logits)[:, ::-1]
cumulative_probs = jnp.cumsum(jax.nn.softmax(sorted_logits), axis=-1)

if top_k is not None:
# Keep only top_k tokens
indices_range = jnp.arange(len(sorted_indices[0]))
indices_range = jnp.stack([indices_range] * len(sorted_indices), axis=0)

sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0)

_, indices_to_remove = jax.lax.sort_key_val(sorted_indices, sorted_indices_to_remove)

logit_mask = 1e10 * indices_to_remove

logits -= logit_mask

# Remove tokens with cumulative probability above a threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove = jnp.concatenate((jnp.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove), axis=-1)[:, :-1]
Expand All @@ -25,8 +38,8 @@ def nucleaus_filter(logits, top_p=0.9):
return logits


def nucleaus_sample(key, logits, _, top_p=0.9, temp=1):
logits = nucleaus_filter(logits, top_p)
def nucleaus_sample(key, logits, _, top_p=0.9, temp=1, top_k=None):
logits = nucleaus_filter(logits, top_p, top_k=top_k)

return softmax_sample(key, logits, None, temp=temp)

Expand Down
4 changes: 2 additions & 2 deletions resharding_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
# move the state to CPU/system memory so it's not duplicated by xmap
network.state = jax.device_put(network.state, jax.devices("cpu")[0])

def infer(context, top_p=0.9, temp=1.0, gen_len=512):
def infer(context, top_k=40, top_p=0.9, temp=1.0, gen_len=512):
tokens = tokenizer.encode(context)

provided_ctx = len(tokens)
Expand All @@ -63,7 +63,7 @@ def infer(context, top_p=0.9, temp=1.0, gen_len=512):
length = np.ones(per_replica_batch, dtype=np.uint32) * len(tokens)

start = time.time()
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "temp": np.ones(per_replica_batch) * temp})
output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(per_replica_batch) * top_p, "top_k": top_k is not None and (np.ones(per_replica_batch, dtype=np.int32) * top_k) or None, "temp": np.ones(per_replica_batch) * temp})

samples = []
decoded_tokens = output[1][0]
Expand Down

0 comments on commit 56626ff

Please sign in to comment.