In [1]:
from dataclasses import dataclass, field
from typing import List

import jax
from jax import numpy as jnp
from jax.sharding import AxisType, NamedSharding, PartitionSpec as P
import optax

from sws import Config

In [2]:
jax.config.update('jax_num_cpu_devices', 8)

In [3]:
mesh = jax.make_mesh((8,), ("data",), (AxisType.Explicit,))
jax.set_mesh(mesh)
print(f"{mesh=}")

mesh=Mesh(axis_sizes=(8,), axis_names=('data',), axis_types=(Explicit,))


In [4]:
SHARDING_RULES = {
    "dp": {
        "batch": "data",
        "act_seq": None,
        "act_vocab": None,
        "act_embed": None,
        "act_intermediate": None,
        "act_q": None,
        "act_kv": None,
        "model_seq": None,
        "model_vocab": None,
        "model_embed": None,
        "model_intermediate": None,
        "model_q": None,
        "model_kv": None,
    },
    "fsdp": {
        "batch": "data",
        "act_seq": None,
        "act_vocab": None,
        "act_embed": None,
        "act_intermediate": None,
        "act_q": None,
        "act_kv": None,
        "model_seq": None,
        "model_vocab": "data",
        "model_embed": None,
        "model_intermediate": "data",
        "model_q": None,
        "model_kv": None,
        "model_head": "data",
    },

}

_current_strategy = "dp"


def logical_to_physical(logical_axes):
    rules = SHARDING_RULES[_current_strategy]
    return P(*[rules.get(axis, None) for axis in logical_axes])

In [5]:
@jax.tree_util.register_dataclass
@dataclass
class AttentionWeights:
  q_proj: jax.Array
  k_proj: jax.Array
  v_proj: jax.Array
  o_proj: jax.Array

@jax.tree_util.register_dataclass
@dataclass
class MLPWeights:
  up_proj: jax.Array
  down_proj: jax.Array

@jax.tree_util.register_dataclass
@dataclass
class LayerWeights:
  attention_weights: AttentionWeights
  mlp_weights: MLPWeights

@jax.tree_util.register_dataclass
@dataclass
class ModelWeights:
  embed: jax.Array
  layer_weights: List[LayerWeights]
  unembed: jax.Array

In [6]:
def rms_norm(x, eps = 1e-6):
  return (x * jax.lax.rsqrt(jnp.mean(x.astype(jnp.float32) ** 2, axis=-1, keepdims=True) + eps)).astype(x.dtype)


def precompute_rope_embeddings(seq_len, head_dim, base):
  channel_range = jnp.arange(0, head_dim, 2, dtype=jnp.float32)
  inv_freq = 1.0 / (base ** (channel_range / head_dim))
  t = jnp.arange(seq_len, dtype=jnp.float32)
  freqs = jnp.outer(t, inv_freq)
  cos, sin = jnp.cos(freqs), jnp.sin(freqs)
  cos, sin = cos.astype(jnp.bfloat16), sin.astype(jnp.bfloat16)
  cos, sin = cos[None, :, None, :], sin[None, :, None, :]
  return cos, sin


def apply_rope(x, cos, sin):
    H = x.shape[-1] // 2
    x1, x2 = x[..., :H], x[..., H:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return jnp.concat([y1, y2], axis=-1)
  

def attention(x, w: AttentionWeights, cos, sin):
  # B = batch size
  # D = embedding dimension
  # S = length of the key/value (source)
  # T = length of the query (target)
  # N = number of attention heads
  # H = dimensions of each attention head
  # K = number of key/value heads
  # G = number of groups, which equals to N // K
  
  T = x.shape[1]
  H = w.q_proj.shape[2]
  G = w.q_proj.shape[1] // w.k_proj.shape[1]
  
  q = jnp.einsum(
    "BTD, DNH -> BTNH", x, w.q_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_q", "act_head"))
    )
  k = jnp.einsum(
    "BSD, DKH -> BSKH", x, w.k_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_kv", "act_head"))
  )
  v = jnp.einsum(
    "BSD, DKH -> BSKH", x, w.v_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_kv", "act_head"))
  )

  q = apply_rope(q, cos, sin)
  k = apply_rope(k, cos, sin)

  q = rms_norm(q)
  k = rms_norm(k)

  k = jnp.repeat(
    k, G, axis=2,
    out_sharding=logical_to_physical(("batch", "act_seq", "act_q", "act_head"))
  )
  v = jnp.repeat(
    v, G, axis=2,
    out_sharding=logical_to_physical(("batch", "act_seq", "act_q", "act_head"))
  )
  
  logits = jnp.einsum(
    "BTNH, BSNH -> BNTS", q, k,
    out_sharding=logical_to_physical(("batch", "act_q", "act_seq", "act_seq"))
  )
  logits *= jax.lax.rsqrt(jnp.array(H, dtype=jnp.bfloat16))
  causal_mask = jnp.tril(jnp.ones((T, T,), dtype=jnp.bfloat16))
  masked_logits = jnp.where(causal_mask, logits, jnp.array(float("-inf"), dtype=jnp.bfloat16))
  probs = jax.nn.softmax(masked_logits.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)
  encoded = jnp.einsum(
    "BNTS, BSNH -> BTNH", probs, v,
    out_sharding=logical_to_physical(("batch", "act_seq", "act_q", "act_head"))
  )
  out = jnp.einsum(
    "BTNH, NHD -> BTD", encoded, w.o_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_embed"))
  )

  return out

def mlp(x, w: MLPWeights):
  intermediate = jnp.matmul(
    x, w.up_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_intermediate"))
  )
  return jnp.matmul(
    jax.nn.silu(intermediate), w.down_proj.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_embed"))
  )


def layer(x, w: LayerWeights, cos, sin):
  x = x + attention(rms_norm(x), w.attention_weights, cos, sin)
  x = x + mlp(rms_norm(x), w.mlp_weights)
  return x

@jax.jit
def forward(x, w: ModelWeights, cos, sin):
  x = w.embed.at[x].get(out_sharding=logical_to_physical(("batch", "act_seq", "act_embed"))).astype(jnp.bfloat16)
  for layer_weights in w.layer_weights:
    x = layer(x, layer_weights, cos, sin)
  logits = jnp.matmul(
    x, w.unembed.astype(jnp.bfloat16),
    out_sharding=logical_to_physical(("batch", "act_seq", "act_vocab"))
  )
  return logits

In [7]:
c = Config()

c.model.seq_len = 1024
c.model.vocab_size = 50304
c.model.num_layers = 1
c.model.hidden_dim = 512
c.model.intermediate_dim = lambda: 4 * c.model.hidden_dim
c.model.num_attention_heads = 8
c.model.num_key_value_heads = 8
c.model.head_dim = lambda: c.model.hidden_dim // c.model.num_attention_heads
c.model.rope_base = 10000

c.optimizer.learning_rate = 0.0001
c.optimizer.weight_decay = 0.01
c.optimizer.beta1 = 0.9
c.optimizer.beta2 = 0.999
c.optimizer.eps = 1e-8

c = c.finalize()

In [8]:
def init_model_weights(
    vocab_size,
    num_layers,
    hidden_dim,
    intermediate_dim,
    num_attention_heads,
    num_key_value_heads,
    head_dim
):
    num_weight_arrays = 1 + (num_layers * 6) + 1
    key = jax.random.key(69420)
    key_iter = iter(jax.random.split(key, num_weight_arrays))
    
    init_fn = jax.nn.initializers.lecun_normal()
    
    embed = init_fn(
        next(key_iter), (vocab_size, hidden_dim), dtype=jnp.float32,
        out_sharding=logical_to_physical(("model_vocab", "model_embed"))
    )
    layer_weights = [
        LayerWeights(
            attention_weights=AttentionWeights(
                q_proj=init_fn(
                    next(key_iter), (hidden_dim, num_attention_heads, head_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_embed", "model_q", "model_head"))
                ),
                k_proj=init_fn(
                    next(key_iter), (hidden_dim, num_key_value_heads, head_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_embed", "model_kv", "model_head"))
                ),
                v_proj=init_fn(
                    next(key_iter), (hidden_dim, num_key_value_heads, head_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_embed", "model_kv", "model_head"))
                ),
                o_proj=init_fn(
                    next(key_iter), (num_attention_heads, head_dim, hidden_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_q", "model_head", "model_embed"))
                )
            ),
            mlp_weights = MLPWeights(
                up_proj=init_fn(
                    next(key_iter), (hidden_dim, intermediate_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_embed", "model_intermediate"))
                ),
                down_proj=init_fn(
                    next(key_iter), (intermediate_dim, hidden_dim), dtype=jnp.float32,
                    out_sharding=logical_to_physical(("model_intermediate", "model_embed"))
                )
            )
        )
        for _ in range(num_layers)
    ]
    unembed = init_fn(
        next(key_iter), (hidden_dim, vocab_size), dtype=jnp.float32,
        out_sharding=logical_to_physical(("model_embed", "model_vocab"))
    )
    model_weights = ModelWeights(embed=embed, layer_weights=layer_weights, unembed=unembed)

    return model_weights


In [9]:
model_weights = init_model_weights(
    vocab_size=c.model.vocab_size,
    num_layers=c.model.num_layers,
    hidden_dim=c.model.hidden_dim,
    intermediate_dim=c.model.intermediate_dim,
    num_attention_heads=c.model.num_attention_heads,
    num_key_value_heads=c.model.num_key_value_heads,
    head_dim=c.model.head_dim
)
optimizer = optax.adamw(
    learning_rate=c.optimizer.learning_rate,
    weight_decay=c.optimizer.weight_decay,
    b1=c.optimizer.beta1,
    b2=c.optimizer.beta2,
    eps=c.optimizer.eps,
)
optimizer_state = optimizer.init(model_weights)

In [10]:
optimizer_state = (
    jax.tree.map(lambda x: jax.sharding.reshard(x, P("data",)) if x.ndim > 1 else x, optimizer_state[0]),
    optimizer_state[1],
    optimizer_state[2],
)

In [11]:
model_weights.layer_weights[0].attention_weights.q_proj.sharding

NamedSharding(mesh=Mesh('data': 8, axis_types=(Explicit,)), spec=PartitionSpec(None, None, None), memory_kind=device)

In [12]:
optimizer_state[0].mu.layer_weights[0].attention_weights.q_proj.sharding

NamedSharding(mesh=Mesh('data': 8, axis_types=(Explicit,)), spec=PartitionSpec('data', None, None), memory_kind=device)

In [13]:
cos, sin = precompute_rope_embeddings(c.model.seq_len, c.model.head_dim, c.model.rope_base)
x = jnp.ones((8, 1024), dtype=jnp.int32, out_sharding=logical_to_physical(("batch", "act_seq")))
y = jnp.ones((8, 1024), dtype=jnp.int32, out_sharding=logical_to_physical(("batch", "act_seq")))
logits = forward(x, model_weights, cos, sin)
logits.shape, logits.dtype

((8, 1024, 50304), dtype(bfloat16))

In [14]:
def loss_fn(w, cos, sin, x, y):
    logits = forward(x, w, cos, sin)
    label_logits = jnp.take_along_axis(logits, y[..., jnp.newaxis], axis=-1)
    log_normalizers = jax.nn.logsumexp(logits, axis=-1, keepdims=True)
    return jnp.mean(log_normalizers - label_logits)

In [15]:
@jax.sharding.auto_axes
def apply_updates(model_weights, updates):
  return jax.tree.map(lambda weights, updates: updates + weights, model_weights, updates)


In [21]:
from functools import partial
from jax.experimental.shard_map import shard_map


  from jax.experimental.shard_map import shard_map


In [22]:
def make_zero_optimizer_update(optimizer, model_weights, optimizer_state):
    """Create a ZeRO-1/2 style update function with proper specs for the given structures."""
    
    # Build partition specs matching the pytree structures
    # Grads and model weights are replicated (P())
    # Optimizer state is sharded on first axis for arrays with ndim > 1
    grads_spec = jax.tree.map(lambda _: P(), model_weights)
    model_spec = jax.tree.map(lambda _: P(), model_weights)
    
    # For optimizer state: shard arrays with ndim > 1 on first axis, replicate scalars/1D
    def opt_state_to_spec(x):
        if hasattr(x, 'ndim') and x.ndim > 1:
            return P("data",)
        return P()
    
    opt_state_spec = jax.tree.map(opt_state_to_spec, optimizer_state)
    
    # Output specs: updates are replicated, optimizer state stays sharded
    updates_spec = jax.tree.map(lambda _: P(), model_weights)
    
    @partial(shard_map, mesh=mesh,
             in_specs=(grads_spec, opt_state_spec, model_spec),
             out_specs=(updates_spec, opt_state_spec),
             check_rep=False)
    def _zero_update(grads, opt_state, model):
        # Get axis info for slicing
        axis_size = jax.lax.psum(1, "data")
        axis_index = jax.lax.axis_index("data")
        
        # Reduce-scatter gradients: each device gets 1/N of fully-reduced grads
        def reduce_scatter_grad(g):
            if g.ndim > 1:
                return jax.lax.psum_scatter(g, "data", scatter_dimension=0, tiled=True)
            else:
                # For 1D arrays (scalars after vmap, biases, etc.), just reduce
                return jax.lax.pmean(g, "data")
        
        grads_sharded = jax.tree.map(reduce_scatter_grad, grads)
        
        # Slice model weights to match sharded gradients (for weight decay)
        def slice_to_shard(arr):
            if arr.ndim > 1:
                shard_size = arr.shape[0] // axis_size
                return jax.lax.dynamic_slice_in_dim(arr, axis_index * shard_size, shard_size, axis=0)
            return arr
        
        model_sharded = jax.tree.map(slice_to_shard, model)
        
        # Local optimizer update with sharded grads, sharded state, and sliced model
        updates_sharded, new_opt_state = optimizer.update(grads_sharded, opt_state, model_sharded)
        
        # All-gather updates to replicate them
        def all_gather_update(u):
            if u.ndim > 1:
                return jax.lax.all_gather(u, "data", axis=0, tiled=True)
            else:
                return u
        
        updates_full = jax.tree.map(all_gather_update, updates_sharded)
        
        return updates_full, new_opt_state
    
    return _zero_update


# Create the ZeRO update function with the actual structures
zero_update = make_zero_optimizer_update(optimizer, model_weights, optimizer_state)

In [23]:
@jax.jit
def train_step(model_weights, optimizer_state, cos, sin, x, y):
    print("model, optimizer_state")
    print(jax.typeof(model_weights.layer_weights[0].attention_weights.q_proj), jax.typeof(optimizer_state[0].mu.layer_weights[0].attention_weights.q_proj))
    loss, grads = jax.value_and_grad(loss_fn)(model_weights, cos, sin, x, y)
    print("model, optimizer_state, grads")
    print(jax.typeof(model_weights.layer_weights[0].attention_weights.q_proj), jax.typeof(optimizer_state[0].mu.layer_weights[0].attention_weights.q_proj), jax.typeof(grads.layer_weights[0].attention_weights.q_proj))
    # updates, optimizer_state = optimizer.update(grads, optimizer_state, model_weights)
    updates, optimizer_state = zero_update(grads, optimizer_state, model_weights)
    print("model, updates, optimizer_state")
    print(jax.typeof(model_weights.layer_weights[0].attention_weights.q_proj), jax.typeof(updates.layer_weights[0].attention_weights.q_proj), jax.typeof(optimizer_state[0].mu.layer_weights[0].attention_weights.q_proj))
    model_weights = optax.apply_updates(model_weights, updates)
    print("model, optimizer_state")
    print(jax.typeof(model_weights.layer_weights[0].attention_weights.q_proj), jax.typeof(optimizer_state[0].mu.layer_weights[0].attention_weights.q_proj))
    return model_weights, optimizer_state, loss

In [24]:
# for _ in range(100):
#     model_weights, optimizer_state, loss = train_step(model_weights, optimizer_state, cos, sin,x, x)

In [25]:
jax.make_jaxpr(train_step)(model_weights, optimizer_state, cos, sin, x, y)

model, optimizer_state
float32[512,8,64] float32[512@data,8,64]
model, optimizer_state, grads
float32[512,8,64] float32[512@data,8,64] float32[512,8,64]
model, updates, optimizer_state
float32[512,8,64] float32[512,8,64] float32[512@data,8,64]
model, optimizer_state
float32[512,8,64] float32[512@data,8,64]


{ [34;1mlambda [39;22m; a[35m:f32[50304,512][39m b[35m:f32[512,8,64][39m c[35m:f32[512,8,64][39m d[35m:f32[512,8,64][39m e[35m:f32[8,64,512][39m
    f[35m:f32[512,2048][39m g[35m:f32[2048,512][39m h[35m:f32[512,50304][39m i[35m:i32[][39m j[35m:f32[50304@data,512][39m
    k[35m:f32[512@data,8,64][39m l[35m:f32[512@data,8,64][39m m[35m:f32[512@data,8,64][39m n[35m:f32[8@data,64,512][39m
    o[35m:f32[512@data,2048][39m p[35m:f32[2048@data,512][39m q[35m:f32[512@data,50304][39m r[35m:f32[50304@data,512][39m
    s[35m:f32[512@data,8,64][39m t[35m:f32[512@data,8,64][39m u[35m:f32[512@data,8,64][39m v[35m:f32[8@data,64,512][39m
    w[35m:f32[512@data,2048][39m x[35m:f32[2048@data,512][39m y[35m:f32[512@data,50304][39m z[35m:bf16[1,1024,1,32][39m
    ba[35m:bf16[1,1024,1,32][39m bb[35m:i32[8@data,1024][39m bc[35m:i32[8@data,1024][39m. [34;1mlet
    [39;22mbd[35m:f32[50304,512][39m be[35m:f32[512,8,64][39m bf[35m:f32[512,8,64]

In [20]:

_, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
print()
_, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
print()
_, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
print()
_, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
print()
_, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
print()
# with jax.profiler.trace("f_profiles"):
#     _, _, loss = train_step(model_weights, optimizer_state, cos, sin, x, y)
#     loss.block_until_ready()







