In [1]:
!pip install max==25.4.0 --index-url https://dl.modular.com/public/nightly/python/simple/

Looking in indexes: https://dl.modular.com/public/nightly/python/simple/
Collecting max==25.4.0
  Downloading https://dl.modular.com/public/nightly/python/max-25.4.0-py3-none-manylinux_2_34_x86_64.whl (285.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m285.0/285.0 MB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: max
Successfully installed max-25.4.0


In [2]:
!git clone https://github.com/modular/mojo-gpu-puzzles

Cloning into 'mojo-gpu-puzzles'...
remote: Enumerating objects: 6332, done.[K
remote: Counting objects: 100% (481/481), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 6332 (delta 449), reused 416 (delta 416), pack-reused 5851 (from 3)[K
Receiving objects: 100% (6332/6332), 148.64 MiB | 14.78 MiB/s, done.
Resolving deltas: 100% (3923/3923), done.


In [3]:
!curl -fsSL https://astral.sh/uv/install.sh | sh

downloading uv 0.8.14 x86_64-unknown-linux-gnu
no checksums to verify
installing to /usr/local/bin
  uv
  uvx
everything's installed!


In [4]:
import max.support.notebook

In [5]:
def save_code_to_file(text: str, filename: str):
    with open(filename, 'w', encoding='utf-8') as file:
        file.write(text)

In [17]:
mojo_code = """
from math import ceildiv
from gpu import thread_idx, block_idx, block_dim, grid_dim, barrier
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
from sys.info import sizeof
from sys.arg import argv
from testing import assert_equal

# ANCHOR: embedding_kernel_coalesced
alias THREADS_PER_BLOCK = 256


fn embedding_kernel_coalesced[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):

    # Simple 1D indexing - each thread = one output element
    global_idx = block_idx.x * block_dim.x + thread_idx.x
    total_elements = batch_size * seq_len * embed_dim

    if global_idx >= total_elements:
        return

    # Convert to (batch, seq, embed) coordinates
    # FILL IN roughly 4 lines
    batch_idx = global_idx // (seq_len * embed_dim)
    remaining = global_idx % (seq_len * embed_dim)
    seq_idx = remaining // embed_dim
    embed_idx = remaining % embed_dim

    # Get token index
    # FILL IN 1 line
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Simple, correct assignment
    # FILL IN 4 lines
    if token_idx_val >= 0 and token_idx_val < vocab_size:
      output[batch_idx, seq_idx, embed_idx] = weights[
        token_idx_val, embed_idx
      ]
    else:
      output[batch_idx, seq_idx, embed_idx] = 0

# ANCHOR_END: embedding_kernel_coalesced


# ANCHOR: embedding_kernel_2d
fn embedding_kernel_2d[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):

    # 2D grid indexing
    batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
    embed_idx = block_idx.y * block_dim.y + thread_idx.y
    total_positions = batch_size * seq_len

    if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
        return

    # Convert to (batch, seq) coordinates
    # FILL IN 2 lines
    batch_idx = batch_seq_idx // seq_len
    seq_idx = batch_seq_idx % seq_len

    # Get token index
    # FILL IN 1 line
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Assignment with 2D grid pattern
    # FILL IN 4 lines
    if token_idx_val >= 0 and token_idx_val < vocab_size:
      output[batch_idx, seq_idx, embed_idx] = weights[token_idx_val, embed_idx]
    else:
      output[batch_idx, seq_idx, embed_idx] = 0

# ANCHOR_END: embedding_kernel_2d

import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer
from gpu.host import DeviceBuffer


@compiler.register("embedding")
struct EmbeddingCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        alias indices_layout = indices_tensor.layout
        alias weights_layout = weights_tensor.layout
        alias out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[output.dtype]]](
                        output_tensor.ptr
                    ),
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 1D grid dimensions (matching kernel's flat indexing)
            total_elements = batch_size * seq_len * embed_dim
            blocks = max(1, ceildiv(total_elements, THREADS_PER_BLOCK))

            # Compile and launch optimized kernel
            compiled_kernel = gpu_ctx.compile_function[
                embedding_kernel_coalesced[
                    indices_layout,
                    weights_layout,
                    out_layout,
                    batch_size,
                    seq_len,
                    vocab_size,
                    embed_dim,
                    output.dtype,
                ]
            ]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks,),
                block_dim=(THREADS_PER_BLOCK,),
            )

        elif target == "cpu":
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)


@compiler.register("embedding_2d")
struct Embedding2DCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        alias indices_layout = indices_tensor.layout
        alias weights_layout = weights_tensor.layout
        alias out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[output.dtype]]](
                        output_tensor.ptr
                    ),
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 2D grid dimensions for non-coalesced access
            total_positions = batch_size * seq_len
            alias BLOCK_X = 16  # batch*seq dimension
            alias BLOCK_Y = 16  # embed dimension
            blocks_x = max(1, ceildiv(total_positions, BLOCK_X))
            blocks_y = max(1, ceildiv(embed_dim, BLOCK_Y))

            # Compile and launch 2D kernel
            compiled_kernel = gpu_ctx.compile_function[
                embedding_kernel_2d[
                    indices_layout,
                    weights_layout,
                    out_layout,
                    batch_size,
                    seq_len,
                    vocab_size,
                    embed_dim,
                    output.dtype,
                ]
            ]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks_x, blocks_y),
                block_dim=(BLOCK_X, BLOCK_Y),
            )

        elif target == "cpu":
            # Same CPU fallback as 1D version
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)

"""

In [18]:
save_code_to_file(mojo_code, "/content/mojo-gpu-puzzles/problems/p21/op/embedding.mojo")

In [19]:
!cd /content/mojo-gpu-puzzles && uv run poe p21

[37mPoe =>[0m [94mpython problems/p21/p21.py[0m
Puzzle 19: Mojo Embedding Kernel Comparison

Configuration: B=8, L=512, V=10000, E=512
------------------------------------------------------------
Testing Correctness...
   1D Coalesced - Max difference: 0.00e+00
   2D Non-coalesced - Max difference: 0.00e+00
   ✅ Both implementations CORRECT

Benchmarking Mojo Kernels...

Performance Results:
   1D Coalesced:     0.401 ms
   2D Non-coalesced: 0.426 ms
   1D is 1.06x faster than 2D

Key Learning Points:
• Compare different GPU kernel implementations
• 1D vs 2D grid patterns have different memory access
• Coalesced memory access should be faster
• Grid configuration affects GPU utilization
