In [1]:
!pip install max --index-url https://dl.modular.com/public/nightly/python/simple/

Looking in indexes: https://dl.modular.com/public/nightly/python/simple/
Collecting max
  Downloading https://dl.modular.com/public/nightly/python/max-25.4.0.dev2025060805-py3-none-manylinux_2_34_x86_64.whl (285.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m285.2/285.2 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: max
Successfully installed max-25.4.0.dev2025060805


In [2]:
!git clone https://github.com/modular/mojo-gpu-puzzles

Cloning into 'mojo-gpu-puzzles'...
remote: Enumerating objects: 4059, done.[K
remote: Counting objects: 100% (166/166), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 4059 (delta 151), reused 127 (delta 127), pack-reused 3893 (from 2)[K
Receiving objects: 100% (4059/4059), 94.69 MiB | 16.13 MiB/s, done.
Resolving deltas: 100% (2489/2489), done.


In [3]:
!curl -fsSL https://astral.sh/uv/install.sh | sh

downloading uv 0.7.12 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [4]:
import max.support.notebook

In [5]:
def save_code_to_file(text: str, filename: str):
    with open(filename, 'w', encoding='utf-8') as file:
        file.write(text)

In [6]:
mojo_code = """
from memory import UnsafePointer
from gpu import thread_idx, block_dim, block_idx
from gpu.host import DeviceContext, HostBuffer
from testing import assert_equal

# ANCHOR: broadcast_add
alias SIZE = 2
alias BLOCKS_PER_GRID = 1
alias THREADS_PER_BLOCK = (3, 3)
alias dtype = DType.float32


fn broadcast_add(
    output: UnsafePointer[Scalar[dtype]],
    a: UnsafePointer[Scalar[dtype]],
    b: UnsafePointer[Scalar[dtype]],
    size: Int,
):
    row = thread_idx.y
    col = thread_idx.x
    # FILL ME IN (roughly 2 lines)
    if row < size and col < size:
      output[row*size+col] = a[col] + b[row]


# ANCHOR_END: broadcast_add
def main():
    with DeviceContext() as ctx:
        out = ctx.enqueue_create_buffer[dtype](SIZE * SIZE).enqueue_fill(0)
        expected = ctx.enqueue_create_host_buffer[dtype](
            SIZE * SIZE
        ).enqueue_fill(0)
        a = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
        b = ctx.enqueue_create_buffer[dtype](SIZE).enqueue_fill(0)
        with a.map_to_host() as a_host, b.map_to_host() as b_host:
            for i in range(SIZE):
                a_host[i] = i
                b_host[i] = i

            for i in range(SIZE):
                for j in range(SIZE):
                    expected[i * SIZE + j] = a_host[i] + b_host[j]

        ctx.enqueue_function[broadcast_add](
            out.unsafe_ptr(),
            a.unsafe_ptr(),
            b.unsafe_ptr(),
            SIZE,
            grid_dim=BLOCKS_PER_GRID,
            block_dim=THREADS_PER_BLOCK,
        )

        ctx.synchronize()

        with out.map_to_host() as out_host:
            print("out:", out_host)
            print("expected:", expected)
            for i in range(SIZE):
                for j in range(SIZE):
                    assert_equal(out_host[i * SIZE + j], expected[i * SIZE + j])


"""

In [7]:
save_code_to_file(mojo_code, "/content/mojo-gpu-puzzles/problems/p05/p05.mojo")

In [8]:
!cd /content/mojo-gpu-puzzles && uv run poe p05

Using CPython 3.11.13 interpreter at: [36m/usr/bin/python3[39m
Creating virtual environment at: [36m.venv[39m
[2K[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2mtqdm                [0m [32m[2m------------------------------[0m[0m     0 B/76.70 KiB
[2K[2A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2mtqdm                [0m [32m[2m------------------------------[0m[0m     0 B/76.70 KiB
[2K[2A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2mtqdm                [0m [32m[2m------------------------------[0m[0m     0 B/76.70 KiB
[2K[2A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2mtqdm                [0m [32m[2m------------------------------[0m[0m     0 B/76.70 KiB
[2mclick               [0m [32m[2m------------------------------[0m[0m     0 B/99.82 KiB
[2K[3A[37m⠙[0m [2mPreparing packages...[0m (0/7)
[2mtqdm                [0m 