In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import optax

class Encoder(nn.Module):
 latents: int


 @nn.compact
 def __call__(self, x):
   x = nn.Dense(500, name='fc1')(x)
   x = nn.relu(x)
   mean_x = nn.Dense(self.latents, name='fc2_mean')(x)
   logvar_x = nn.Dense(self.latents, name='fc2_logvar')(x)
   return mean_x, logvar_x

In [None]:
class Decoder(nn.Module):

 @nn.compact
 def __call__(self, z):
   z = nn.Dense(500, name='fc1')(z)
   z = nn.relu(z)
   z = nn.Dense(784, name='fc2')(z)
   return z

In [None]:
class VAE(nn.Module):
 latents: int = 20

 def setup(self):
   self.encoder = Encoder(self.latents)
   self.decoder = Decoder()

 def __call__(self, x, z_rng):
   mean, logvar = self.encoder(x)
   z = reparameterize(z_rng, mean, logvar)
   recon_x = self.decoder(z)
   return recon_x, mean, logvar

def reparameterize(rng, mean, logvar):
 std = jnp.exp(0.5 * logvar)
 eps = random.normal(rng, logvar.shape)
 return mean + eps * std

def model():
 return VAE(latents=LATENTS)

In [None]:
@jax.vmap
def kl_divergence(mean, logvar):
 return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
 logits = nn.log_sigmoid(logits)
 return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


@jax.jit
def train_step(params, opt_state, batch, rng):
    def loss_fn(params):
        recon_x, mean, logvar = model().apply({'params': params}, batch, rng)
        bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        return bce_loss + kld_loss

    grads = jax.grad(loss_fn)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

In [None]:
# Key creation for random number generation
rng = random.PRNGKey(0)
rng, key = random.split(rng)

# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
LATENTS = 128
STEPS_PER_EPOCH = 50000 // BATCH_SIZE

init_data = jnp.ones((BATCH_SIZE, 784), jnp.float32)
init_params = model().init(key, init_data, rng)['params']

optimizer = optax.adam(learning_rate=LEARNING_RATE)
opt_state = optimizer.init(init_params)

# Assuming train_ds is defined correctly to yield batches

for epoch in range(NUM_EPOCHS):
    for _ in range(STEPS_PER_EPOCH):
        batch = next(train_ds)
        rng, z_rng = random.split(rng)
        init_params, opt_state = train_step(init_params, opt_state, batch, z_rng)
    print(f'Epoch {epoch}, Opt_state: {opt_state}')


[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
       -4.82336804e-03,  3.18965688e-03,  2.27237432e-13,  4.76302914e-02,
       -1.37291977e-03, -1.19702267e-02, -1.15174549e-02, -2.90351007e-02,
       -1.71529353e-02,  4.84497175e-02,  5.54545874e-27,  6.53371331e-04,
       -1.40090585e-02,  6.65860594e-16, -8.40675645e-03,  0.00000000e+00,
        0.00000000e+00, -2.47622430e-02,  2.33826973e-03,  2.19874831e-17,
        2.56586764e-02, -1.94111058e-13, -7.24641373e-03, -1.57357585e-02,
       -4.79011331e-03,  5.31039992e-03,  7.99924284e-02, -1.50056155e-02,
       -2.31575649e-02, -4.53640670e-02,  5.95990577e-29,  0.00000000e+00,
        8.48178472e-03, -5.15958341e-03,  8.66103510e-04, -3.63414921e-02,
        6.20503724e-03,  0.00000000e+00, -1.01598119e-02,  0.00000000e+00,
        1.09471772e-02,  8.59531574e-04,  6.20411374e-02, -9.74250361e-05,
       -1.71898175e-02, -1.20023657e-02, -5.14267059e-03,  0.00000000e+00,
       -

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')

def prepare_image(x):
 x = tf.cast(x['image'], tf.float32)
 x = tf.reshape(x, (-1,))
 return x

ds_builder = tfds.builder('binarized_mnist')
ds_builder.download_and_prepare()
train_ds = ds_builder.as_dataset(split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(BATCH_SIZE)
train_ds = iter(tfds.as_numpy(train_ds))

test_ds = ds_builder.as_dataset(split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])

In [None]:
import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
   return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))

1.05


In [None]:
import jax
import jax.numpy as jnp

global_list = []

def log2(x):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x/ln_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [None]:
def log2_with_print(x):
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [None]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


In [None]:
from jax import grad

def sum_logistic(x) :
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [None]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))
#print(key1, key2)
#print(mat)
print(batched_x)

[[-0.1982749   0.12290715  0.36694935  1.1387376   0.8180654   0.2846364
  -1.2432412  -0.21720754  1.4403243  -1.3803186  -0.17311803  0.87109554
  -0.36025354  1.4244151  -0.2374977  -0.2992412  -0.78120977 -0.7913257
  -0.10821776 -0.5700162  -0.6177342  -0.92486453  0.0966308  -0.12466219
  -0.76721346 -1.6429391  -0.5530122   0.27125555 -0.47809386  1.2628251
   0.06739253 -0.36439684  0.6163947   0.6659997  -1.2629865  -0.8262338
   0.4272523  -0.31627107 -0.8964336  -0.36465937 -0.06689852  0.32176843
  -1.2004355  -0.74787426  0.50390005  1.9520171  -1.36864    -0.53189766
  -0.30683482  1.3208697   1.4793857  -0.44423586 -0.54569876  1.559088
  -0.68541384  1.3441124   0.20196167  0.84128606 -0.76329046  1.5112543
  -1.0545305   0.55136067  0.7035998  -0.24217466 -0.88659596 -0.95006734
   0.71373916  0.45581234 -0.03055416 -0.1943196   0.88121986 -0.44041997
  -0.76725554 -0.48482668 -1.1485003   0.8024451  -0.80056286  1.0913795
  -1.5928441   0.06719396 -0.94790715  0.18879