I'm working on implementing Hadamard Transforms. I use the following function to build a hadamard matrix:
def hadamard(key, shape, dtype=jnp.float32):
lg2 = jnp.log2(shape[0])
H = jnp.ones((1, ), dtype=dtype)
for i in jnp.arange(lg2):
H = jnp.vstack([jnp.hstack([H, H]), jnp.hstack([H, -H])])
H = 2**(-lg2 / 2) * H
return H
I cannot lower this to a fori_loop or while, because the shape changes every iteration. I then put it in a simple dense layer:
class HadamardTransform(nn.Module):
n_hadamard: int
@nn.compact
def __call__(self, X):
kernel = self.param("kernel", hadamard, (self.n_hadamard, ))
z = jnp.dot(X, kernel)
return z
I can initialize this model and my parameters look fine (i.e. its a pytree with a device array):
# fake test data
key = random.PRNGKey(42)
X = random.normal(key, (1, 4096))
# Instantiating model
model = HadamardTransform(4096)
params = model.init(key, X)
However, when I want to do the forward pass with
I get a concretizationtype error on the for loop of the hadamard function: 'ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.'
I'm not sure why this is happening - I can init the model, so I would expect forward to work. Is jitting the apply function also jitting the hadamard function?
I'm working on implementing Hadamard Transforms. I use the following function to build a hadamard matrix:
I cannot lower this to a fori_loop or while, because the shape changes every iteration. I then put it in a simple dense layer:
I can initialize this model and my parameters look fine (i.e. its a pytree with a device array):
However, when I want to do the forward pass with
I get a concretizationtype error on the for loop of the hadamard function: 'ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.'
I'm not sure why this is happening - I can init the model, so I would expect forward to work. Is jitting the apply function also jitting the hadamard function?