Skip to content

Commit

Permalink
Record the example transformer layer as a test case.
Browse files Browse the repository at this point in the history
  • Loading branch information
axch committed Jul 13, 2023
1 parent 4345ccd commit f348366
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tests/dynamic_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,71 @@ def func(size):
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data)

@parameterized.parameters((True,), (False,))
def test_pile_map_end_to_end_fprop_layer(self, disable_jit):
config.update('jax_disable_jit', disable_jit)

def fprop_layer(params, x):
((xnorm_scale, xnorm_bias), (wqkv, wqkv_bias), (wo, wo_bias),
(ynorm_scale, ynorm_bias), (w_i, w_i_bias), (w_o, w_o_bias)) = params
xnorm = jax.nn.standardize(x) * xnorm_scale + xnorm_bias
qkv = jnp.einsum('te,ihqe->ithq', xnorm, wqkv) + wqkv_bias[:, None]
q, k, v = qkv
outer = jnp.einsum('thq,shq->tsh', q, k) / jnp.asarray(
jnp.sqrt(v.shape[-1]), dtype=x.dtype)

alpha = jax.nn.softmax(outer, 2)
inner = jnp.einsum('tsh,shq->thq', alpha, v)
y = jnp.einsum('thq,hqe->te', inner, wo) + wo_bias + x
ynorm = jax.nn.standardize(y) * ynorm_scale + ynorm_bias
act = jax.nn.gelu(jnp.einsum('te,ef->tf', ynorm, w_i) + w_i_bias)
z = jnp.einsum('tf,fe->te', act, w_o) + w_o_bias + y
return z

params = [
(jnp.ones(128), jnp.zeros(128)), # xnorm_scale, xnorm_bias
(jnp.ones((3, 16, 64, 128)), jnp.zeros((3, 16, 64))), # wqkv, wqkv_bias
(jnp.ones((16, 64, 128)), jnp.zeros(128)), # wo, wo_bias
(jnp.ones(128), jnp.zeros(128)), # ynorm_scale, ynorm_bias
(jnp.ones((128, 4096)), jnp.zeros(4096)), # w_i, w_i_bias
(jnp.ones((4096, 128)), jnp.zeros(128)), # w_o, w_o_bias
]

xs = [
jnp.zeros((512, 128)),
jnp.zeros((386, 128)),
jnp.zeros((420, 128)),
]

def pile_stack(xs: list[jax.Array]) -> batching.Pile:
max_length = max(len(x) for x in xs)
lengths = jnp.array([len(x) for x in xs])
lengths = jax.lax.convert_element_type(lengths, core.bint(max_length))
xs_padded = jnp.stack([jnp.zeros((max_length, 128), dtype=x.dtype
).at[:x.shape[0]].set(x) for x in xs])
# jax.vmap(lambda l, xp: xp[:l, :], out_axes=pile_axis)(lengths, xs_padded)

# binder = i
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
# elt_ty = f32[[3, 1, 4].i, 128]
elt_ty = core.DShapedArray((batching.IndexedAxisSize(binder, lengths), 128),
xs_padded.dtype)
# aval = i:(Fin 3) => f32[[3, 1, 4].i, 128]
aval = batching.PileTy(binder, len(xs), elt_ty)
xs_pile = batching.Pile(aval, xs_padded)
return xs_pile

xs_pile = pile_stack(xs)

fprop_batched = jax.vmap(fprop_layer,
in_axes=(None, batching.pile_axis),
out_axes=batching.pile_axis,
axis_size=3)

result_pile = fprop_batched(params, xs_pile)
self.assertIsInstance(result_pile, batching.Pile)
self.assertRegex(str(result_pile.aval), r'Var[0-9]+:3 => (f32|f64)\[bint\{≤512\}\[3\] with value: \[512 386 420\]\.Var[0-9]+,128\]')
self.assertAllClose(result_pile.data.shape, (3, 512, 128))

def pile_map(f):
def mapped(*piles):
Expand Down

0 comments on commit f348366

Please sign in to comment.