Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change batch size to 1 and add an example #56

Merged
merged 4 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/gpt2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
# T - decode step size
def get_config():
config = collections.namedtuple('Config', ['B', 'K', 'S', 'T'])
return config(4, 8, 64, 1)
return config(1, 8, 64, 1)
46 changes: 46 additions & 0 deletions models/gpt2/example_gen_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import jax
import jax.numpy as jnp

import os
import sys

from transformers import GPT2Tokenizer

tknzr = GPT2Tokenizer.from_pretrained("gpt2")

import model
import config

assert_dir = os.path.join(sys.path[0], "assets")

ids = tknzr.encode("Hello, my dog is cute")

params = model.load_gpt2_model("gpt2", assert_dir)

cfg = config.get_config()
B = cfg.B # Batch size
K = cfg.K # Input sequence length
S = cfg.S
T = cfg.T # Batched decode
L, _, _, Q, H, _ = model.model_sizes["gpt2"]

nids = len(ids)
t = jnp.asarray([nids], dtype=jnp.int32)
prompt = jnp.asarray([ids + [0] * (K - nids)], dtype=jnp.int32)
# t = [6]
# prompt = [[15496 11 616 3290 318 13779 0 0]]


kv = model.init_kv(1, S, L, Q, H, dtype=jnp.float32)
kv = [jnp.tile(k, (1, prompt.shape[0], 1, 1, 1)) for k in kv]
kv, x = model.encode(params, kv, prompt, 0, t)
ids.append(x.item())
Copy link
Contributor Author

@wangkuiyi wangkuiyi Feb 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the first call to encode should return x=[13]. This could be a ground-truth, because I tried variaous implementations of GPT-2, including HuggingFace's TF version, Flax version, and PyTorch version, and this one, and all of them return the same result.


for _ in range(12):
kv, x = model.decode(params, kv, x, t)
t = t + 1
ids.append(x.item())
# ids = [15496, 11, 616, 3290, 318, 13779, 13, 314, 1101, 407, 1654, 611, 673, 338, 257, 26188, 393, 407, 13]
assert (
tknzr.decode(ids) == "Hello, my dog is cute. I'm not sure if she's a puppy or not."
)