Skip to content

Commit

Permalink
more configurable sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
kingoflolz committed Apr 24, 2021
1 parent 60c277e commit 239d312
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 29 deletions.
25 changes: 4 additions & 21 deletions device_sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import functools
import json
import time

Expand All @@ -9,7 +8,7 @@

from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import softmax_sample
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
import transformers
from smart_open import open
Expand All @@ -20,15 +19,8 @@
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--tpu", type=str, help="Name of TPU to train on.")
parser.add_argument("--tpu_region", type=str, help="Region of TPU to train on.")
parser.add_argument("--preemptible", action="store_true")

parser.add_argument("--config", type=str, default=None, help="Config file location")

parser.add_argument("--new", action="store_true", help="If set, deletes previous checkpoint, if it exists, and "
"starts a new training run")

args = parser.parse_args()
return args

Expand All @@ -37,16 +29,6 @@ def parse_args():
args = parse_args()
params = json.load(open(args.config))

if args.new:
print(f"Starting experiment {params['name']} from scratch! "
f"all data in gs://{params['bucket']}/{params['model_dir']}/ will be deleted")
input("Hit enter to continue")

tpu_name = args.tpu
region = args.tpu_region
preemptible = args.preemptible
clean_start = args.new

gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
Expand All @@ -62,7 +44,7 @@ def parse_args():
seq = params["seq"]
norm = params["norm"]

params["sampler"] = softmax_sample
params["sampler"] = nucleaus_sample
opt = optax.chain(
optax.scale(1 / gradient_accumulation_steps),
clip_by_global_norm(1),
Expand Down Expand Up @@ -114,7 +96,8 @@ def parse_args():
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

output = network.generate(batched_tokens, length, 512)
output = network.generate(batched_tokens, length, 512, {"top_p": np.ones(total_batch) * 0.9,
"temp": np.ones(total_batch) * 0.75})

for idx, o in enumerate(output[1][0][:, :, 0]):
print(f"sample {idx}: {repr(tokenizer.decode(o))}")
Expand Down
8 changes: 4 additions & 4 deletions mesh_transformer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@


# takes in a logit distribution, softmax and then sample
def softmax_sample(key, logits, _):
return jax.random.categorical(key, logits/0.75, -1).astype(jnp.uint32), None
def softmax_sample(key, logits, _, temp=1):
return jax.random.categorical(key, logits/temp, -1).astype(jnp.uint32), None


def nucleaus_filter(logits, top_p=0.9):
Expand All @@ -25,10 +25,10 @@ def nucleaus_filter(logits, top_p=0.9):
return logits


def nucleaus_sample(key, logits, _, top_p=0.9):
def nucleaus_sample(key, logits, _, top_p=0.9, temp=1):
logits = nucleaus_filter(logits, top_p)

return softmax_sample(key, logits, None)
return softmax_sample(key, logits, None, temp=temp)


if __name__ == "__main__":
Expand Down
10 changes: 6 additions & 4 deletions mesh_transformer/transformer_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def train_loss(x, y):
"opt_state": optimizer.init(params)
}

def generate(state, key, ctx, ctx_length, aux):
def generate(state, key, ctx, ctx_length, aux, sampler_options):
sampler = config["sampler"]
gen_length = self.gen_length

Expand All @@ -183,7 +183,7 @@ def generate_scan_fn(carry, sampler_input):
sample_key, new_key = jax.random.split(sample_key)

output, new_state = transformer.generate_once(next_token, decode_state)
next_token, sample_info = sampler(sample_key, output, sampler_input)
next_token, sample_info = sampler(sample_key, output, sampler_input, **sampler_options)

output = (next_token, sample_info)
new_carry = (next_token, new_state, new_key)
Expand Down Expand Up @@ -222,6 +222,7 @@ def generate_scan_fn(carry, sampler_input):
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...],
["batch", ...]),
out_axes=["batch", ...],
axis_resources={'shard': 'mp', 'batch': 'dp'})
Expand Down Expand Up @@ -292,7 +293,7 @@ def eval(self, sample):
# print(f"eval done in {time.time() - start:.06}s")
return out

def generate(self, ctx, ctx_length, gen_length):
def generate(self, ctx, ctx_length, gen_length, sampler_options):
key = hk.PRNGSequence(random.randint(0, 2 ** 60))

batch_size = ctx.shape[0]
Expand All @@ -303,4 +304,5 @@ def generate(self, ctx, ctx_length, gen_length):
jnp.array(key.take(batch_size)),
ctx,
np.array(ctx_length, dtype=np.uint32),
aux)
aux,
sampler_options)

0 comments on commit 239d312

Please sign in to comment.