In [1]:
import jax
from typing import Any
from jax import numpy as jnp
from flax import linen as nn
import optax

# Setup

In [2]:
x_key, noise_key, params_key, dropout_key = jax.random.split(
    jax.random.PRNGKey(0), 4)
xs = jax.random.normal(x_key, (100, 1))
noise = jax.random.normal(noise_key, (100, 1))
W, b = 2, -1
ys = xs + noise + b



Metal device set to: Apple M2 Max

systemMemory: 96.00 GB
maxCacheSize: 36.00 GB



# Single Dropout

In [3]:
class MyModel(nn.Module):
    num_neurons: int
    
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x
        

In [4]:
from flax.training import train_state

class TrainState(train_state.TrainState):
  key: jax.random.KeyArray

In [5]:
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
    dropout_train_key = jax.random.fold_in(
        key=dropout_key, data=state.step)
    
    def loss_fn(params):
        yhats = state.apply_fn(
            {'params': params}, xs, training=True, 
            rngs={'dropout': dropout_train_key})
        loss = jnp.mean((ys - yhats) ** 2)
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
    

In [6]:
model = MyModel(num_neurons=3)
variables = model.init(params_key, xs, training=False)
params = variables['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optax.adam(1e-3),
)

for i in range(1001):
    state, loss = train_step(state, xs, ys, dropout_key)
    if i % 100 == 0:
        print(f'Iteration {i}: {loss}')
    

Iteration 0: 6.6528730392456055
Iteration 100: 5.362330436706543
Iteration 200: 4.739670753479004
Iteration 300: 4.24291467666626
Iteration 400: 4.084765911102295
Iteration 500: 3.82289719581604
Iteration 600: 3.664264440536499
Iteration 700: 3.4753434658050537
Iteration 800: 2.919456720352173
Iteration 900: 2.8462142944335938
Iteration 1000: 2.6043999195098877


In [7]:
print(state.key)

[ 839183663 3740430601]


# Multiple Dropout

In [8]:
from typing import Optional
from flax.linen.stochastic import KeyArray
from flax.training import train_state
from jax import random
from jax import lax
from flax.linen.module import merge_param
from typing import Sequence


class TrainState(train_state.TrainState):
  key: jax.random.KeyArray


class MyModelMultiple(nn.Module):
    num_neurons: int
    
    @nn.compact
    def __call__(self, x, training: bool):
        x = nn.Dense(self.num_neurons)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        return x
        

In [9]:
@jax.jit
def train_step(state: TrainState, xs, ys, dropout_key):
    dropout_train_key = jax.random.fold_in(
        key=dropout_key, data=state.step)
    
    def loss_fn(params):
        yhats = state.apply_fn(
            {'params': params}, xs, training=True, 
            # Each of 'Dropout' layers in the model has unique 'scope name' which is
            # folded in when computing the random numbers. Thus, those
            # there layers are okay with sharing the same key.
            # If interested, see https://github.com/google/flax/discussions/3262.
            rngs={'dropout': dropout_train_key})
        loss = jnp.mean((ys - yhats) ** 2)
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
    

In [10]:
model = MyModelMultiple(num_neurons=3)
rng1, rng2, rng3 = jax.random.split(dropout_key, 3)
print("* Init")
variables = model.init(params_key, xs, training=False)
params = variables['params']

state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    key=dropout_key,
    tx=optax.adam(1e-3),
)

print("* Training")
for i in range(1001):
    state, loss = train_step(state, xs, ys, dropout_key)
    if i % 100 == 0:
        print(f'Iteration {i}: {loss}')


* Init
* Training
Iteration 0: 31.82681655883789
Iteration 100: 16.837251663208008
Iteration 200: 16.39636993408203
Iteration 300: 13.929004669189453
Iteration 400: 10.109065055847168
Iteration 500: 13.120767593383789
Iteration 600: 9.250357627868652
Iteration 700: 10.172189712524414
Iteration 800: 8.768033981323242
Iteration 900: 10.029900550842285
Iteration 1000: 3.9830234050750732
