In [1]:
%pip install dm-haiku einops

Note: you may need to restart the kernel to use updated packages.


In [2]:
import jax, torch
import jax.numpy as jnp
import numpy as np
import haiku as hk
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, reduce, repeat, einsum
from functools import partial

In [3]:
ks = hk.PRNGSequence(0)
nk = lambda : next(ks)

In [4]:
precision = 8
torch.set_printoptions(precision=precision)
jnp.set_printoptions(precision=precision)
np.set_printoptions(precision=precision)

In [5]:
# haiku
def forward(args, x):
    wq = hk.Linear(args.dim, with_bias=False, name='wq')
    wk = hk.Linear(args.dim, with_bias=False, name='wk')
    wv = hk.Linear(args.dim, with_bias=False, name='wv')
    wo = hk.Linear(args.dim, with_bias=False, name='wo') 

    q, k, v = map(lambda x: rearrange(x, 'l (nh dh) -> nh l dh', nh=args.n_head), (wq(x), wk(x), wv(x)))

    scores = einsum(q, k, 'h i k, h j k -> h i j')

    heads = scores @ v  # ignoring sfmx now

    return wo(rearrange(heads, 'nh l dh -> l (nh dh)', nh=args.n_head))


# torch
class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.wq, self.wk, self.wv, self.wo = [nn.Linear(args.dim, args.dim, bias=False) for _ in range(4)]        

    def forward(self, x):
        q, k, v = map(lambda x: rearrange(x, 'l (nh dh) -> nh l dh', nh=args.n_head), (self.wq(x), self.wk(x), self.wv(x)))

        scores = einsum(q, k, 'h i k, h j k -> h i j')

        heads = scores @ v

        return self.wo(rearrange(heads, 'nh l dh -> l (nh dh)', nh=args.n_head))

In [6]:
@dataclass
class Args:
    dim: int
    n_head: int


args = Args(dim=4096, n_head=32)

In [7]:
seq_len = 32
x_ = torch.randn(seq_len, args.dim)
x = jnp.asarray(x_)

In [8]:
attention = hk.transform(partial(forward, args))

params = attention.init(nk(), x)

In [9]:
attention_ = Attention(args)

In [10]:
params['wq']['w'] = jnp.asarray(attention_.wq.weight.data.T)
params['wk']['w'] = jnp.asarray(attention_.wk.weight.data.T)
params['wv']['w'] = jnp.asarray(attention_.wv.weight.data.T)
params['wo']['w'] = jnp.asarray(attention_.wo.weight.data.T)

In [11]:
jax_out = attention.apply(params, nk(), x)

In [12]:
pt_out = attention_(x_).detach()

In [13]:
print(jax_out)
print(pt_out)

[[-10.911846     3.6535482   -6.280497   ...  -1.1029336   10.962386
  -12.578524  ]
 [  4.3466306    6.333505    -1.2032125  ...  -2.8397715   -5.389377
   -5.559636  ]
 [  2.4335847    6.424877     6.4055147  ...  -5.6227484    1.9965103
    5.3998938 ]
 ...
 [ 12.701281     0.60315216  -1.8463742  ...   1.4613512   -2.1398673
    6.179638  ]
 [  2.809864     7.107299     5.754225   ...   6.6284566    0.86284876
    5.8206277 ]
 [ -0.18527448   8.471758    -5.831539   ... -14.321176     6.495843
   -7.096695  ]]
tensor([[-10.91184521,   3.65352607,  -6.28050089,  ...,  -1.10294020,
          10.96237564, -12.57854557],
        [  4.34663439,   6.33349705,  -1.20321608,  ...,  -2.83975768,
          -5.38935900,  -5.55962181],
        [  2.43358707,   6.42486191,   6.40550184,  ...,  -5.62276554,
           1.99652243,   5.39990234],
        ...,
        [ 12.70126343,   0.60314202,  -1.84636748,  ...,   1.46134126,
          -2.13987756,   6.17963219],
        [  2.80984259,   7.1073

In [14]:
jnp.allclose(jax_out, jnp.asarray(pt_out), atol=1e-5)

Array(False, dtype=bool)