In [None]:
# !pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [None]:
import jax, flax, numpy
from notebookinit import *
import jax.numpy as jnp
from sklearn.model_selection import train_test_split

2023-12-30 08:06:59.606313: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-30 08:06:59.606339: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-30 08:06:59.607148: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
import wandb

# start a new wandb run to track this script


In [None]:
q = jnp.array([0.0], dtype=jnp.float32)

In [None]:
learning_rate = 5e-3

In [None]:
X_data = jnp.linspace(-10,10,41).reshape([-1,1])
unknown_process_a = 2.0
unknown_process_b = 3.0
y_data = X_data*unknown_process_a + unknown_process_b
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.2, random_state=42)

In [None]:
# Utilities blatantly missing from flax
def count_params(params):
    return sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(lambda x: x.size, params)))

# create the loss function
def make_loss(_model):
    @jax.jit
    def loss_fn(params, x_batched, y_batched):
      # Define the squared loss for a single pair (x,y)
      def squared_error(x, y):
        pred = _model.apply(params, x)
        return jnp.inner(y-pred, y-pred) / 2.0
      # Vectorize the previous to compute the average of the loss on all samples.
      return jnp.squeeze(jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0))
    return loss_fn

In [None]:
from flax import linen as nn
class Linear(nn.Module):
    n_features: int = 1    

    def setup(self):
        local_initializer1 = lambda: jnp.array(0.0)
        local_initializer2 = nn.initializers.lecun_normal()
        self.vars_a = self.variable('vars','a',nn.initializers.lecun_normal(), jax.random.PRNGKey(0), (1,1))
        self.vars_b = self.variable('vars','b',nn.initializers.lecun_normal(), jax.random.PRNGKey(0), (1,1))
        pass
        
    @nn.compact
    
    def __call__(self, x):
        # y = nn.Dense(features=self.n_features)(x)
        # y = jnp.sum(y)
        y = self.vars_a.value * x + self.vars_b.value
        return y

model = Linear(n_features=1)
loss_fn = make_loss(model)
X1 = jnp.array([[0],[1],[2], [4]], dtype = jnp.float32)
X2 = jnp.array([[4], [6]], dtype = jnp.float32)
params = model.init(jax.random.PRNGKey(0),X1)    
print(f"Total number of parameters: {count_params(params)}")
y_pred1 = model.apply(params, X1)    
y_pred2 = model.apply(params, X2)

Total number of parameters: 2


In [None]:
loss_fn(params,X_train,y_train)

Array(87.93137, dtype=float32)

In [None]:
y_pred2

Array([[-1.116112 ],
       [-1.5625569]], dtype=float32)

In [None]:
modelInstance = model.bind(params)

In [None]:
modelInstance.variables

{'vars': {'a': Array([[-0.2232224]], dtype=float32),
  'b': Array([[-0.2232224]], dtype=float32)}}

In [None]:
tabulate_fn = nn.tabulate(Linear(), jax.random.PRNGKey(0), console_kwargs={'width':120, 'force_jupyter':True})
print(tabulate_fn(X1))






In [None]:
modelInstance.apply(params,X2)

Array([[-1.116112 ],
       [-1.5625569]], dtype=float32)

In [None]:
tabulate_fn(X1)

'\n\n'

In [None]:
nn.initializers.lecun_normal()(jax.random.PRNGKey(0),shape=[1,10])

Array([[-0.40293682,  0.2864223 , -0.1979662 , -0.7924857 , -0.47636405,
        -0.16503796, -0.7232473 , -0.6376393 ,  0.78707284,  0.61249083]],      dtype=float32)

In [None]:
loss_fn = jax.jit(make_loss(model))
train_loss = loss_fn(params, X_train, y_train)
validation_loss = loss_fn(params, X_test, y_test)
print(f'init state:')
print(f'{train_loss=:0.2f}, {validation_loss=:0.2f}')

init state:
train_loss=87.93, validation_loss=105.06


In [None]:
learning_rate

0.005

In [None]:
%timeit loss_fn(params, X_test, y_test)

25 µs ± 1.39 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
train_loss_value, gradients = loss_grad_fn(params, X_train, y_train)
test_loss_value = loss_fn(params, X_test, y_test)
gradients

{'vars': {'a': Array([[-75.487236]], dtype=float32),
  'b': Array([[-2.4937277]], dtype=float32)}}

In [None]:
%timeit loss_grad_fn(params, X_train, y_train)

35.4 µs ± 3.39 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
test_loss_value

Array(105.06377, dtype=float32)

In [None]:
float(train_loss_value)

87.9313735961914

In [None]:
# !pip install --upgrade chex


In [None]:
%timeit loss_fn(params, X_test, y_test)

23.2 µs ± 968 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
%timeit loss_grad_fn(params, X_train, y_train)

33.7 µs ± 1.85 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


<function optax._src.alias.adam(learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]]], b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Optional[Any] = None) -> optax._src.base.GradientTransformation>

In [None]:
import optax
optimizers =  \
[
    'adabelief',
    'adafactor',
    'adagrad',
    'adam',
    'adamw',
    'lion',
    'amsgrad',
    'noisy_sgd',
    'novograd',
    'rmsprop',
    'sm3',
    'adamax',
    'adamaxw',
    
]

optimizer_idx = 1
optimizer_name = optimizers[optimizer_idx]
exec(f'this_optimizer=optax.{optimizers[optimizer_idx]}')


In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="01_flax_review",    
    # track hyperparameters and run metadata
    config={'this_optimizer':optimizer_name}
    )

model = Linear(n_features=1)
loss_fn = make_loss(model)
loss_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
params = model.init(jax.random.PRNGKey(0),X_train) 

optimizer = this_optimizer(learning_rate=learning_rate)
opt_state = optimizer.init(params)
print(f'starting with {learning_rate=:0.3e}')
params_trajectory = []
params_trajectory.append(params)

target_achieved = False
step_count = 0
for superepoch_idx in range(20):
    for epoch_idx in range(500):
        train_loss_value, gradients = loss_grad_fn(params, X_train, y_train)
        test_loss_value = loss_fn(params, X_test, y_test)
        updates, opt_state = optimizer.update(gradients, opt_state, params=params)
        step_count+=1
        updated_params = optax.apply_updates(params, updates)
        # params_trajectory.append(updated_params)
        params = updated_params
        wandb.log({"step_count":step_count, "train_loss_value": train_loss_value, "test_loss_value": test_loss_value})
        if jnp.log10(test_loss_value)<-3:
            target_achieved = True
            break    
    if target_achieved:
        break        
    print(f"{step_count}:, loss: train:{jnp.log10(train_loss_value):0.1f}, test:{jnp.log10(test_loss_value):0.1f}; a={params['vars']['a'][0][0]:0.2f}; b={params['vars']['b'][0][0]:0.2f}")
if target_achieved:
    print(f'target achieved in {step_count}  steps')
    print(f"{step_count}:, loss: train:{jnp.log10(train_loss_value):0.1f}, test:{jnp.log10(test_loss_value):0.1f}; a={params['vars']['a'][0][0]:0.2f}; b={params['vars']['b'][0][0]:0.2f}")

wandb.finish()

starting with learning_rate=5.000e-03
500:, loss: train:1.9, test:1.9; a=-0.02; b=-0.02
1000:, loss: train:1.9, test:1.9; a=-0.00; b=-0.00
1500:, loss: train:1.9, test:1.9; a=0.00; b=0.00
2000:, loss: train:1.8, test:1.9; a=0.01; b=0.01
2500:, loss: train:1.8, test:1.9; a=0.15; b=0.15
3000:, loss: train:1.2, test:1.3; a=1.08; b=1.33
3500:, loss: train:-1.4, test:-1.3; a=1.95; b=2.90
target achieved in 3720  steps
3720:, loss: train:-3.1, test:-3.0; a=1.99; b=2.99


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
step_count,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
test_loss_value,██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▅▅▄▃▂▁▁▁▁▁▁▁
train_loss_value,██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▅▅▄▃▂▁▁▁▁▁▁▁

0,1
step_count,3720.0
test_loss_value,0.00099
train_loss_value,0.0008
