In [33]:
import os

# Set this to True to run the model on CPU only.
USE_CPU_ONLY = True

flags = os.environ.get("XLA_FLAGS", "")
if USE_CPU_ONLY:
    flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
    # Enforce CPU-only execution
    os.environ["CUDA_VISIBLE_DEVICES"] = ""
else:
    # GPU flags
    flags += (
        "--xla_gpu_enable_triton_softmax_fusion=true "
        "--xla_gpu_triton_gemm_any=false "
        "--xla_gpu_enable_async_collectives=true "
        "--xla_gpu_enable_latency_hiding_scheduler=true "
        "--xla_gpu_enable_highest_priority_async_stream=true "
    )
os.environ["XLA_FLAGS"] = flags

import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
from flax.training import train_state
from single_gpu import TrainState
import time


PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [2]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [3]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [4]:
data_config = ConfigDict(
    dict(
        batch_size=16,
        num_classes=8,
        input_size=32,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=8,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

In [20]:
class KeyState:
    def __init__(self, base_key: jax.random.key):
        self.key = jax.random.key(base_key)

    def __call__(self, num: int = 1):
        self.key, *rng = jax.random.split(self.key, num=num + 1)
        if len(rng) == 1:
            return rng[0]
        else:
            return jnp.array(rng)

In [None]:
model = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

key = KeyState(config.seed)
x = jax.random.normal(key(), (config.data.batch_size, config.data.input_size))
y = jax.random.randint(key(), (config.data.batch_size,), 0, config.data.num_classes)
variables = model.init({"params": key()}, x, train=False)
params = variables.pop("params")
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("x",))
print(
    jax.tree.reduce(lambda acc, current: acc + current.size, jax.tree.leaves(params), 0)
)
print(jax.devices())

336
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [37]:
def init_device(params, local_model, config):
    tx = optax.chain(
        optax.clip_by_global_norm(1),
        optax.inject_hyperparams(optax.adam)(learning_rate=1e-3),
    )
    state = train_state.TrainState.create(
        apply_fn=local_model.apply, params=params, tx=tx
    )
    return params


sharded_init = shard_map(
    functools.partial(init_device, local_model=model, config=model_config),
    mesh,
    in_specs=(P()),
    out_specs=(
        P(
            "x",
        )
    ),
)

state_initialized = sharded_init(params)
state_initialized

{'input_dense': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
  'kernel': Array([[-0.2624882 ,  0.27095973,  0.27826482, ...,  0.01405951,
          -0.3247369 ,  0.06595676],
         [-0.25214282, -0.10684335,  0.08712902, ...,  0.03858314,
          -0.09044334,  0.08779923],
         [-0.28910136, -0.28363675, -0.05046552, ...,  0.190749  ,
           0.15739593,  0.10549708],
         ...,
         [-0.18320438,  0.340858  , -0.34782368, ...,  0.05340033,
          -0.17750989,  0.18437599],
         [ 0.15253912,  0.10370893,  0.11011248, ...,  0.01051933,
          -0.3585401 , -0.30651963],
         [ 0.08335917, -0.1490531 , -0.23801802, ...,  0.29867032,
           0.11949908, -0.23115313]], dtype=float32)},
 '

In [39]:
print("DP Parameters")
pprint(jax.tree.map(lambda x: x.shape, jax.device_get(state_initialized)))

DP Parameters
{'input_dense': {'bias': (64,), 'kernel': (256, 8)},
 'output_dense': {'bias': (64,), 'kernel': (64, 8)}}
