In [1]:
!pip install max --index-url https://dl.modular.com/public/nightly/python/simple/

Looking in indexes: https://dl.modular.com/public/nightly/python/simple/
Collecting max
  Downloading https://dl.modular.com/public/nightly/python/max-25.4.0-py3-none-manylinux_2_34_x86_64.whl (285.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m285.0/285.0 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: max
Successfully installed max-25.4.0


In [2]:
!git clone https://github.com/modular/mojo-gpu-puzzles

Cloning into 'mojo-gpu-puzzles'...
remote: Enumerating objects: 5684, done.[K
remote: Counting objects: 100% (803/803), done.[K
remote: Compressing objects: 100% (258/258), done.[K
remote: Total 5684 (delta 647), reused 629 (delta 519), pack-reused 4881 (from 2)[K
Receiving objects: 100% (5684/5684), 101.45 MiB | 27.38 MiB/s, done.
Resolving deltas: 100% (3559/3559), done.


In [3]:
!curl -fsSL https://astral.sh/uv/install.sh | sh

downloading uv 0.8.3 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [4]:
import max.support.notebook

In [5]:
def save_code_to_file(text: str, filename: str):
    with open(filename, 'w', encoding='utf-8') as file:
        file.write(text)

In [6]:
mojo_code = """
from gpu import thread_idx, block_idx, block_dim, barrier
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
from sys import sizeof, argv
from math import log2
from testing import assert_equal

# ANCHOR: prefix_sum_simple
alias TPB = 8
alias SIZE = 8
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
alias layout = Layout.row_major(SIZE)


fn prefix_sum_simple[
    layout: Layout
](
    output: LayoutTensor[mut=False, dtype, layout],
    a: LayoutTensor[mut=False, dtype, layout],
    size: Int,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # FILL ME IN (roughly 18 lines) ###################### SIMPLE ######################
    shared = tb[dtype]().row_major[TPB]().shared().alloc()
    if global_i < SIZE:
      shared[local_i] = a[global_i]

    barrier()

    ####### logic here ######

    offset = 1
    for i in range(Int(log2(Scalar[dtype](TPB)))):
      var current_val: output.element_type = 0
      if local_i >= offset and local_i < size:
        current_val = shared[local_i - offset]

      barrier()
      if local_i >= offset and local_i < size:
        shared[local_i] += current_val

      barrier()
      offset *= 2

    if global_i < size:
      output[global_i]  = shared[local_i]


# ANCHOR_END: prefix_sum_simple

# ANCHOR: prefix_sum_complete
alias SIZE_2 = 15
alias BLOCKS_PER_GRID_2 = (2, 1)
alias THREADS_PER_BLOCK_2 = (TPB, 1)
alias EXTENDED_SIZE = SIZE_2 + 2  # up to 2 blocks
alias extended_layout = Layout.row_major(EXTENDED_SIZE)


# Kernel 1: Compute local prefix sums and store block sums in out
fn prefix_sum_local_phase[
    out_layout: Layout, in_layout: Layout
](
    output: LayoutTensor[mut=False, dtype, out_layout],
    a: LayoutTensor[mut=False, dtype, in_layout],
    size: Int,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    # FILL ME IN (roughly 20 lines) ################## COMPLETE ###################
    shared = tb[dtype]().row_major[TPB]().shared().alloc()
    if global_i < size:
      shared[local_i] = a[global_i]

    barrier()

    ##### logic here #####
    offset = 1
    for i in range(Int(log2(Scalar[dtype](TPB)))):
      var current_val: output.element_type = 0
      if local_i >= offset and local_i < TPB:
        current_val = shared[local_i - offset]

      barrier()

      if local_i >= offset and local_i < TPB:
        shared[local_i] += current_val

      barrier()
      offset *= 2

    if global_i < size:
      output[global_i] = shared[local_i]

    if local_i == TPB -1:
      output[size + block_idx.x] = shared[local_i]


# Kernel 2: Add block sums to their respective blocks
fn prefix_sum_block_sum_phase[
    layout: Layout
](output: LayoutTensor[mut=False, dtype, layout], size: Int):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    # FILL ME IN (roughly 3 lines)
    if block_idx.x > 0 and global_i < size:
      prev_block_sum  = output[size + block_idx.x  - 1]
      output[global_i] += prev_block_sum

# ANCHOR_END: prefix_sum_complete


def main():
    with DeviceContext() as ctx:
        if len(argv()) != 2 or argv()[1] not in [
            "--simple",
            "--block-boundary",
        ]:
            raise Error(
                "Expected one command-line argument: '--simple' or"
                " '--block-boundary'"
            )

        use_simple = argv()[1] == "--simple"

        size = SIZE if use_simple else SIZE_2
        num_blocks = (size + TPB - 1) // TPB

        if not use_simple and num_blocks > EXTENDED_SIZE - SIZE_2:
            raise Error("Extended buffer too small for the number of blocks")

        buffer_size = size if use_simple else EXTENDED_SIZE
        out = ctx.enqueue_create_buffer[dtype](buffer_size).enqueue_fill(0)
        a = ctx.enqueue_create_buffer[dtype](size).enqueue_fill(0)

        with a.map_to_host() as a_host:
            for i in range(size):
                a_host[i] = i

        a_tensor = LayoutTensor[mut=False, dtype, layout](a.unsafe_ptr())

        if use_simple:
            out_tensor = LayoutTensor[mut=False, dtype, layout](
                out.unsafe_ptr()
            )

            ctx.enqueue_function[prefix_sum_simple[layout]](
                out_tensor,
                a_tensor,
                size,
                grid_dim=BLOCKS_PER_GRID,
                block_dim=THREADS_PER_BLOCK,
            )
        else:
            var out_tensor = LayoutTensor[mut=False, dtype, extended_layout](
                out.unsafe_ptr()
            )

            # ANCHOR: prefix_sum_complete_block_level_sync
            # Phase 1: Local prefix sums
            ctx.enqueue_function[
                prefix_sum_local_phase[extended_layout, extended_layout]
            ](
                out_tensor,
                a_tensor,
                size,
                grid_dim=BLOCKS_PER_GRID_2,
                block_dim=THREADS_PER_BLOCK_2,
            )

            # Wait for all `blocks` to complete with using host `ctx.synchronize()`
            # Note this is in contrast with using `barrier()` in the kernel
            # which is a synchronization point for all threads in the same block and not across blocks.
            ctx.synchronize()

            # Phase 2: Add block sums
            ctx.enqueue_function[prefix_sum_block_sum_phase[extended_layout]](
                out_tensor,
                size,
                grid_dim=BLOCKS_PER_GRID_2,
                block_dim=THREADS_PER_BLOCK_2,
            )
            # ANCHOR_END: prefix_sum_complete_block_level_sync

        # Verify results for both cases
        expected = ctx.enqueue_create_host_buffer[dtype](size).enqueue_fill(0)
        ctx.synchronize()

        with a.map_to_host() as a_host:
            expected[0] = a_host[0]
            for i in range(1, size):
                expected[i] = expected[i - 1] + a_host[i]

        with out.map_to_host() as out_host:
            if not use_simple:
                print(
                    "Note: we print the extended buffer here, but we only need"
                    " to print the first `size` elements"
                )

            print("out:", out_host)
            print("expected:", expected)
            # Here we need to use the size of the original array, not the extended one
            size = size if use_simple else SIZE_2
            for i in range(size):
                assert_equal(out_host[i], expected[i])
"""

In [7]:
save_code_to_file(mojo_code, "/content/mojo-gpu-puzzles/problems/p12/p12.mojo")

In [9]:
!cd /content/mojo-gpu-puzzles && uv run poe p12 --block-boundary

[37mPoe =>[0m [94mmojo problems/p12/p12.mojo --block-boundary[0m
Note: we print the extended buffer here, but we only need to print the first `size` elements
out: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 28.0, 77.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0])
