# JIT et méthodes
Le petit code suivant, montre que les méthodes même si elles sont associées à des instances différentes, sont liées dans leur compilations et le fait de récréer la classe réinitialise les compilations

In [6]:
from jax import jit, clear_caches
import jax.numpy as jnp
import equinox as eqx

class Class(eqx.Module):
    att1: jnp.ndarray
    att2: str = eqx.field(static=True)

    def meth(self):
        if self.att2 == 'r':
            return self.att1
        else:
            return -self.att1
    
    @jit
    def jit_meth(self, **kwargs):
        return self.meth(**kwargs)
    jit_met = jit(meth)

insr = Class(jnp.ones(2), 'r')
insr2 = Class(2*jnp.ones(2), 'r')
insd = Class(jnp.ones(2), 'd')

print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr2.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insd.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
clear_caches()
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())


0 0 0
[1. 1.]
1 1 1
[2. 2.]
1 1 1
[-1. -1.]
2 2 2
0 0 0


In [3]:
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr2.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insd.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())

2 2 2
[1. 1.]
2 2 2
[2. 2.]
2 2 2
[-1. -1.]
2 2 2


In [4]:
from jax import jit
import jax.numpy as jnp
import equinox as eqx

class Class(eqx.Module):
    att1: jnp.ndarray
    att2: str = eqx.field(static=True)

    def meth(self):
        if self.att2 == 'r':
            return self.att1
        else:
            return -self.att1
    
    @jit
    def jit_meth(self, **kwargs):
        return self.meth(**kwargs)
    jit_met = jit(meth)


insr = Class(jnp.ones(2), 'r')
print(insr.jit_meth._cache_size())
print(insr.jit_meth())
insr2 = Class(2*jnp.ones(2), 'r')
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size())
print(insr2.jit_meth())
insd = Class(jnp.ones(2), 'd')
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insd.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())


0
[1. 1.]
1 1
[2. 2.]
1 1 1
[-1. -1.]
2 2 2


même une instantce qui vient d'être créee hérite d'une version compilée des méthodes -> c'est sympa merci jax

# si l'attribut est une fonction

In [8]:
from jax import jit, clear_caches
import jax.numpy as jnp
import equinox as eqx
from typing import Callable

class Class(eqx.Module):
    att1: jnp.ndarray
    att2: Callable[[float], float] = eqx.field(static=True)

    def meth(self):
        return self.att2(self.att1)
    
    @jit
    def jit_meth(self):
        return self.meth()

insr = Class(jnp.ones(2), lambda x: x)
insr2 = Class(2*jnp.ones(2), lambda x: x)
insd = Class(jnp.ones(2), lambda x: -x)

print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr2.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insd.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
clear_caches()
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())


0 0 0
[1. 1.]
1 1 1
[2. 2.]
2 2 2
[-1. -1.]
3 3 3
0 0 0


In [9]:
from jax import jit, clear_caches
import jax.numpy as jnp
import equinox as eqx
from typing import Callable

class Class(eqx.Module):
    att1: jnp.ndarray
    att2: Callable[[float], float] = eqx.field(static=True)

    def meth(self):
        return self.att2(self.att1)
    
    @jit
    def jit_meth(self):
        return self.meth()
fun1 = lambda x: x
fun2 = lambda x: -x

insr = Class(jnp.ones(2), fun1)
insr2 = Class(2*jnp.ones(2), fun1)
insd = Class(jnp.ones(2), fun2)

print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insr2.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
print(insd.jit_meth())
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())
clear_caches()
print(insr.jit_meth._cache_size(), insr2.jit_meth._cache_size(), insd.jit_meth._cache_size())


0 0 0
[1. 1.]
1 1 1
[2. 2.]
1 1 1
[-1. -1.]
2 2 2
0 0 0


apparamment il faut que la fonction soit associée au même nom de variable, si la fonction est égale dans le sens qu'elle fait la même chose, même pour une fonction très simple, ça ne marche pas. Ce n'est pas le même fonctionnement que pour les autres objets, genre les entiers ça va comparer les valeurs par exemple.Ça implique que si on veut ne pas recompiler sur des fonctions, ils faut que les fonctions soient définient le plus "extérieurement" possible