In [27]:
from modeling_phi import Attention
import jax.numpy as jnp
import torch
from transformers import PhiConfig
from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer
from modeling_phi import forward_attention, forward_decoder_block, forward_decoder, forward_embedding, pt2jax, jax2pt, convert_attention, convert_decoder_block, forward_layer_norm, forward_mlp
from modeling_phi import phi_config, make_rotary_values, LayerNorm
import torch.nn as tnn

In [28]:
batch_size = 2
seq_len = 7
d_model = 32
head_dim = 8
n_heads_kv = 4
n_rep_kv = 1
vocab_size = 6
d_ff = 37
torch.manual_seed(1)

config_pt = PhiConfig(hidden_size=d_model, num_attention_heads=n_heads_kv, vocab_size=vocab_size,layer_norm_eps=1e-05, intermediate_size=d_ff, hidden_act='gelu_new')
config_jax = phi_config._replace(d_model=d_model, d_ff=d_ff, n_rep_kv=n_rep_kv, vocab_size=vocab_size, n_heads_kv=n_heads_kv, head_dim=head_dim, dropout_rate=None, partial_rotary_factor=0.5)

In [29]:



attention_pt = PhiAttention(config=config_pt)
params_jax = convert_attention(attention_pt, model_config=config_jax)



# initialise input sequence
seq_pt = torch.rand(batch_size, seq_len, d_model)
seq_jax = pt2jax(seq_pt)

mask_pt_1d = torch.ones(batch_size, seq_len, dtype=torch.bool)  # torch.rand(batch_size, seq_len) > 0.1
mask_pt = torch.tril(torch.einsum('bi,bj->bij', mask_pt_1d, mask_pt_1d))[:, None]
mask_jax_1d = pt2jax(mask_pt_1d)
mask_jax = jnp.tril(jnp.einsum('bi,bj->bij', mask_jax_1d, mask_jax_1d))[:, None, None]
leftpad_len = mask_jax_1d.argmax(axis=-1).astype(jnp.uint16)
rotary_values = make_rotary_values(leftpad_len, batch_size, seq_len, model_config=config_jax)

# In the Hugging Face implementation, the attention mask is added to the attention
# matrix, not multiplied.
# See https://github.com/huggingface/transformers/issues/1935
mask_pt = torch.where(mask_pt, 0, -10000.)

y_pt = attention_pt(hidden_states=seq_pt, attention_mask=mask_pt)[0]
y_jax = pt2jax(y_pt)
y_hat_jax, _ = forward_attention(params_jax, seq_jax, seq_jax, mask_jax,rotary_values=rotary_values, model_config=config_jax)

y_jax = jnp.where(mask_jax_1d[..., None], y_jax, 0.)
y_hat_jax = jnp.where(mask_jax_1d[..., None], y_hat_jax, 0.)

print('y_jax', y_jax.reshape(-1)[:30])
print('y_hat_jax', y_hat_jax.reshape(-1)[:30])
assert jnp.allclose(y_jax, y_hat_jax)
print('Test passed.')

y_jax [ 0.3809222   0.08554427 -0.32185036  0.09108333  0.21352716  0.21511878
  0.02889613  0.28895354 -0.24973309  0.13385388 -0.38566226  0.12869063
  0.00871488 -0.23161611 -0.30547583 -0.2266545  -0.2180866  -0.1230709
  0.01938874 -0.32273737  0.13738525 -0.30827588 -0.01547235 -0.15963183
 -0.05082852 -0.09416683 -0.28768718 -0.2831776   0.08766142 -0.04026837]
y_hat_jax [ 0.38092217  0.08554427 -0.32185036  0.09108337  0.21352716  0.21511877
  0.02889613  0.2889536  -0.24973312  0.13385391 -0.38566223  0.12869063
  0.00871487 -0.23161614 -0.30547583 -0.22665454 -0.21808663 -0.12307093
  0.01938875 -0.3227374   0.13738526 -0.30827582 -0.01547235 -0.1596319
 -0.05082849 -0.09416679 -0.28768718 -0.28317758  0.08766145 -0.04026842]
Test passed.


In [30]:
embedding_pt = tnn.Embedding(config_jax.vocab_size, config_jax.d_model, -1)
embedding_pt.weight = tnn.Parameter(torch.randn_like(embedding_pt.weight))

params_pt = embedding_pt.weight
params_jax = pt2jax(params_pt)

x_pt = torch.tensor([[3, 3, 3, 0, 3, 2, 3, 1, 5]], dtype=torch.int)
x_jax = pt2jax(x_pt).astype(jnp.uint16)

y_pt = embedding_pt(x_pt)
y_jax = pt2jax(y_pt)
y_hat_jax = forward_embedding(params_jax, x_jax)
assert jnp.allclose(y_jax, y_hat_jax)
print('Test passed.')

Test passed.


In [31]:
batch_size = 2
seq_len = 1

norm_pt = tnn.LayerNorm(config_jax.d_model, eps=config_jax.layer_norm_epsilon, elementwise_affine=True)
norm_pt.weight = tnn.Parameter(torch.randn_like(norm_pt.weight))

params_pt = norm_pt.weight
params_jax = LayerNorm(weight=pt2jax(norm_pt.weight), bias=pt2jax(norm_pt.bias))

x_pt = torch.rand(batch_size, seq_len, config_jax.d_model)
x_jax = pt2jax(x_pt)

y_pt = norm_pt(x_pt)
y_jax = pt2jax(y_pt)
y_hat_jax = forward_layer_norm(params_jax, x_jax, model_config=config_jax)
print('y_jax', y_jax.reshape(-1)[:30])
print('y_hat_jax', y_hat_jax.reshape(-1)[:30])
assert jnp.allclose(y_jax, y_hat_jax)
print('Test passed.')

y_jax [-1.4426247   0.31092632 -0.30887297  0.22166999 -2.1207542  -0.11964922
 -2.7633946   0.01631283  1.9290285   1.0304995  -0.5117087   3.478505
 -0.05939087  1.1803433   0.10975467  0.9459591  -1.6364973   0.03644655
  1.298217    0.9042006   0.66670144 -0.50051916  1.0391926  -0.5723707
  0.53466356 -0.06153665  1.0334806   2.3507762   0.07948992 -0.09832619]
y_hat_jax [-1.4426247   0.31092635 -0.30887297  0.22166999 -2.1207542  -0.11964922
 -2.7633946   0.01631283  1.9290285   1.0304995  -0.5117087   3.478505
 -0.05939087  1.1803433   0.10975467  0.94595915 -1.6364973   0.03644655
  1.2982172   0.9042006   0.66670144 -0.5005192   1.0391926  -0.5723707
  0.53466356 -0.06153665  1.0334806   2.3507762   0.07948992 -0.09832621]
Test passed.


In [52]:
import einops as op
decoder_block_pt = PhiDecoderLayer(config=config_pt, layer_idx=1)
params_jax = convert_decoder_block(decoder_block_pt, model_config=config_jax)
mlp_pt = decoder_block_pt.mlp
seq_pt = torch.rand(batch_size, seq_len, d_model)
seq_jax = pt2jax(seq_pt)

y_pt = mlp_pt(seq_pt)
y_jax = pt2jax(y_pt)

y_hat_jax = forward_mlp(params_jax, seq_jax)

print('y_jax', y_jax.reshape(-1)[:60])
print('y_hat_jax', y_hat_jax.reshape(-1)[:60])

assert jnp.allclose(y_jax, y_hat_jax, rtol=1e-02)

y_jax [ 0.20558816  0.24419892  0.1038277  -0.00047055  0.14069492 -0.2527915
  0.13871008  0.02292301 -0.0296139  -0.11507002  0.06133269  0.02876599
  0.11477372  0.27442294  0.05185471  0.08392327 -0.03465857  0.1597825
 -0.05245028 -0.24720992  0.02974929 -0.10189229  0.14556657 -0.09291685
  0.02649859 -0.3545338  -0.09292045  0.02840933  0.28200686  0.31341866
  0.0162026   0.13692412  0.30714723  0.25758561  0.04955429 -0.08822581
  0.1308631  -0.15604994  0.19415739  0.10688122  0.09553957 -0.02963576
 -0.04539628  0.07708994  0.02144855  0.24398202 -0.00731712  0.06229361
 -0.05898028  0.1293217  -0.00369786 -0.13411754 -0.00721333 -0.0674806
  0.10203804  0.00868339 -0.02300387 -0.41825277 -0.07009862  0.06579411]
y_hat_jax [ 0.20558454  0.24420096  0.1038224  -0.00046931  0.14071405 -0.25279364
  0.1387099   0.02293312 -0.02960952 -0.11508399  0.06134358  0.0287626
  0.11477291  0.2744342   0.05186819  0.08393145 -0.03464745  0.15980378
 -0.05245134 -0.24722143  0.02976028 -

In [56]:
import einops as op
decoder_block_pt = PhiDecoderLayer(config=config_pt, layer_idx=0)
params_jax = convert_decoder_block(decoder_block_pt, model_config=config_jax)

seq_pt = torch.rand(batch_size, seq_len, d_model)
seq_jax = pt2jax(seq_pt)

mask_pt_1d = torch.rand(batch_size, seq_len) > 0.7
mask_pt = op.rearrange(torch.tril(op.einsum(mask_pt_1d, mask_pt_1d, 'B L1, B L2 -> B L1 L2')), 'B L1 L2 -> B 1 L1 L2')
mask_jax_1d = pt2jax(mask_pt_1d)
mask_jax = pt2jax(mask_pt)

leftpad_len = mask_jax_1d.argmax(axis=-1).astype(jnp.uint16)
rotary_values = make_rotary_values(leftpad_len, batch_size, seq_len, model_config=config_jax)
mask_pt = torch.where(mask_pt, 0, -10000000.)

y_pt = decoder_block_pt(hidden_states=seq_pt, attention_mask=mask_pt)[0]
y_jax = pt2jax(y_pt)
y_hat_jax = forward_decoder_block(params_jax, seq_jax, mask_jax, rotary_values=rotary_values, model_config=config_jax)[0]

y_jax = jnp.where(mask_jax_1d[..., None], y_jax, 0.)
y_hat_jax = jnp.where(mask_jax_1d[..., None], y_hat_jax, 0.)

print('y_jax', y_jax.reshape(-1)[:60])
print('y_hat_jax', y_hat_jax.reshape(-1)[:60])

assert jnp.allclose(y_jax, y_hat_jax, rtol=1e-02)

y_jax [ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.37413242 -0.324925   -0.04732257 -0.22331363
  1.1843688   0.16076326  0.9275273  -0.80919784  0.4814498   0.94506735
 -0.81111825  0.2613848   0.5426444   1.4842494   0.9145651   0.49251202
  0.5494744   0.8769037   0.65794027  0.49175996  0.5358275   0.44394332
  1.3008204   1.3742481  -0.33338895 -0.27109402  0.61694956  1.1081872 ]
y_hat_jax [ 0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.     