In [1]:
%pip install dm-haiku einops

Note: you may need to restart the kernel to use updated packages.


In [2]:
import jax, torch
import jax.numpy as jnp
import numpy as np
import haiku as hk
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, reduce, repeat, einsum
from functools import partial

In [3]:
ks = hk.PRNGSequence(0)
nk = lambda : next(ks)

In [4]:
precision = 8
torch.set_printoptions(precision=precision)
jnp.set_printoptions(precision=precision)
np.set_printoptions(precision=precision)

In [5]:
# haiku
def forward(args, x):
    wq = hk.Linear(args.dim, with_bias=False, name='wq')
    wk = hk.Linear(args.dim, with_bias=False, name='wk')
    wv = hk.Linear(args.dim, with_bias=False, name='wv')
    wo = hk.Linear(args.dim, with_bias=False, name='wo') 

    q, k, v = map(lambda x: rearrange(x, 'l (nh dh) -> nh l dh', nh=args.n_head), (wq(x), wk(x), wv(x)))

    scores = einsum(q, k, 'h i k, h j k -> h i j')

    heads = scores @ v  # ignoring sfmx now

    return wo(rearrange(heads, 'nh l dh -> l (nh dh)', nh=args.n_head))


# torch
class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.wq, self.wk, self.wv, self.wo = (nn.Linear(args.dim, args.dim, bias=False) for _ in range(4))      

    def forward(self, x):
        args = self.args

        q, k, v = map(lambda x: rearrange(x, 'l (nh dh) -> nh l dh', nh=args.n_head), (self.wq(x), self.wk(x), self.wv(x)))

        scores = einsum(q, k, 'h i k, h j k -> h i j')

        heads = scores @ v

        return self.wo(rearrange(heads, 'nh l dh -> l (nh dh)', nh=args.n_head))

In [6]:
@dataclass
class Args:
    dim: int
    n_head: int


args = Args(dim=4096, n_head=32)

In [7]:
seq_len = 32
x_ = torch.randn(seq_len, args.dim)
x = jnp.asarray(x_)

In [8]:
attention = hk.transform(partial(forward, args))

params = attention.init(nk(), x)

In [9]:
attention_ = Attention(args)

In [10]:
params['wq']['w'] = jnp.asarray(attention_.wq.weight.data.T)
params['wk']['w'] = jnp.asarray(attention_.wk.weight.data.T)
params['wv']['w'] = jnp.asarray(attention_.wv.weight.data.T)
params['wo']['w'] = jnp.asarray(attention_.wo.weight.data.T)

In [11]:
jax_out = attention.apply(params, nk(), x)

In [12]:
with torch.no_grad():
    pt_out = attention_(x_)

In [13]:
print(jax_out)
print(pt_out)

[[-10.620579     0.2395443   14.827803   ...  -4.097179     5.0579653
   -0.06813204]
 [ -2.9569612    6.4783435    1.6496208  ...   4.478825    11.532101
    4.6816864 ]
 [ -1.0786729    0.5392799    6.2966537  ...  -3.1048808   -7.515436
   13.741988  ]
 ...
 [-20.14898      2.7913113   -1.4063245  ...  20.23688    -15.4071665
    2.8805785 ]
 [  9.724804    -5.035846     7.1828976  ...  -6.987926   -11.659335
   -4.3041105 ]
 [ -2.7573986    2.9879174    7.2161365  ...   8.685687    -8.505831
   -3.262116  ]]
tensor([[-10.62058258,   0.23952192,  14.82778645,  ...,  -4.09717321,
           5.05792522,  -0.06814072],
        [ -2.95696282,   6.47835493,   1.64962208,  ...,   4.47883129,
          11.53208637,   4.68168831],
        [ -1.07867491,   0.53928685,   6.29667091,  ...,  -3.10488224,
          -7.51543188,  13.74200153],
        ...,
        [-20.14899254,   2.79131079,  -1.40632212,  ...,  20.23687744,
         -15.40716171,   2.88055515],
        [  9.72478390,  -5.035833

In [14]:
jnp.allclose(jax_out, jnp.asarray(pt_out), atol=1e-5)

Array(False, dtype=bool)