Below will print the HLO representation of an implementation of MHA self attention


```
Performs multi-head self-attention.
Args:
   params: Tuple containing the weight matrices (wq, wk, wv, wo).

   x: Input tensor of shape (batch_size, sequence_length, embedding_dimension).

   num_heads: Number of attention heads.

   head_dim: Dimensionality of each attention head.
   
Returns:
   Output tensor of shape (batch_size, sequence_length, embedding_dimension).
```



In [11]:
import jax
import jax.numpy as jnp
from jax.nn import softmax

def multihead_self_attention(params, x, num_heads):
  wq, wk, wv, wo = params # Query, Key, Value, Output weights
  batch_size, seq_len, embed_dim = x.shape
  head_dim = embed_dim // num_heads

  # Project input to Q, K, V for each head
  q = jnp.dot(x, wq).reshape(batch_size, num_heads, seq_len, head_dim)
  k = jnp.dot(x, wk).reshape(batch_size, num_heads, seq_len, head_dim)
  v = jnp.dot(x, wv).reshape(batch_size, num_heads, seq_len, head_dim)

  # Calculate attention scores for each head
  scores = jnp.einsum('bhqd, bhkd -> bhqk', q, k) / jnp.sqrt(head_dim)

  # Apply softmax
  weights = softmax(scores, axis=-1)

  # Apply attention weights to V for each head
  output = jnp.einsum('bhqk, bhkd -> bhqd', weights, v)

  # Concatenate head outputs
  output = output.transpose(0, 2, 1, 3)
  output = output.reshape(batch_size, seq_len, embed_dim)

  # Final output projection
  output = jnp.dot(output, wo)
  return output

key = jax.random.PRNGKey(3)
seq_len = 64
embed_dim = 256
num_heads = 4 # Number of attention heads
head_dim = embed_dim // num_heads # Dimensionality of each head

# Simplified shapes (Batch=1, SeqLen, EmbedDim)
input_shape = (1, seq_len, embed_dim)
x = jax.random.normal(key, input_shape)

# Dummy parameters
keys = jax.random.split(key, 4)
# Reshape weight matrices to incorporate head_dim
wq = jax.random.normal(keys[0], (embed_dim, embed_dim))
wk = jax.random.normal(keys[1], (embed_dim, embed_dim))
wv = jax.random.normal(keys[2], (embed_dim, embed_dim))
wo = jax.random.normal(keys[3], (embed_dim, embed_dim))
params = (wq, wk, wv, wo)

# JIT compile
jit_attention = jax.jit(multihead_self_attention, static_argnums=(2,))
lowered = jit_attention.lower(params, x, num_heads)

# Extract the HLO text
hlo_text = lowered.compiler_ir(dialect="hlo").as_hlo_text()
print(hlo_text)


HloModule jit_multihead_self_attention, entry_computation_layout={(f32[256,256]{1,0}, f32[256,256]{1,0}, f32[256,256]{1,0}, f32[256,256]{1,0}, f32[1,64,256]{2,1,0})->f32[1,64,256]{2,1,0}}

region_0.20 {
  Arg_0.21 = f32[] parameter(0)
  Arg_1.22 = f32[] parameter(1)
  ROOT maximum.23 = f32[] maximum(Arg_0.21, Arg_1.22)
}

region_1.32 {
  Arg_0.33 = f32[] parameter(0)
  Arg_1.34 = f32[] parameter(1)
  ROOT add.35 = f32[] add(Arg_0.33, Arg_1.34)
}

ENTRY main.46 {
  Arg_4.5 = f32[1,64,256]{2,1,0} parameter(4)
  Arg_0.1 = f32[256,256]{1,0} parameter(0)
  dot.12 = f32[1,64,256]{2,1,0} dot(Arg_4.5, Arg_0.1), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  reshape.13 = f32[1,4,64,64]{3,2,1,0} reshape(dot.12)
  Arg_1.2 = f32[256,256]{1,0} parameter(1)
  dot.14 = f32[1,64,256]{2,1,0} dot(Arg_4.5, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}
  reshape.15 = f32[1,4,64,64]{3,2,1,0} reshape(dot.14)
  dot.18 = f32[1,4,64,64]{3,2,1,0} dot(reshape.13, reshape.15), lhs_batch_dims=

Mixed Precision Forward Pass of a CNN to RELU
    
```
A two-layer Conv→ReLU block in mixed precision:
     1) cast inputs & weights to bfloat16
     2) conv + bias add in bfloat16
     3) cast back to float32 + ReLU
     4) repeat with stride=2
```

In [26]:
import jax
import jax.numpy as jnp
from jax import lax
from jax.nn import relu

def conv_block_mp(params, x):
    # -- first conv --
    x16   = lax.convert_element_type(x,    jnp.bfloat16)
    w1_16 = lax.convert_element_type(params['w1'], jnp.bfloat16)
    b1_16 = lax.convert_element_type(params['b1'], jnp.bfloat16)
    y16   = lax.conv_general_dilated(
               x16, w1_16,
               window_strides=(1,1),
               padding='SAME',
               dimension_numbers=('NHWC', 'HWIO', 'NHWC')
           ) + b1_16
    y     = lax.convert_element_type(y16, jnp.float32)
    y     = relu(y)

    # -- second conv (downsample) --
    y16   = lax.convert_element_type(y,    jnp.bfloat16)
    w2_16 = lax.convert_element_type(params['w2'], jnp.bfloat16)
    b2_16 = lax.convert_element_type(params['b2'], jnp.bfloat16)
    z16   = lax.conv_general_dilated(
               y16, w2_16,
               window_strides=(2,2),
               padding='SAME',
               dimension_numbers=('NHWC', 'HWIO', 'NHWC')
           ) + b2_16
    z     = lax.convert_element_type(z16, jnp.float32)
    return relu(z)


key = jax.random.PRNGKey(0)
x0 = jax.random.normal(key, (1, 32, 32, 3))   # e.g. CIFAR-style image
params = {
  'w1': jax.random.normal(key, (3,3,3,16)),   # 16 filters
  'b1': jnp.zeros((16,)),
  'w2': jax.random.normal(key, (3,3,16,32)),  # 32 filters
  'b2': jnp.zeros((32,))
}

# 2) jit & lower
jit_block = jax.jit(conv_block_mp)
lowered  = jit_block.lower(params, x0)

# 3) dump HLO text
print(lowered.compiler_ir(dialect="hlo").as_hlo_text())

HloModule jit_conv_block_mp, entry_computation_layout={(f32[16]{0}, f32[32]{0}, f32[3,3,3,16]{3,2,1,0}, f32[3,3,16,32]{3,2,1,0}, f32[1,32,32,3]{3,2,1,0})->f32[1,16,16,32]{3,2,1,0}}

relu.16 {
  Arg_0.17 = f32[1,32,32,16]{3,2,1,0} parameter(0)
  constant.18 = f32[] constant(0)
  broadcast.19 = f32[1,32,32,16]{3,2,1,0} broadcast(constant.18), dimensions={}
  ROOT maximum.20 = f32[1,32,32,16]{3,2,1,0} maximum(Arg_0.17, broadcast.19)
}

relu_0.32 {
  Arg_0.33 = f32[1,16,16,32]{3,2,1,0} parameter(0)
  constant.34 = f32[] constant(0)
  broadcast.35 = f32[1,16,16,32]{3,2,1,0} broadcast(constant.34), dimensions={}
  ROOT maximum.36 = f32[1,16,16,32]{3,2,1,0} maximum(Arg_0.33, broadcast.35)
}

ENTRY main.38 {
  Arg_4.5 = f32[1,32,32,3]{3,2,1,0} parameter(4)
  convert.6 = bf16[1,32,32,3]{3,2,1,0} convert(Arg_4.5)
  Arg_2.3 = f32[3,3,3,16]{3,2,1,0} parameter(2)
  convert.7 = bf16[3,3,3,16]{3,2,1,0} convert(Arg_2.3)
  convolution.9 = bf16[1,32,32,16]{3,2,1,0} convolution(convert.6, convert.7), win

In [24]:
import jax
import jax.numpy as jnp
import optax
from functools import partial

# 1) Set up a simple optimizer (e.g. SGD)
learning_rate = 1e-2
opt = optax.sgd(learning_rate)

# 2) Define the pmapped train step
@partial(jax.pmap, axis_name="batch")
def train_step(params, opt_state, batch):
    # Forward+loss
    def loss_fn(p, x, y):
        logits = x @ p['w'] + p['b']
        return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, y))
    (loss, grads) = jax.value_and_grad(loss_fn)(params, batch['x'], batch['y'])

    # All-reduce the gradients across devices
    grads = jax.lax.pmean(grads, axis_name="batch")

    # Optimizer update
    updates, new_opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)

    return new_params, new_opt_state, loss

# 3) Dummy data setup for N devices
num_devices = jax.local_device_count()
per_device_batch = 8
feature_dim = 16
num_classes = 10

key = jax.random.PRNGKey(42)
# Create one key per device
keys = jax.random.split(key, num_devices)
# Dummy inputs: shape (num_devices, per_device_batch, feature_dim)
x = jnp.stack([jax.random.normal(k, (per_device_batch, feature_dim)) for k in keys])
y = jnp.stack([jax.random.randint(k, (per_device_batch,), 0, num_classes) for k in keys])
batch = {'x': x, 'y': y}

# 4) Initialize & shard (replicate) parameters and optimizer state
init_params = {
    'w': jax.random.normal(key, (feature_dim, num_classes)),
    'b': jnp.zeros((num_classes,))
}
# replicate across devices
params = jax.device_put_replicated(init_params, jax.local_devices())
opt_state = jax.device_put_replicated(opt.init(init_params), jax.local_devices())

# 5) Lower & dump the HLO
lowered = train_step.lower(params, opt_state, batch)
hlo_text = lowered.compiler_ir(dialect="hlo").as_hlo_text()
print(hlo_text)


HloModule pmap_train_step, entry_computation_layout={(f32[1,10]{1,0}, f32[1,16,10]{2,1,0}, f32[1,8,16]{2,1,0}, s32[1,8]{1,0})->(f32[1,10]{1,0}, f32[1,16,10]{2,1,0}, f32[1]{0})}

region_0.31 {
  Arg_0.32 = f32[] parameter(0)
  Arg_1.33 = f32[] parameter(1)
  ROOT maximum.34 = f32[] maximum(Arg_0.32, Arg_1.33)
}

region_1.43 {
  Arg_0.44 = pred[] parameter(0)
  Arg_1.45 = pred[] parameter(1)
  ROOT and.46 = pred[] and(Arg_0.44, Arg_1.45)
}

take_along_axis.47 {
  Arg_1.49 = s32[8,1]{1,0} parameter(1)
  constant.59 = s32[] constant(0)
  broadcast.60 = s32[8,1]{1,0} broadcast(constant.59), dimensions={}
  compare.61 = pred[8,1]{1,0} compare(Arg_1.49, broadcast.60), direction=LT
  constant.56 = s32[] constant(10)
  broadcast.57 = s32[8,1]{1,0} broadcast(constant.56), dimensions={}
  add.62 = s32[8,1]{1,0} add(Arg_1.49, broadcast.57)
  select.63 = s32[8,1]{1,0} select(compare.61, add.62, Arg_1.49)
  reshape.64 = s32[8,1,1]{2,1,0} reshape(select.63)
  constant.54 = s32[] constant(0)
  broadca