Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Jitting twice for a class method #22

Closed
pipme opened this issue Aug 12, 2022 · 2 comments
Closed

Jitting twice for a class method #22

pipme opened this issue Aug 12, 2022 · 2 comments

Comments

@pipme
Copy link

pipme commented Aug 12, 2022

import jax
import jax.numpy as jnp
import treeo as to

class A(to.Tree):
    X: jnp.array = to.field(node=True)
    
    def __init__(self):
        self.X = jnp.ones((50, 50))

    @jax.jit
    def f(self, Y):
        return jnp.sum(Y ** 2) * jnp.sum(self.X ** 2)

Y = jnp.ones(2)
for i in range(5):
    print(A.f._cache_size())
    a = A()
    a.f(Y)

The output of the above is 0 1 2 2 2 with jax 0.3.15. No idea what's happening. It seems to work fine with 0.3.10 and the output is 0 1 1 1 1. Thanks.

@pipme
Copy link
Author

pipme commented Aug 12, 2022

I looks like something related to ** 2. The following gives 0 1 1 1 1. But still not sure why 🤔

@jax.jit
def f(self, Y):
    return jnp.sum(Y) * jnp.sum(self.X)

@pipme
Copy link
Author

pipme commented Aug 12, 2022

Ok it seems not related to treeo.

@jax.jit
def f(Y):
    return jnp.sum(Y**2)

Y = jnp.ones(2)

for i in range(5):
    print(f._cache_size())
    f(Y)

This alone gives 0 1 2 2 2 too. I will probably ask in Jax's repo if needed.

@pipme pipme closed this as completed Aug 12, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant