In [177]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

In [178]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

In [228]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
from flax.training import train_state
from single_gpu import TrainState
import time

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [354]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [355]:
data_config = ConfigDict(
    dict(
        batch_size=16,
        num_classes=8,
        input_size=32,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=8,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

In [356]:
class KeyState:
    def __init__(self, base_key: jax.random.key):
        self.key = jax.random.key(base_key)

    def __call__(self, num: int = 2):
        self.key, rng = jax.random.split(self.key, num=num)
        return rng

In [357]:
model = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)
class TrainStateWithRNG(train_state.TrainState):
        rng: Any

In [359]:
key = KeyState(config.seed)
x=jax.random.normal(key(), (config.data.batch_size, config.data.input_size))
y = jax.random.randint(key(), (config.data.batch_size,), 0, config.data.num_classes)
variables = model.init({"params": key()}, x, train=False)
params = variables.pop("params")
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("x",))
print(jax.tree.reduce(lambda acc, current: acc + current.size, jax.tree.leaves(params), 0))


336


In [360]:
def init_device(params, rng, local_model, config):
        lr_scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.2,
        peak_value=0.5,
        warmup_steps=15,
        decay_steps=300,
        end_value=0.9,
        )   
        tx = optax.chain(
            optax.clip_by_global_norm(1),
            optax.inject_hyperparams(optax.adam)(learning_rate=lr_scheduler),
        )
        state = TrainStateWithRNG.create(
            apply_fn=local_model.apply,
            params=params,
            tx=tx,
            rng=rng,
        )
        return state

In [361]:
sharded_init = shard_map(
            functools.partial(init_device, rng=key(), local_model=model, config=model_config),
            mesh,
            in_specs=(P()),
            out_specs=(P()),
        )

state_initialized = sharded_init(params)

In [363]:
def fold_key(key, axis):
        axis_index = jax.lax.axis_index(axis)
        return jax.random.fold_in(key, axis_index)

In [364]:
def cross_entropy_loss(model, params, key, x, y, train=True):
        dropout_key = fold_key(key, "x")
        B, T = x.shape
        pred = model.apply({'params': params}, x, train=train, rngs={'dropout': dropout_key})
        log_prob = jax.nn.log_softmax(pred, axis=-1)
        loss = -jnp.mean(log_prob[jnp.arange(B), y])
        loss = jax.lax.pmean(loss, axis_name="x")  
        return loss
#loss = cross_entropy_loss(model, params, key(), x, y)

In [365]:
def train_step(loss_fn, params, key, *args, **kwargs):
        loss_grad = jax.value_and_grad(
            loss_fn,
            argnums=0,
            has_aux=False
        )
        loss, grads = loss_grad(params, key, *args, **kwargs, train=True)
        # don't need cache in training

        metrics = {
            'loss': loss,
        }
        return grads, metrics

In [None]:
def accumulate_grads(key, x, y, state):
        print("starting training")
        loss_fn = jax.tree_util.Partial(cross_entropy_loss, model)
        train_step_jit = lambda key, params, x, y : train_step(loss_fn, params, key, x, y)
   
        start = time.time()
        train_loss = 0.0

        grads = None
        acc_metrics = None
        for i in range(2):
            grads_step, metrics = train_step_jit(key, state.params, x, y)
            grads = grads_step if grads is None else jax.tree.map(
                lambda x, y: x + y, grads, grads_step
            )
            acc_metrics = metrics if acc_metrics is None else jax.tree.map(jnp.add, acc_metrics, metrics)

        grads = jax.tree.map(lambda x: x / 2, grads)
        
        return grads, acc_metrics

In [422]:
def train_step_device(state, x, y):
        key, step_key = jax.random.split(state.rng)
        grads, step_metrics = accumulate_grads(step_key, x, y, state)
        grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name="x"), grads)
        new_state = state.apply_gradients(grads=grads)
        step_metrics = jax.tree.map(lambda x: jax.lax.pmean(x, axis_name="x"), step_metrics)

        return new_state, step_metrics


In [423]:
train_step_dp_fn =  shard_map(
            train_step_device,
            mesh,
            in_specs=(P(), P("x",), P("x",)),
            out_specs=(P(), P()),
        )

In [424]:
state, metrics = train_step_dp_fn(state_initialized, x, y)
state

starting training


TrainStateWithRNG(step=Array(1, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of DPClassifier(
    # attributes
    config = data_axis_name: data
    dropout_rate: 0.1
    dtype: !!python/name:jax.numpy.bfloat16 ''
    hidden_size: 8
    num_classes: 8
    
)>, params={'input_dense': {'bias': Array([ 0.19999996, -0.19999996, -0.19999997, -0.19999997, -0.19999981,
       -0.19999996, -0.19999996, -0.19999996], dtype=float32), 'kernel': Array([[-0.46248814,  0.4709597 ,  0.47826478,  0.2542902 ,  0.01786101,
         0.2140594 , -0.12473693, -0.13404296],
       [-0.45214278, -0.30684328, -0.11287088,  0.17070016, -0.13469952,
        -0.16141678,  0.10955662, -0.11220071],
       [-0.48910135, -0.4836367 , -0.25046363,  0.2654105 , -0.07899389,
        -0.00925097, -0.042604  , -0.09450248],
       [ 0.46585745, -0.18968448, -0.2280331 , -0.5355575 , -0.15533236,
        -0.00882587,  0.12820327, -0.27822432],
       [-0.12116256,  0.1060341 ,  0.03051075, -0.0734961

In [362]:
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_initialized.params))

DP Parameters
{'input_dense': {'bias': ((8,),
                          NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                 'kernel': ((32, 8),
                            NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))},
 'output_dense': {'bias': ((8,),
                           NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                  'kernel': ((8, 8),
                             NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))}}


In [378]:
print(state)

{'input_dense': {'bias': Array([-0.05493164,  0.0703125 ,  0.10253906,  0.10351562,  0.01184082,
        0.06005859,  0.13085938,  0.07275391], dtype=float32), 'kernel': Array([[ 0.05273438, -0.04614258, -0.06933594, -0.13964844, -0.03833008,
        -0.01757812, -0.06933594,  0.00747681],
       [ 0.10449219,  0.04125977,  0.01916504, -0.05541992,  0.03564453,
         0.04882812, -0.1484375 ,  0.04663086],
       [ 0.08496094,  0.05517578,  0.00106812, -0.0300293 ,  0.03222656,
         0.09423828,  0.02941895,  0.00457764],
       [-0.02563477, -0.05615234,  0.10302734,  0.00665283,  0.04418945,
        -0.06030273, -0.03491211,  0.06225586],
       [-0.06591797,  0.01416016, -0.02746582,  0.02087402, -0.00337219,
         0.05004883,  0.09716797,  0.00543213],
       [-0.04711914,  0.02038574, -0.00531006,  0.06347656,  0.01806641,
        -0.02807617,  0.05834961, -0.02404785],
       [-0.02282715,  0.04174805,  0.08105469,  0.06591797,  0.03930664,
        -0.02038574, -0.0092773

In [410]:
p = state.params['input_dense']['kernel']
jax.debug.visualize_array_sharding(x)

In [425]:
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state.params))
print("Metrics")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), metrics))

DP Parameters
{'input_dense': {'bias': ((8,),
                          NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                 'kernel': ((32, 8),
                            NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))},
 'output_dense': {'bias': ((8,),
                           NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                  'kernel': ((8, 8),
                             NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))}}
Metrics
{'loss': ((),
          NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))}


In [440]:
from jax.sharding import NamedSharding

In [441]:
jax.debug.visualize_sharding((16,32), NamedSharding(mesh, P("x")))