In [1]:
!pip install max==25.4.0 --index-url https://dl.modular.com/public/nightly/python/simple/

Looking in indexes: https://dl.modular.com/public/nightly/python/simple/
Collecting max==25.4.0
  Downloading https://dl.modular.com/public/nightly/python/max-25.4.0-py3-none-manylinux_2_34_x86_64.whl (285.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m285.0/285.0 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: max
Successfully installed max-25.4.0


In [2]:
!git clone https://github.com/modular/mojo-gpu-puzzles

Cloning into 'mojo-gpu-puzzles'...
remote: Enumerating objects: 6079, done.[K
remote: Counting objects: 100% (535/535), done.[K
remote: Compressing objects: 100% (135/135), done.[K
remote: Total 6079 (delta 464), reused 434 (delta 399), pack-reused 5544 (from 3)[K
Receiving objects: 100% (6079/6079), 146.56 MiB | 15.89 MiB/s, done.
Resolving deltas: 100% (3756/3756), done.


In [3]:
!curl -fsSL https://astral.sh/uv/install.sh | sh

downloading uv 0.8.5 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [4]:
import max.support.notebook

In [5]:
def save_code_to_file(text: str, filename: str):
    with open(filename, 'w', encoding='utf-8') as file:
        file.write(text)

In [22]:
mojo_code = """
from sys import sizeof
from testing import assert_equal
from gpu.host import DeviceContext

# ANCHOR: axis_sum
from gpu import thread_idx, block_idx, block_dim, barrier
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb


alias TPB = 8
alias BATCH = 4
alias SIZE = 6
alias BLOCKS_PER_GRID = (1, BATCH)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
alias in_layout = Layout.row_major(BATCH, SIZE)
alias out_layout = Layout.row_major(BATCH, 1)


fn axis_sum[
    in_layout: Layout, out_layout: Layout
](
    output: LayoutTensor[mut=False, dtype, out_layout],
    a: LayoutTensor[mut=False, dtype, in_layout],
    size: Int,
):
    global_i = block_dim.x * block_idx.x + thread_idx.x
    local_i = thread_idx.x
    batch = block_idx.y

    # FILL ME IN (roughly 15 lines)

    cache = tb[dtype]().row_major[TPB]().shared().alloc()

    if local_i < size:
      cache[local_i] = a[batch, local_i]
    else:
      cache[local_i] = 0

    barrier()

    # 1. first sum per row -> cache , barrier()
    # 2. copy cache

    stride = TPB // 2
    while stride > 0:
      var temp_val: output.element_type = 0
      if local_i < stride:
        temp_val = cache[local_i + stride]

      barrier()

      if local_i < stride:
        cache[local_i] += temp_val

      barrier()

      stride //= 2

    if local_i == 0:
      output[batch, 0] = cache[0]

# ANCHOR_END: axis_sum


def main():
    with DeviceContext() as ctx:
        out = ctx.enqueue_create_buffer[dtype](BATCH).enqueue_fill(0)
        inp = ctx.enqueue_create_buffer[dtype](BATCH * SIZE).enqueue_fill(0)
        with inp.map_to_host() as inp_host:
            for row in range(BATCH):
                for col in range(SIZE):
                    inp_host[row * SIZE + col] = row * SIZE + col

        out_tensor = LayoutTensor[mut=False, dtype, out_layout](
            out.unsafe_ptr()
        )
        inp_tensor = LayoutTensor[mut=False, dtype, in_layout](inp.unsafe_ptr())

        ctx.enqueue_function[axis_sum[in_layout, out_layout]](
            out_tensor,
            inp_tensor,
            SIZE,
            grid_dim=BLOCKS_PER_GRID,
            block_dim=THREADS_PER_BLOCK,
        )

        expected = ctx.enqueue_create_host_buffer[dtype](BATCH).enqueue_fill(0)
        with inp.map_to_host() as inp_host:
            for row in range(BATCH):
                for col in range(SIZE):
                    expected[row] += inp_host[row * SIZE + col]

        ctx.synchronize()

        with out.map_to_host() as out_host:
            print("out:", out)
            print("expected:", expected)
            for i in range(BATCH):
                assert_equal(out_host[i], expected[i])

"""

In [23]:
save_code_to_file(mojo_code, "/content/mojo-gpu-puzzles/problems/p15/p15.mojo")

In [24]:
!cd /content/mojo-gpu-puzzles && uv run poe p15

[37mPoe =>[0m [94mmojo problems/p15/p15.mojo[0m
out: DeviceBuffer([15.0, 51.0, 87.0, 123.0])
expected: HostBuffer([15.0, 51.0, 87.0, 123.0])
