In [97]:
import jax
from typing import Any
from jax import numpy as jnp
from flax import linen as nn
import optax

In [98]:
x_key, noise_key, params_key, dropout_key = jax.random.split(
    jax.random.PRNGKey(0), 4)
xs = jax.random.normal(x_key, (100, 1))
noise = jax.random.normal(noise_key, (100, 1))
W, b = 2, -1
ys = xs + noise + b

# Single Dropout

In [84]:
class MyModel(nn.Module):
    num_neurons: int
    
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x
        

In [85]:
from flax.training import train_state

class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

In [86]:
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
    dropout_train_key = jax.random.fold_in(
        key=dropout_key, data=state.step)
    
    def loss_fn(params):
        yhats = state.apply_fn(
            {'params': params}, xs, training=True, 
            rngs={'dropout': dropout_train_key})
        loss = jnp.mean((ys - yhats) ** 2)
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
    

In [87]:
model = MyModel(num_neurons=3)
variables = model.init(params_key, xs, training=False)
params = variables['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optax.adam(1e-3),
)

for i in range(1001):
    state, loss = train_step(state, xs, ys, dropout_key)
    if i % 100 == 0:
        print(f'Iteration {i}: {loss}')
    

Iteration 0: 6.6528730392456055
Iteration 100: 5.362330436706543
Iteration 200: 4.739670753479004
Iteration 300: 4.24291467666626
Iteration 400: 4.084765911102295
Iteration 500: 3.82289719581604
Iteration 600: 3.664264440536499
Iteration 700: 3.4753434658050537
Iteration 800: 2.919456720352173
Iteration 900: 2.8462142944335938
Iteration 1000: 2.6043999195098877


In [88]:
print(state.key)

[ 839183663 3740430601]


# Multiple Dropout

In [123]:
class MyModelMultiple(nn.Module):
    num_neurons: int
    
    @nn.compact
    def __call__(self, x, training: bool, rngs):
        rng1, rng2, rng3 = rngs['dropout']
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training, rng=rng1)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training, rng=rng2)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training, rng=rng3)(x)
        return x
        

In [126]:
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
    dropout_train_key = jax.random.fold_in(
        key=dropout_key, data=state.step)
    rngs = jax.random.split(dropout_train_key, 3)
    
    def loss_fn(params):
        yhats = state.apply_fn(
            {'params': params}, xs, training=True, 
            rngs={'dropout': rngs})
        loss = jnp.mean((ys - yhats) ** 2)
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
    

In [129]:
model = MyModelMultiple(num_neurons=3)
rng1, rng2, rng3 = jax.random.split(dropout_key, 3)
variables = model.init(params_key, xs, training=False,
                       rngs={'dropout': (rng1, rng2, rng3)})
params = variables['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optax.adam(1e-3),
)

for i in range(1001):
    state, loss = train_step(state, xs, ys, dropout_key)
    if i % 100 == 0:
        print(f'Iteration {i}: {loss}')


TypeError: Module.init() got multiple values for argument 'rngs'