# Function Transforms Uncovered

This notebook demonstrates how transformations like `vmap`, `grad` or `jit` modify a Python program under the hood.

In order to visualize how Nabla works, we need two things:

- `nabla.xpr(<function>, *<args>)` - Shows intermediate representation of a traced program: inputs → operations → outputs  
- `nabla.jit(<function>, show_graph=True)` - Shows compiled MAX graph (JIT only). The JIT-trafo transforms the intermediate representation into optimized machine code. 

## 1. Defining and Visualizing a Python Function

In [1]:
import sys

try:
    import nabla as nb
except ImportError:
    import subprocess
    packages = ["nabla-ml"]
    subprocess.run([sys.executable, "-m", "pip", "install"] + packages, check=True)
    import nabla as nb

print(
    f"🎉 All libraries loaded successfully! Python {sys.version_info.major}.{sys.version_info.minor}"
)

🎉 All libraries loaded successfully! Python 3.10


In [2]:
def function(input):
    return nb.sum(input * 2 * input, axes=0)


input = nb.randn((5,))
print("Base XPR:", nb.xpr(function, input))
print("\nres:", function(input))

Base XPR: { lambda (a:[95mf32[5]:cpu(0)[0m) ;
  let
    b:[95mf32[1]:cpu(0)[0m = unsqueeze[axes=[-1]]
    c:[95mf32[5]:cpu(0)[0m = mul a b
    d:[95mf32[5]:cpu(0)[0m = mul c a
    e:[95mf32[1]:cpu(0)[0m = sum[axes=[-1]] d
    f:[95mf32[]:cpu(0)[0m = squeeze[axes=[-1]] e
  in f }

res: 25.47862:[95mf32[]:cpu(0)[0m


## 3. Gradient Transformation
`nb.grad()` transforms the program by adding `vjp-nodes` during backward pass.

In [3]:
grad_function = nb.grad(function)
print("Gradient XPR:", nb.xpr(grad_function, input))
print("\nGradient res:", grad_function(input))

Gradient XPR: { lambda (a:[95mf32[5]:cpu(0)[0m) ;
  let
    b:[95mf32[]:cpu(0)[0m = 1.0
    c:[95mf32[]:cpu(0)[0m = shallow_copy b
    d:[95mf32[1]:cpu(0)[0m = unsqueeze[axes=[-1]] c
    e:[95mf32[5]:cpu(0)[0m = broadcast[shape=(5,)] d
    f:[95mf32[5]:cpu(0)[0m = shallow_copy a
    g:[95mf32[1]:cpu(0)[0m = unsqueeze[axes=[-1]]
    h:[95mf32[5]:cpu(0)[0m = mul f g
    i:[95mf32[5]:cpu(0)[0m = mul e h
    j:[95mf32[5]:cpu(0)[0m = mul e f
    k:[95mf32[5]:cpu(0)[0m = mul j g
    l:[95mf32[5]:cpu(0)[0m = add i k
  in l }

Gradient res: [7.0562096 1.6006289 3.914952  8.9635725 7.470232 ]:[95mf32[5]:cpu(0)[0m


## 4. Vectorization Transformation
`nb.vmap()` adds batch processing. **Blue numbers** in shapes indicate batched dimensions (vs pink for regular dims).

In [4]:
vmapped_grad_function = nb.vmap(nb.grad(function), in_axes=0)
batched_input = nb.randn((3, 5))
print("Vectorized XPR:", nb.xpr(vmapped_grad_function, batched_input))
print("\nVectorized res:", vmapped_grad_function(batched_input))

Vectorized XPR: { lambda (a:[95mf32[3[95m,[95m5]:cpu(0)[0m) ;
  let
    b:[95mf32[]:cpu(0)[0m = 1.0
    c:[95mf32[[94m1[95m]:cpu(0)[0m = unsqueeze_batch_dims[axes=[-1]] b
    d:[95mf32[[94m3[95m]:cpu(0)[0m = broadcast_batch_dims[shape=(3,)] c
    e:[95mf32[[94m3[95m]:cpu(0)[0m = shallow_copy d
    f:[95mf32[[94m3[95m[95m,[95m1]:cpu(0)[0m = unsqueeze[axes=[-1]] e
    g:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = broadcast[shape=(5,)] f
    h:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = incr_batch_dim_ctr a
    i:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = shallow_copy h
    j:[95mf32[1]:cpu(0)[0m = unsqueeze[axes=[-1]]
    k:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = mul i j
    l:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = mul g k
    m:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = mul g i
    n:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = mul m j
    o:[95mf32[[94m3[95m[95m,[95m5]:cpu(0)[0m = add l n
    p:[95mf32[3[95m,[95m5]:cpu(0

## 5. Compilation Transformation with MAX

In [5]:
jitted_vmapped_grad_function = nb.jit(nb.vmap(nb.grad(function)), show_graph=True)
res = jitted_vmapped_grad_function(batched_input)
print("\nJitted Vectorized res:", res)

mo.graph @nabla_graph(%arg0: !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32> attributes {_kernel_library_paths = [], argument_names = ["input0"], result_names = ["output0"]} {
  %0 = mo.chain.create()
  %1 = mo.constant {value = #M.dense_array<2.000000e+00> : tensor<1xf32>} : !mo.tensor<[1], f32>
  %2 = mo.constant {value = #M.dense_array<1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00> : tensor<3x5xf32>} : !mo.tensor<[3, 5], f32>
  %3 = rmo.mul(%2, %arg0) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[3, 5], f32>
  %4 = rmo.mul(%3, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
  %5 = rmo.mul(%arg0, %1) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[1], f32>) -> !mo.tensor<[3, 5], f32>
  %6 = rmo.mul(%2, %5) : (!mo.tensor<[3, 5], f32>, !mo.tensor<[3, 5], f32>) -> !mo.tensor<[