# Understanding Nabla - Program Transformations (Part 1)

This notebook demonstrates how transformations like `vmap`, `grad` or `jit` modify a Python program in Nabla.

In order to visualize how Nabla works under the hood, we need two things:

- `nabla.xpr(<function>, *<args>)` - Shows intermediate representation of a traced program: inputs → operations → outputs  
- `nabla.jit(<function>, show_graph=True)` - Shows compiled MAX graph (JIT only). The JIT-trafo transforms the intermediate representation into optimized machine code. 

## 1. Defining and Visualizing a Python Function

In [1]:
# Installation
import sys

IN_COLAB = "google.colab" in sys.modules

try:
    import nabla as nb
except ImportError:
    import subprocess

    subprocess.run(
        [
            sys.executable,
            "-m",
            "pip",
            "install",
            "modular",
            "--extra-index-url",
            "https://download.pytorch.org/whl/cpu",
            "--index-url",
            "https://dl.modular.com/public/nightly/python/simple/",
        ],
        check=True,
    )
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "nabla-ml", "--upgrade"], check=True
    )
    import nabla as nb

print(
    f"🎉 Nabla is ready! Running on Python {sys.version_info.major}.{sys.version_info.minor}"
)

🎉 Nabla is ready! Running on Python 3.12


In [None]:
def function(input):
    return nb.sum(input * 2 * input, axes=0)


input = nb.randn((5,))
print("Base XPR:", nb.xpr(function, input))
print("\nres:", function(input))

Base XPR: { lambda (a:[95mf32[5][0m) ;
  let
    b:[95mf32[][0m = 2.0
    c:[95mf32[5][0m = mul a b
    d:[95mf32[5][0m = mul c a
    e:[95mf32[1][0m = sum[axes=[-1]] d
    f:[95mf32[][0m = squeeze[axes=[-1]] e
  in f }

res: 25.47862:[95mf32[][0m


## 3. Gradient Transformation
`nb.grad()` transforms the program by adding `vjp-nodes` during backward pass.

In [9]:
grad_function = nb.grad(function)
print("Gradient XPR:", nb.xpr(grad_function, input))
print("\nGradient res:", grad_function(input))

Gradient XPR: { lambda (a:[95mf32[5][0m) ;
  let
    b:[95mf32[][0m = array(1., dtype=float32)
    c:[95mf32[1][0m = unsqueeze[axes=[-1]] b
    d:[95mf32[5][0m = broadcast[shape=(5,)] c
    e:[95mf32[5][0m = shallow_copy a
    f:[95mf32[][0m = 2.0
    g:[95mf32[5][0m = mul e f
    h:[95mf32[5][0m = mul d g
    i:[95mf32[5][0m = mul d e
    j:[95mf32[5][0m = mul i f
    k:[95mf32[5][0m = add h j
  in k }

Gradient res: [7.0562096 1.6006289 3.914952  8.9635725 7.470232 ]:[95mf32[5][0m


## 4. Vectorization Transformation
`nb.vmap()` adds batch processing. **Blue numbers** in shapes indicate batched dimensions (vs pink for regular dims).

In [10]:
vmapped_grad_function = nb.vmap(nb.grad(function), in_axes=0)
batched_input = nb.randn((3, 5))
print("Vectorized XPR:", nb.xpr(vmapped_grad_function, batched_input))
print("\nVectorized res:", vmapped_grad_function(batched_input))

Vectorized XPR: { lambda (a:[95mf32[3[95m,[95m5][0m) ;
  let
    b:[95mf32[[94m3[95m][0m = shallow_copy
    c:[95mf32[[94m3[95m[95m,[95m1][0m = unsqueeze[axes=[-1]] b
    d:[95mf32[[94m3[95m[95m,[95m5][0m = broadcast[shape=(5,)] c
    e:[95mf32[[94m3[95m[95m,[95m5][0m = incr_batch_dim_ctr a
    f:[95mf32[[94m3[95m[95m,[95m5][0m = permute_batch_dims[axes=(-1,)] e
    g:[95mf32[[94m3[95m[95m,[95m5][0m = shallow_copy f
    h:[95mf32[][0m = 2.0
    i:[95mf32[[94m3[95m[95m,[95m5][0m = mul g h
    j:[95mf32[[94m3[95m[95m,[95m5][0m = mul d i
    k:[95mf32[[94m3[95m[95m,[95m5][0m = mul d g
    l:[95mf32[[94m3[95m[95m,[95m5][0m = mul k h
    m:[95mf32[[94m3[95m[95m,[95m5][0m = add j l
    n:[95mf32[[94m3[95m[95m,[95m5][0m = permute_batch_dims[axes=(-1,)] m
    o:[95mf32[3[95m,[95m5][0m = decr_batch_dim_ctr n
  in o }

Vectorized res: [[ 7.0562096   1.6006289   3.914952    8.9635725   7.470232  ]
 [-3.9091115   3.800

## 5. Compilation Transformation with MAX

In [11]:
jitted_vmapped_grad_function = nb.jit(nb.vmap(nb.grad(function)), show_graph=True)
res = jitted_vmapped_grad_function(batched_input)
print("\nJitted Vectorized res:", res)

mo.graph @nabla_graph(%arg0: !mo.tensor<[3, 5], f32, cpu:0>) -> !mo.tensor<[3, 5], f32, cpu:0> attributes {argument_names = ["input0"], inputParams = #kgen<param.decls[]>, result_names = ["output0"]} {
  %0 = mo.chain.create()
  %1 = mo.constant {value = #M.dense_array<2.000000e+00> : tensor<f32>} : !mo.tensor<[], f32, cpu:0>
  %2 = mo.constant {value = #M.dense_array<1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00, 1.000000e+00> : tensor<3x5xf32>} : !mo.tensor<[3, 5], f32, cpu:0>
  %3 = rmo.mul(%2, %arg0) : (!mo.tensor<[3, 5], f32, cpu:0>, !mo.tensor<[3, 5], f32, cpu:0>) -> !mo.tensor<[3, 5], f32, cpu:0>
  %4 = rmo.mul(%3, %1) : (!mo.tensor<[3, 5], f32, cpu:0>, !mo.tensor<[], f32, cpu:0>) -> !mo.tensor<[3, 5], f32, cpu:0>
  %5 = rmo.mul(%arg0, %1) : (!mo.tensor<[3, 5], f32, cpu:0>, !mo.tensor<[], f32, cpu:0>) -> !mo.tensor<[3, 5], f32, cpu: