# TODOS
- [X] shard!
- [X] make sure we instantiate the big arrays in sharded form from the init fn!
- [X] make sure we pad flattened params/grads to divide `n_opt_devices` 
- [X] add `bfloat16` support cause memory matters

In [1]:
# assert 'batch' in MESH.axis_names

from meta_opt.optimizers.base import OptimizerConfig
from meta_opt.optimizers.sgd import SGDConfig
from meta_opt.optimizers.adamw import AdamWConfig
from meta_opt.optimizers.metaopt import MetaOptConfig
from meta_opt.profiling import *

import tqdm
import functools
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from flax import jax_utils, struct
import optax

# Create a Sharding object to distribute a value across devices:
NUM_DEVICES = jax.local_device_count()
BATCH_NUM_DEVICES = 8
OPT_STATE_NUM_DEVICES = 1
assert BATCH_NUM_DEVICES * OPT_STATE_NUM_DEVICES == NUM_DEVICES

devices = mesh_utils.create_device_mesh((BATCH_NUM_DEVICES, OPT_STATE_NUM_DEVICES))
MESH = Mesh(devices, axis_names=('batch', 'opt'))

Mesh(device_ids=array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7]]), axis_names=('batch', 'opt'))


In [2]:
%%capture
!pip3 install memory-profiler
%load_ext memory_profiler
!pip3 install tensorflow tensorboard-plugin-profile
!pip3 install ml-dtypes==0.2.0

In [3]:
# hyperparams
# cfg = AdamWConfig(learning_rate=0.001, b1=0.9, b2=0.999, eps=1e-8, weight_decay=None, grad_clip=None)
cfg = MetaOptConfig(initial_learning_rate=0.1, weight_decay=1e-4, grad_clip=None,
                            H=64, HH=2, m_method='scalar', scale_by_adam_betas=None,
                            use_bfloat16=False,
                            fake_the_dynamics=False, freeze_gpc_params=False, freeze_cost_fn_during_rollouts=False,
                            meta_optimizer_cfg=SGDConfig(learning_rate=1e-5, momentum=0, nesterov=False, weight_decay=None, grad_clip=None))

# Memory profiling
Simply profile what happens when we shard the optimizer state. This section's code makes no use of data parallelism -- that is the work of later sections.

In [4]:
# import os
# if os.path.isdir('/tmp/trace'):
#     import shutil
#     shutil.rmtree("/tmp/trace")

# # with jax.profiler.trace('/tmp/trace'):
    
# opt = cfg.make_jax()
# opt = (jax.tree_util.Partial(opt.init), jax.tree_util.Partial(opt.update))
# params = make_params(0)  # make some pretend parameters
# opt_state = make_opt_state(params, opt)
# print('made', opt_state[0].disturbance_history.dtype)

# opt_state = shard_opt_state(opt_state)  # uncomment this line to shard the optimizer state
# print('sharded', opt_state[0].disturbance_history.dtype)
# jax.debug.visualize_array_sharding(opt_state[0].disturbance_history)

# # do a few steps of GD
# for seed in tqdm.trange(90):
#     batch = make_batch(seed)
#     params, opt_state = train_step(batch, params, opt_state, opt)
# print('trained', opt_state[0].disturbance_history.dtype)
# print(params[:10])
# jax.debug.visualize_array_sharding(opt_state[0].disturbance_history)

# %timeit train_step(batch, params, opt_state, opt)[0].block_until_ready()

# # %load_ext tensorboard
# # %tensorboard --logdir=/tmp/trace

# Profile a pmapped one

In [5]:
# import os
# if os.path.isdir('/tmp/trace'):
#     import shutil
#     shutil.rmtree("/tmp/trace")

# @functools.partial(jax.pmap, axis_name='batch', in_axes=(0, 0, 0, None), out_axes=(0, 0), static_broadcasted_argnums=(3,)) 
# def _pmapped_train_step(batch, params, opt_state, opt):
#     loss_fn = LossFn(batch)
#     grads = jax.grad(loss_fn)(params)
#     grads = jax.lax.pmean(grads, axis_name='batch')

#     # dh = jax.lax.with_sharding_constraint(opt_state[0].disturbance_history, NamedSharding(MESH, P(None, 'opt')))
#     # ph = jax.lax.with_sharding_constraint(opt_state[0].param_history, NamedSharding(MESH, P(None, 'opt')))
#     # opt_state = (opt_state[0].replace(disturbance_history=dh, param_history=ph), opt_state[1])
    
#     updates, opt_state = opt[1](grads, opt_state, params, cost_fn=loss_fn)

#     # dh = jax.lax.with_sharding_constraint(opt_state[0].disturbance_history, NamedSharding(MESH, P(None, 'opt')))
#     # ph = jax.lax.with_sharding_constraint(opt_state[0].param_history, NamedSharding(MESH, P(None, 'opt')))
#     # opt_state = (opt_state[0].replace(disturbance_history=dh, param_history=ph), opt_state[1])
    
#     params = optax.apply_updates(params, updates)
#     return params, opt_state

# def pmapped_train_step(batch, replicated_params, replicated_opt_state, opt):
#     batch = batch.reshape(BATCH_NUM_DEVICES, -1, NUM_DATA)
#     sharded_batch = jax.device_put(batch, NamedSharding(MESH, P('batch', None, None)))
    
#     # print(replicated_opt_state[0].disturbance_history.sharding, replicated_opt_state[0].disturbance_history.shape)
#     replicated_params, replicated_opt_state = _pmapped_train_step(sharded_batch, replicated_params, replicated_opt_state, opt)
#     # print(replicated_opt_state[0].disturbance_history.sharding, replicated_opt_state[0].disturbance_history.shape)
    
#     return replicated_params, replicated_opt_state

# # with jax.profiler.trace('/tmp/trace'):

# opt = cfg.make_jax()
# opt = (jax.tree_util.Partial(opt.init), jax.tree_util.Partial(opt.update))
# params = make_params(0)  # make some pretend parameters
# opt_state = make_opt_state(params, opt)

# # replicate
# replicate_fn = lambda v: jnp.tile(v[None], (BATCH_NUM_DEVICES,) + ((1,) * v.ndim))
# unreplicate_fn = lambda v: v[0]

# replicated_params = replicate_fn(params)
# replicated_params = jax.device_put(replicated_params, NamedSharding(MESH, P('batch')))

# # replicated_opt_state = opt_state
# replicated_opt_state = jax.tree_map(replicate_fn, opt_state)

# replicated_opt_state = shard_opt_state(replicated_opt_state)  # uncomment this line to shard the optimizer state
# jax.debug.visualize_array_sharding(unreplicate_fn(replicated_opt_state[0].disturbance_history))

# # do a few steps of pmapped GD
# for seed in tqdm.trange(90):
#     batch = make_batch(seed)
#     replicated_params, replicated_opt_state = pmapped_train_step(batch, replicated_params, replicated_opt_state, opt)

# jax.debug.visualize_array_sharding(unreplicate_fn(replicated_opt_state[0].disturbance_history))
# %timeit pmapped_train_step(batch, replicated_params, replicated_opt_state, opt)[0].block_until_ready()

# # %load_ext tensorboard
# # %tensorboard --logdir=/tmp/trace

# Profile a shmapped one

In [6]:
# import os
# if os.path.isdir('/tmp/trace'):
#     import shutil
#     shutil.rmtree("/tmp/trace")

# from jax.experimental.shard_map import shard_map

# replicate_fn = lambda v: jnp.tile(v, (BATCH_NUM_DEVICES,) + ((1,) * (v.ndim - 1))) if v.ndim > 0 else v
# unreplicate_fn = lambda v: v[:(v.shape[0] // BATCH_NUM_DEVICES)] if v.ndim > 0 else v

# def make_shardings(opt, spec):
#     shardings = {
#         # to replicate
#         'gpc_params': P(),
#         # 'gpc_tx': P(),
#         'gpc_opt_state': P(),
#         # 'H': P(),
#         # 'HH': P(),
#         't': P(),
#         # 'num_params': P(),
#         # 'base_lr': P(),
#         # 'disturbance_transform': P(),
#         'disturbance_transform_state': P(),
#         'recent_gpc_grads': P(),
#         'recent_gpc_cost': P(),
#         'cost_fn_history': P(),
    
#         # to shard along num_params axis
#         'disturbance_history': spec,
#         'param_history': spec,
#     }
#     os = make_opt_state(make_params(0), opt)
#     shardings = (os[0].replace(**shardings), optax.EmptyState())
#     del os
#     return shardings

# opt = cfg.make_jax()
# opt = (jax.tree_util.Partial(opt.init), jax.tree_util.Partial(opt.update))
# shardings = make_shardings(opt, P('batch', None))
# s2 = jax.tree_map(lambda v: NamedSharding(MESH, v), make_shardings(opt, P('batch', 'opt')))

# @functools.partial(shard_map, mesh=MESH, in_specs=(P('batch'), P('batch'), shardings, P()), out_specs=(P('batch'), shardings), check_rep=False)
# def _shmapped_train_step(batch, params, opt_state, opt):
#     loss_fn = LossFn(batch)
#     grads = jax.grad(loss_fn)(params)
#     grads = jax.lax.pmean(grads, axis_name='batch')
#     updates, opt_state = opt[1](grads, opt_state, params, cost_fn=loss_fn)
#     params = optax.apply_updates(params, updates)
#     return params, opt_state

# @jax.jit
# def shmapped_train_step(batch, params, opt_state, opt):
#     # assert params.ndim == 2 and batch.ndim == 2 and opt_state[0].disturbance_history.ndim == 3, (params.shape, batch.shape, opt_state[0].disturbance_history.shape)
#     # print(opt_state[0].disturbance_history.shape, opt_state[0].disturbance_history.sharding)
#     batch = jax.device_put(batch, NamedSharding(MESH, P('batch', None)))
#     opt_state = jax.lax.with_sharding_constraint(opt_state, s2)
#     params, opt_state = _shmapped_train_step(batch, params, opt_state, opt)
#     opt_state = jax.lax.with_sharding_constraint(opt_state, s2)
#     # print(opt_state[0].disturbance_history.shape, opt_state[0].disturbance_history.sharding)
#     return params, opt_state

# with jax.profiler.trace('/tmp/trace'):

#     # opt = cfg.make_jax()
#     # opt = (jax.tree_util.Partial(opt.init), jax.tree_util.Partial(opt.update))
#     params = make_params(0)  # make some pretend parameters
#     opt_state = make_opt_state(params, opt)
    
#     # replicate
#     replicated_params = jax.device_put(replicate_fn(params), NamedSharding(MESH, P('batch')))
#     replicated_opt_state = (opt_state[0].replace(disturbance_history=replicate_fn(opt_state[0].disturbance_history), 
#                                                  param_history=replicate_fn(opt_state[0].param_history)), opt_state[1])
#     replicated_opt_state = shard_opt_state(replicated_opt_state)  # uncomment this line to shard the optimizer state
#     print(replicated_opt_state[0].disturbance_history.shape, replicated_opt_state[0].disturbance_history.sharding)
    
#     # do a few steps of pmapped GD
#     for seed in tqdm.trange(74):
#         batch = make_batch(seed)
#         replicated_params, replicated_opt_state = shmapped_train_step(batch, replicated_params, replicated_opt_state, opt)
    
# %load_ext tensorboard
# %tensorboard --logdir=/tmp/trace

# Profile a straight up sharded one

In [7]:
# import os
# if os.path.isdir('/tmp/trace'):
#     import shutil
#     shutil.rmtree("/tmp/trace")

# @jax.jit
# def sharded_train_step(batch, params, opt_state, opt):
#     loss_fn = LossFn(batch)
#     grads = jax.grad(loss_fn)(params)
#     updates, opt_state = opt[1](grads, opt_state, params, cost_fn=loss_fn)
#     params = optax.apply_updates(params, updates)
#     return params, opt_state

# with jax.profiler.trace('/tmp/trace'):

#     opt = cfg.make_jax()
#     opt = (jax.tree_util.Partial(opt.init), jax.tree_util.Partial(opt.update))
#     params = make_params(0)  # make some pretend parameters
#     opt_state = make_opt_state(params, opt)
    
#     opt_state = shard_opt_state(opt_state)  # uncomment this line to shard the optimizer state
    
#     # do a few steps of pmapped GD
#     for seed in tqdm.trange(90):
#         batch = make_batch(seed)
#         batch = jax.device_put(batch, NamedSharding(MESH, P('batch', None)))
#         params, opt_state = sharded_train_step(batch, params, opt_state, opt)
#     %timeit sharded_train_step(batch, params, opt_state, opt)[0].block_until_ready()

# %load_ext tensorboard
# %tensorboard --logdir=/tmp/trace

# Profile a full run

In [10]:
import os
if os.path.isdir('/tmp/trace'):
    import shutil
    shutil.rmtree("/tmp/trace")

import sys
from absl import logging, flags
from meta_opt.algoperf.runner import run
from meta_opt.experiment import ExperimentConfig
from meta_opt.utils import make_mesh, GLOBAL_MESH

logging.set_verbosity(logging.INFO)
FLAGS = flags.FLAGS
FLAGS(['runner.py', '--config_path=hi'])

experiment_cfg = ExperimentConfig(
    experiment_name='test',
    
    # workload details
    seed=0,
    workload_name='mnist', 
    full_batch=False,  # whether to do full gradient descent on one batch (fixed during the whole training) or regular minibatch SGD
    num_episodes=1,
    num_iters=200,  # if None, uses default for the workload

    framework='jax',
    num_batch_devices=1,
    num_opt_devices=8,

    # how often to do things
    eval_every=-1,
    checkpoint_every=-1,
    log_every=5,

    # other details
    use_wandb=False,
    print_with_colors=True)

with jax.profiler.trace('/tmp/trace'):
    make_mesh(experiment_cfg.num_batch_devices, experiment_cfg.num_opt_devices)
    run(experiment_cfg, cfg)

%load_ext tensorboard
%tensorboard --logdir=/tmp/trace

INFO:absl: [93m[1m8 devices in a mesh Mesh(device_ids=array([[0, 1, 2, 3, 4, 5, 6, 7]]), axis_names=('batch', 'opt')) of shape OrderedDict([('batch', 1), ('opt', 8)])[0m
INFO:absl:Creating directory at /Users/evandogariu/Desktop/meta-opt/experiments/test/mnist_jax for experiments to be saved to.
INFO:absl:[96m[1mno `batch_size` provided. using default of 512 for the workload mnist![0m
INFO:absl: [93m[1m8 devices in a mesh Mesh(device_ids=array([[0, 1, 2, 3, 4, 5, 6, 7]]), axis_names=('batch', 'opt')) of shape OrderedDict([('batch', 1), ('opt', 8)])[0m
INFO:absl:Using [93m[1mcpu[0m for jax
INFO:absl:[1mEXPERIMENT CONFIG[0m
INFO:absl:{ 'experiment_name': 'test',
  'seed': 0,
  'workload_name': 'mnist',
  'full_batch': False,
  'num_episodes': 1,
  'num_iters': 200,
  'batch_size': 512,
  'framework': 'jax',
  'num_batch_devices': 1,
  'num_opt_devices': 8,
  'eval_every': -1,
  'checkpoint_every': -1,
  'log_every': 5,
  'print_with_colors': True,
  'use_wandb': False}
INFO

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 82781), started 0:00:59 ago. (Use '!kill 82781' to kill it.)

# random tests

In [9]:
# import chex
# import numpy as np

# def compute(gpc_params: chex.Array, 
#                         disturbance_history: chex.Array) -> chex.Array:
#     ret = jnp.dot(disturbance_history.T, gpc_params)
#     return ret

# @jax.custom_jvp
# def compute2(gpc_params: chex.Array, 
#                 disturbance_history: chex.Array) -> chex.Array:
#     ret = disturbance_history.T @ gpc_params
#     return ret

# @compute2.defjvp
# def compute2_jvp(primals, tangents):
#     g, d = primals
#     g_dot, d_dot = tangents
#     primal_out = compute2(g, d)
#     tangent_out = d.T @ g_dot
#     return primal_out, tangent_out


# # @custom_jvp
# # def f(x, y):
# #   return x ** 2 * y

# # @f.defjvp
# # def f_jvp(primals, tangents):
# #   x, y = primals
# #   x_dot, y_dot = tangents
# #   primal_out = f(x, y)
# #   tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
# #   return primal_out, tangent_out

# def loss(p, d):
#     return jnp.linalg.norm(compute(p, d) + jax.random.normal(jax.random.PRNGKey(0), d.shape[1:]))
# def loss2(p, d):
#     return jnp.linalg.norm(compute2(p, d) + jax.random.normal(jax.random.PRNGKey(0), d.shape[1:]))

# gpc_params = jnp.array(np.random.randn(32,))
# disturbance_history = jnp.array(np.random.randn(32, 2048))

# print('losses')
# print(loss(gpc_params, disturbance_history))
# print(loss2(gpc_params, disturbance_history))

# print('grads')
# print(jax.grad(loss)(gpc_params, disturbance_history))
# print(jax.grad(loss2)(gpc_params, disturbance_history))