In [177]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

In [178]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

In [213]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict
from flax.training import train_state
from single_gpu import TrainState

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [214]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [215]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=8,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

In [216]:
class KeyState:
    def __init__(self, base_key: jax.random.key):
        self.key = jax.random.key(base_key)

    def __call__(self, num: int = 2):
        self.key, rng = jax.random.split(self.key, num=num)
        return rng

In [217]:
model = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

In [218]:
key = KeyState(config.seed)
x=jax.random.normal(key(), (config.data.batch_size, config.data.input_size))
y = jax.random.randint(key(), (config.data.batch_size,), 0, config.data.num_classes)
variables = model.init({"params": key()}, x, train=False)
params = variables.pop("params")
device_array = np.array(jax.devices())
mesh = Mesh(device_array, ("x",))


In [221]:
def init_device(params, rng, local_model, config):
        lr_scheduler = optax.warmup_cosine_decay_schedule(
        init_value=0.2,
        peak_value=0.5,
        warmup_steps=15,
        decay_steps=300,
        end_value=0.9,
        )   
        tx = optax.chain(
            optax.clip_by_global_norm(1),
            optax.inject_hyperparams(optax.adam)(learning_rate=lr_scheduler),
        )
        state = TrainState.create(
            apply_fn=local_model.apply,
            params=params,
            tx=tx,
            rng=rng,
        )
        return state

In [222]:
sharded_init = shard_map(
            functools.partial(init_device, rng=key(), local_model=model, config=model_config),
            mesh,
            in_specs=(P()),
            out_specs=(P()),
        )

state_initialized = sharded_init(params)

In [224]:
def fold_key(key, axis):
        axis_index = jax.lax.axis_index(axis)
        return jax.random.fold_in(key, axis_index)

In [None]:
def cross_entropy_loss(model, params, key, x, y, train=True):
        dropout_key = fold_key(key, config.data_axis_name)
        B, T = x.shape
        pred = model.apply({'params': params}, x, train=train, rngs={'dropout': key})
        log_prob = jax.nn.log_softmax(pred, axis=-1)
        loss = -jnp.mean(log_prob[jnp.arange(B), y])
        return loss
#loss = cross_entropy_loss(model, params, key(), x, y)

In [226]:
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_initialized.params))

DP Parameters
{'input_dense': {'bias': ((512,),
                          NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                 'kernel': ((784, 512),
                            NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))},
 'output_dense': {'bias': ((8,),
                           NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host)),
                  'kernel': ((512, 8),
                             NamedSharding(mesh=Mesh('x': 8, axis_types=(Auto,)), spec=PartitionSpec(), memory_kind=unpinned_host))}}
