In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import jax.numpy as jnp
import jax
import jax.random as jrandom

In [25]:
from layers import LinearLayer
import equinox as eqx
from layers import batched_mm

key = jrandom.PRNGKey(0)

linear = LinearLayer(3, 2, True, 0.1)
linear.params['w']['w']

class Linear(eqx.Module):
    weight: jnp.ndarray
    bias: jnp.ndarray
    def __init__(self, input_dim, output_dim, key):
        self.weight = jax.random.normal(key, (output_dim, input_dim))
        self.bias = jax.random.normal(key, (output_dim,))
    
    def __call__(self, x):
        return jnp.einsum('bis,oi->bos', x, self.weight) + self.bias[None, :, None]

batch_size, dig, seq = 5, 3, 2
x = jrandom.normal(key, (batch_size, 3, seq))
jax_linear = Linear(3, 2, jax.random.PRNGKey(0))
# out = jax_linear(x)

# diffs = jrandom.normal(key, out.shape)

# primals, vjp_fun = eqx.filter_vjp(jax_linear.__call__, x)
# vjp_fun(out)

def test_backward(jax_mod, my_mod, *args):
    primals, vjp_fun = eqx.filter_vjp(jax_mod.__call__, *args)
    diffs = jrandom.normal(key, primals.shape)
    backed = vjp_fun(diffs)

    args_mapped = [np.array(x, dtype=np.float64) for x in args]
    primals_mine = my_mod.forward(*args_mapped)
    backed_mine = my_mod.backward(np.array(diffs, dtype=np.float64))

    print(primals - primals_mine)
    print(backed - backed_mine)
    assert np.allclose(np.array(primals), primals_mine)
    assert np.allclose(np.array(backed), backed_mine)

linear.params['w']['w'] = np.array(jax_linear.weight, dtype=np.float64).copy()
print(linear.params['w']['w'].shape)
print(x.shape)
# print(batched_mm(linear.params['w']['w'], x).shape)
linear.params['b']['w'] = np.array(jax_linear.bias, dtype=np.float64)[:, None].copy()
test_backward(jax_linear, linear, x)

(2, 3)
(5, 3, 2)
[[[ 0.0000000e+00  0.0000000e+00]
  [ 1.4901161e-08  0.0000000e+00]]

 [[ 0.0000000e+00 -2.9802322e-08]
  [-2.9802322e-08  2.3841858e-07]]

 [[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00 -1.1920929e-07]]

 [[ 5.9604645e-08  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  8.9406967e-08]]]
[[[[-1.11205214e-08 -1.54819801e-09]
   [-3.27570149e-09 -1.77625523e-08]
   [-4.39871535e-08 -2.71883737e-08]]

  [[ 1.57698814e-08  2.14102425e-08]
   [ 1.61450354e-07 -1.18389391e-07]
   [ 5.81359134e-08  6.79754271e-08]]

  [[ 1.92657343e-08 -6.84767159e-08]
   [ 7.70442909e-09 -1.33064582e-08]
   [-3.49659857e-10 -2.50611629e-08]]

  [[-1.12520659e-09  3.28681500e-08]
   [ 3.17221041e-08  3.73587312e-08]
   [ 6.64036603e-09  1.92668210e-08]]

  [[ 8.66035483e-08  2.37996582e-08]
   [ 1.81227033e-08  7.29825071e-08]
   [-1.20852093e-08  2.15488645e-08]]]]


In [26]:
from layers import Softmax
class SoftmaxJax(eqx.Module):
    def __call__(self, x):
        return jax.nn.softmax(x, axis=1)

softmax_jax = SoftmaxJax()
softmax = Softmax()

x = jrandom.normal(key, (batch_size, 3, seq))

test_backward(softmax_jax, softmax, x)


[[[ 0.0000000e+00  2.9802322e-08]
  [ 7.4505806e-09  2.9802322e-08]
  [ 5.9604645e-08  2.9802322e-08]]

 [[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]]

 [[-2.9802322e-08  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00 -1.4901161e-08]]

 [[ 2.9802322e-08  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]
  [ 2.9802322e-08  1.4901161e-08]]]
[[[[-1.19512546e-08 -1.15222698e-08]
   [-2.23788299e-08 -2.30276542e-08]
   [-5.38400856e-08 -1.84939561e-08]]

  [[-2.12966785e-08  1.45802842e-09]
   [-3.00589783e-08  2.49374237e-09]
   [-2.01895425e-09  5.16536428e-09]]

  [[-3.36910841e-08 -7.14107423e-09]
   [-3.69118398e-08 -1.28443045e-09]
   [-2.45850245e-08 -7.82553701e-08]]

  [[ 2.08120076e-08 -2.22866270e-09]
   [ 2.67632172e-09 -2.59669183e-08]
   [ 1.49142169e-08 -1.72183687e-08]]

  [[-3.02159220e-09  5.04840

In [27]:
from layers import CrossEntropy

class CrossEntropyLossJax(eqx.Module):
    def __call__(self, x, y):
        return -jnp.sum(y * jax.nn.log_softmax(x, axis=1))

loss_jax = CrossEntropyLossJax()

loss = CrossEntropy()

x = jrandom.normal(key, (batch_size, 3, seq))
y = jrandom.normal(key, (batch_size, 3, seq))

test_backward(loss_jax, loss, x, y)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1m[1m[1m[1m- Resolution failure for literal arguments:
[1mFailed in nopython mode pipeline (step: nopython frontend)
[1m[1mNo implementation of function Function(<built-in function getitem>) found for signature:

 >>> getitem(array(float64, 3d, C), Tuple(int64, array(float64, 1d, C), int64))

There are 22 candidate implementations:
[1m  - Of which 20 did not match due to:
  Overload of function 'getitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64, 3d, C), Tuple(int64, array(float64, 1d, C), int64))':[0m
[1m   No match.[0m
[1m  - Of which 2 did not match due to:
  Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 209.
    With argument(s): '(array(float64, 3d, C), Tuple(int64, array(float64, 1d, C), int64))':[0m
[1m   Rejected as the implementation raised a specific error:
     NumbaTypeError: [1mUnsupported array index type array(float64, 1d, C) in Tuple(int64, array(float64, 1d, C), int64)[0m[0m
  raised from /opt/homebrew/Caskroom/miniforge/base/envs/jax/lib/python3.12/site-packages/numba/core/typing/arraydecl.py:102
[0m
[0m[1mDuring: typing of intrinsic-call at /Users/marcel/git/Vitber-indmat/layers.py (460)[0m
[1m
File "layers.py", line 460:[0m
[1m    def forward(self, y_pred: np.ndarray, y_true: np.ndarray):
        <source elided>
            for seq_index in range(y_pred.shape[2]):
[1m                per_token_loss[batch_index, seq_index] = -np.log(y_pred[batch_index, y_true[batch_index, seq_index], seq_index] + self.epsilon)
[0m                [1m^[0m[0m
[0m
[0m[1m- Resolution failure for non-literal arguments:
[1mNone[0m
[0m[0m
[0m[1mDuring: resolving callee type: BoundFunction((<class 'numba.core.types.misc.ClassInstanceType'>, 'forward') for instance.jitclass.CrossEntropy#317627080<epsilon:float64,prev_y_pred:OptionalType(array(float64, 3d, A)),prev_y:OptionalType(array(int64, 2d, A))>)[0m
[0m[1mDuring: typing of call at <string> (3)
[0m
[1m
File "<string>", line 3:[0m
[1m<source missing, REPL/exec in use?>[0m
