In [2]:
%%capture
# Only need to run the first time.
# Works with latest triton. Sorry, this takes a minute to install.
!pip install jaxtyping
!pip install git+https://github.com/Deep-Learning-Profiling-Tools/triton-viz
!wget "https://dl.cloudsmith.io/public/test-wha/triton-puzzles/raw/files/triton-3.0.0-cp310-cp310-linux_x86_64.whl"
!mv triton-3.0.0*.whl triton-3.0.0-cp310-cp310-linux_x86_64.whl
!pip install triton-3.0.0-cp310-cp310-linux_x86_64.whl
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia


In [3]:
import torch
import triton
from torch import Tensor
import triton.language as tl
import jaxtyping
from jaxtyping import Float32, Int32

In [4]:
# @title Setup

import triton_viz
import inspect
from triton_viz.interpreter import record_builder

def test(puzzle, puzzle_spec, nelem={}, B={"B0": 32}, viz=True):
    B = dict(B)
    if "N1" in nelem and "B1" not in B:
        B["B1"] = 32
    if "N2" in nelem and "B2" not in B:
        B["B2"] = 32

    triton_viz.interpreter.record_builder.reset()
    torch.manual_seed(0)
    signature = inspect.signature(puzzle_spec)
    args = {}
    for n, p in signature.parameters.items():
        print(p)
        args[n + "_ptr"] = ([d.size for d in p.annotation.dims], p)
    args["z_ptr"] = ([d.size for d in signature.return_annotation.dims], None)

    tt_args = []
    for k, (v, t) in args.items():
        tt_args.append(torch.rand(*v))
        if t is not None and t.annotation.dtypes[0] == "int32":
            tt_args[-1] = torch.randint(-100000, 100000, v)
    grid = lambda meta: (triton.cdiv(nelem["N0"], meta["B0"]),
                         triton.cdiv(nelem.get("N1", 1), meta.get("B1", 1)),
                         triton.cdiv(nelem.get("N2", 1), meta.get("B2", 1)))

    #for k, v in args.items():
    #    print(k, v)
    triton_viz.trace(puzzle)[grid](*tt_args, **B, **nelem)
    z = tt_args[-1]
    tt_args = tt_args[:-1]
    z_ = puzzle_spec(*tt_args)
    match = torch.allclose(z, z_, rtol=1e-3, atol=1e-3)
    print("Results match:",  match)
    if viz:
        failures = triton_viz.launch()
    if not match or failures:
        print("Invalid Access:", failures)
        print("Yours:", z)
        print("Spec:", z_)
        print(torch.isclose(z, z_))
        return
    # PUPPIES!
    from IPython.display import HTML
    import random
    print("Correct!")
    pups = [
    "2m78jPG",
    "pn1e9TO",
    "MQCIwzT",
    "udLK6FS",
    "ZNem5o3",
    "DS2IZ6K",
    "aydRUz8",
    "MVUdQYK",
    "kLvno0p",
    "wScLiVz",
    "Z0TII8i",
    "F1SChho",
    "9hRi2jN",
    "lvzRF3W",
    "fqHxOGI",
    "1xeUYme",
    "6tVqKyM",
    "CCxZ6Wr",
    "lMW0OPQ",
    "wHVpHVG",
    "Wj2PGRl",
    "HlaTE8H",
    "k5jALH0",
    "3V37Hqr",
    "Eq2uMTA",
    "Vy9JShx",
    "g9I2ZmK",
    "Nu4RH7f",
    "sWp0Dqd",
    "bRKfspn",
    "qawCMl5",
    "2F6j2B4",
    "fiJxCVA",
    "pCAIlxD",
    "zJx2skh",
    "2Gdl1u7",
    "aJJAY4c",
    "ros6RLC",
    "DKLBJh7",
    "eyxH0Wc",
    "rJEkEw4"]
    return HTML("""
    <video alt="test" controls autoplay=1>
        <source src="https://openpuppies.com/mp4/%s.mp4"  type="video/mp4">
    </video>
    """%(random.sample(pups, 1)[0]))

## Puzzle 1: Constant Add

Add a constant to a vector. Uses one program id axis. Block size `B0` is always the same as vector `x` with length `N0`.


$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$


In [None]:
def add_spec(x: Float32[Tensor, "32"]) -> Float32[Tensor, "32"]:
    "This is the spec that you should implement. Uses typing to define sizes."
    return x + 10.

@triton.jit
def add_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    range = tl.arange(0, B0)
    x = tl.load(x_ptr + range)
    z = x + 10
    tl.store(z_ptr + range, z)

test(add_kernel, add_spec, nelem={"N0": 32}, viz=True)

## Puzzle 2: Constant Add Block

Add a constant to a vector. Uses one program block. Block size `B0` is now smaller than the shape vector `x` which is `N0`.


$$z_i = 10 + x_i \text{ for } i = 1\ldots N_0$$



In [None]:
def add2_spec(x: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    return x + 10.

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid = tl.program_id(axis=0)
    offsets = pid*B0 + tl.arange(0, B0)
    mask = offsets < N0
    x = tl.load(x_ptr + offsets, mask=mask)
    z = x + 10
    tl.store(z_ptr + offsets, z, mask=mask)

test(add_mask2_kernel, add2_spec, nelem={"N0": 200})

## Puzzle 3: Outer Vector Add

Add two vectors.

Uses one program block axis. Block size `B0` is always the same as vector `x` length `N0`.
Block size `B1` is always the same as vector `y` length `N1`.


$$z_{i, j} = x_i + y_j\text{ for } i = 1\ldots B_0,\ j = 1\ldots B_1$$


In [None]:
def add_vec_spec(x: Float32[Tensor, "32"], y: Float32[Tensor, "32"]) -> Float32[Tensor, "32 32"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    x_offsets = tl.arange(0, B0)
    y_offsets = tl.arange(0, B1)
    z_offsets = x_offsets[None, :] + y_offsets[:, None]*B0
    x = tl.load(x_ptr + x_offsets)
    y = tl.load(y_ptr + y_offsets)
    z = x[None, :] + y[:, None]
    tl.store(z_ptr + z_offsets, z)

test(add_vec_kernel, add_vec_spec, nelem={"N0": 32, "N1": 32})

## Puzzle 4: Outer Vector Add Block

Add a row vector to a column vector.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{i, j} = x_i + y_j\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$


In [None]:
def add_vec_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return x[None, :] + y[:, None]

@triton.jit
def add_vec_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    x_offsets = (pid_0*B0 + tl.arange(0, B0))
    y_offsets = (pid_1*B1 + tl.arange(0, B1))
    z_offsets = x_offsets[None, :] + y_offsets[:, None] * N0

    x = tl.load(x_ptr + x_offsets, mask=(x_offsets < N0))
    y = tl.load(y_ptr + y_offsets, mask=(y_offsets < N1))
    tl.store(z_ptr + z_offsets, x[None, :] + y[:, None], mask = (x_offsets[None, :] < N0) & (y_offsets[:, None] < N1))

test(add_vec_block_kernel, add_vec_block_spec, nelem={"N0": 100, "N1": 90})

## Puzzle 5: Fused Outer Multiplication

Multiply a row vector to a column vector and take a relu.

Uses two program block axes. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`.

$$z_{i, j} = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$



In [None]:
def mul_relu_block_spec(x: Float32[Tensor, "100"], y: Float32[Tensor, "90"]) -> Float32[Tensor, "90 100"]:
    return torch.relu(x[None, :] * y[:, None])

@triton.jit
def mul_relu_block_kernel(x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    x_off = (pid_0*B0 + tl.arange(0, B0))
    y_off = (pid_1*B1 + tl.arange(0, B1))
    z_off = x_off[None, :] + y_off[:, None] * N0

    x = tl.load(x_ptr + x_off, mask=(x_off < N0))
    y = tl.load(y_ptr + y_off, mask=(y_off < N1))
    z = x[None, :] * y[:, None]
    z = tl.where(z>0, z, 0)
    tl.store(z_ptr + z_off, z, mask=(x_off[None, :] < N0)&(y_off[:, None]<N1))
    return

test(mul_relu_block_kernel, mul_relu_block_spec, nelem={"N0": 100, "N1": 90})

## Puzzle 6: Fused Outer Multiplication - Backwards


Backwards of a function that multiplies a matrix with a row vector and take a relu.

Uses two program blocks. Block size `B0` is always less than the vector `x` length `N0`.
Block size `B1` is always less than vector `y` length `N1`. Chain rule backward `dz`
is of shape `N0`

$$f(x, y) = \text{relu}(x_i \times y_j)\text{ for } i = 1\ldots N_0,\ j = 1\ldots N_1$$

$$dx_{i, j} = f_x'(x, y)_{i, j} \times dz_{i,j}$$

In [None]:
def mul_relu_block_back_spec(x: Float32[Tensor, "90 100"], y: Float32[Tensor, "90"],
                             dz: Float32[Tensor, "90 100"]) -> Float32[Tensor, "90 100"]:
    x = x.clone()
    y = y.clone()
    x = x.requires_grad_(True)
    y = y.requires_grad_(True)
    z = torch.relu(x * y[:, None])
    z.backward(dz)
    dx = x.grad
    return dx

@triton.jit
def mul_relu_block_back_kernel(x_ptr, y_ptr, dz_ptr, dx_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    x_off_1 = (pid_0*B0 + tl.arange(0, B0))[None, :]
    x_off_2 = (pid_1*B1 + tl.arange(0, B1))[:, None]

    y_off = pid_1*B1 + tl.arange(0, B1)
    dz_off_1 = (pid_0*B0 + tl.arange(0, B0))[None, :]
    dz_off_2 = (pid_1*B1 + tl.arange(0, B1))[:, None]

    x = tl.load(x_ptr + x_off_1 + x_off_2*N0, mask=(x_off_1 < N0)&(x_off_2 < N1))
    y = tl.load(y_ptr + y_off, mask=(y_off < N1))
    dz = tl.load(dz_ptr + dz_off_1 + dz_off_2*N0, mask=(dz_off_1 < N0)&(dz_off_2 < N1))

    relu_prime = tl.where(x*y[:, None] < 0, 0, 1) * y[:, None]
    dx = relu_prime * dz
    tl.store(dx_ptr + x_off_1 + x_off_2*N0, dx, mask=(x_off_1 < N0)&(x_off_2 < N1))
    return

test(mul_relu_block_back_kernel, mul_relu_block_back_spec, nelem={"N0": 100, "N1": 90})


## Puzzle 7: Long Sum

Sum of a batch of numbers.

Uses one program blocks. Block size `B0` represents a range of batches of  `x` of length `N0`.
Each element is of length `T`. Process it `B1 < T` elements at a time.  

$$z_{i} = \sum^{T}_j x_{i,j} =  \text{ for } i = 1\ldots N_0$$

Hint: You will need a for loop for this problem. These work and look the same as in Python.

In [None]:
def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
    return x.sum(1)

@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    sum = tl.full((B0,), 0.0, dtype=tl.float32)

    x_off = (pid_0*B0 + tl.arange(0, B0))[:, None]

    offsets = tl.arange(0, B1)[None, :]
    for i in range(0, T, B1):
      x = tl.load(x_ptr + x_off*T +offsets, mask=(x_off<N0)&(offsets<T), other=0.0)
      sum += tl.sum(x, axis=1, keep_dims=True)
      offsets += B1

    tl.store(z_ptr + x_off, sum, mask=(x_off<N0))
    return

test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200})


## Puzzle 8: Long Softmax


Softmax of a batch of logits.

Uses one program block axis. Block size `B0` represents the batch of `x` of length `N0`.
Block logit length `T`.   Process it `B1 < T` elements at a time.  

$$z_{i, j} = \text{softmax}(x_{i,1} \ldots x_{i, T}) \text{ for } i = 1\ldots N_0$$

Note softmax needs to be computed in numerically stable form as in Python. In addition in Triton they recommend not using `exp` but instead using `exp2`. You need the identity

$$\exp(x) = 2^{\log_2(e) x}$$

Advanced: there one way to do this with 3 loops. You can also do it with 2 loops if you are clever. Hint: you will find this identity useful:

$$\exp(x_i - m) =  \exp(x_i - m/2 - m/2) = \exp(x_i - m/ 2) /  \exp(m/2) $$

In [None]:
def softmax_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4 200"]:
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    return x_exp / x_exp.sum(1, keepdim=True)

@triton.jit
def softmax_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    log2_e = 1.44269504

    _max = tl.full((B0,), -float('inf'), dtype=tl.float32)
    denom = tl.full((B0,), 0.0, dtype=tl.float32)

    x_off = (pid_0*B0 + tl.arange(0, B0))[:, None]
    x_ptrs = x_ptr + x_off*T
    # loop once through blocks and update the denom as it goes along
    offsets = tl.arange(0, B1)[None, :]
    for i in range(0, T, B1):
      x = tl.load(x_ptrs + offsets, mask=(x_off < N0) & (offsets<T), other=-float('inf'))
      m = tl.max(x, axis=1, keep_dims=True)
      new_max = tl.where(m > _max, m, _max)
      x = tl.exp2(log2_e * (x - new_max))
      denom = tl.exp2(log2_e * (_max - new_max)) * denom + tl.sum(x, axis=1, keep_dims=True)
      _max = new_max

      offsets += B1

    offsets = tl.arange(0, B1)[None, :]
    for i in range(0, T, B1):
      x = tl.load(x_ptrs + offsets, mask=(x_off < N0) & (offsets<T), other=-float('inf'))
      x = tl.exp2(log2_e*(x - _max))
      z = x / denom
      tl.store(z_ptr + x_off*T + offsets, z, mask=(x_off < N0) & (offsets<T))
      offsets += B1
    return



test(softmax_kernel, softmax_spec, B={"B0": 1, "B1":32}, nelem={"N0": 4, "N1": 32, "T": 200})


## Puzzle 9: Simple FlashAttention

A scalar version of FlashAttention.

Uses zero programs. Block size `B0` represents `k` of length `N0`.
Block size `B0` represents `q` of length `N0`. Block size `B0` represents `v` of length `N0`.
Sequence length is `T`. Process it `B1 < T` elements at a time.  

$$z_{i} = \sum_{j} \text{softmax}(q_1 k_1, \ldots, q_T k_T)_j v_{j} \text{ for } i = 1\ldots N_0$$

This can be done in 1 loop using a similar trick from the last puzzle.

In [None]:
def flashatt_spec(q: Float32[Tensor, "200"], k: Float32[Tensor, "200"], v: Float32[Tensor, "200"]) -> Float32[Tensor, "200"]:
    x = q[:, None] * k[None, :]
    x_max = x.max(1, keepdim=True)[0]
    x = x - x_max
    x_exp = x.exp()
    soft =  x_exp  / x_exp.sum(1, keepdim=True)
    return (v[None, :] * soft).sum(1)

@triton.jit
def flashatt_kernel(q_ptr, k_ptr, v_ptr, z_ptr, N0, T, B0: tl.constexpr):
    log2_e = 1.44269504
    B1 = 32

    num_blocks = (T // B1) + 1
    offsets = tl.arange(0, B1)

    for i in range(0, num_blocks):
      q = tl.load(q_ptr + offsets + i*B1, mask=offsets+i*B1 < N0, other=-float('inf'))

      maxes = tl.full((B1,), -float('inf'), tl.float32)
      denoms = tl.full((B1,), 0.0, dtype=tl.float32)
      z = tl.full((B1,), 0.0, dtype=tl.float32)

      for j in range(0, num_blocks):
        k = tl.load(k_ptr + offsets + j*B1, mask=offsets+j*B1 < N0, other=-float('inf'))
        v = tl.load(v_ptr + offsets + j*B1, mask=offsets+j*B1 < N0, other=0.0)

        s = q[:, None] * k[None, :]
        row_maxes = tl.max(s, axis=1)
        p = tl.exp2(log2_e * (s - tl.max(s, axis=1, keep_dims=True)))
        new_maxes = tl.where(row_maxes > maxes, row_maxes, maxes)

        new_denoms = denoms*tl.exp2(log2_e*(maxes-new_maxes)) + tl.sum(p, axis=1)*tl.exp2(log2_e*(row_maxes-new_maxes))

        new_z = p * v[None, :]
        z = (denoms * z * tl.exp2(log2_e*(maxes-new_maxes)) + tl.sum(new_z, axis=1)*tl.exp2(log2_e*(row_maxes-new_maxes))) / new_denoms

        maxes = new_maxes
        denoms = new_denoms

      tl.store(z_ptr + offsets + i*B1, z, mask=offsets+i*B1<N0)


    return


test(flashatt_kernel, flashatt_spec, B={"B0":200},
     nelem={"N0": 200, "T": 200})

## Puzzle 10: Two Dimensional Convolution

A batched 2D convolution.

Uses one program id axis. Block size `B0` represent the batches to process out of `N0`.
Image `x` is size is `H` by `W` with only 1 channel, and kernel `k` is size `KH` by `KW`.

$$z_{i, j, k} = \sum_{oj, ok} k_{oj,ok} \times x_{i,j + oj, k + ok} \text{ for } i = 1\ldots N_0$$



In [None]:
def conv2d_spec(x: Float32[Tensor, "4 8 8"], k: Float32[Tensor, "4 4"]) -> Float32[Tensor, "4 8 8"]:
    z = torch.zeros(4, 8, 8)
    x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0)
    print(x.shape, k.shape)
    for i in range(8):
        for j in range(8):
            z[:, i, j] = (k[None, :, :] * x[:, i: i+4, j: j + 4]).sum(1).sum(1)
    return z


@triton.jit
def conv2d_kernel(x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr):
    pid_0 = tl.program_id(0)
    k_ptrs = k_ptr + tl.arange(0, KW)[None, :] + KW*tl.arange(0, KH)[:, None]
    k = tl.load(k_ptrs)

    x_base = (pid_0*B0 + tl.arange(0, B0))[:, None, None]
    x_base = x_base*H*W

    for idx in range(H*W):
      i = idx % W
      j = idx //  W
      x_off = (i+tl.arange(0, KW))[None, None, :] + W*(j+tl.arange(0, KH))[None, :, None]
      mask = ((i+tl.arange(0, KW))[None, None, :] < W) & ((j+tl.arange(0, KH))[None, :, None] < H)
      x = tl.load(x_ptr+x_base+x_off, mask=mask, other=0.0)
      mult = (k[None, :, :] * x[:, :, :])
      z = tl.sum(tl.sum(mult, axis=2, keep_dims=True), axis=1, keep_dims=True)

      z_off = (i+tl.arange(0, 1))[None, None, :] + W*(j+tl.arange(0, 1))[None, :, None]


      tl.store(z_ptr+x_base+z_off, z)

    return

test(conv2d_kernel, conv2d_spec, B={"B0": 1}, nelem={"N0": 4, "H": 8, "W": 8, "KH": 4, "KW": 4})

## Puzzle 11: Matrix Multiplication

A blocked matrix multiplication.

Uses three program id axes. Block size `B2` represent the batches to process out of `N2`.
Block size `B0` represent the rows of `x` to process out of `N0`. Block size `B1` represent the cols of `y` to process out of `N1`. The middle shape is `MID`.

$$z_{i, j, k} = \sum_{k} x_{i,j, l} \times y_{i, l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

You are allowed to use `tl.dot` which computes a smaller mat mul.

Hint: the main trick is that you can split a matmul into smaller parts.

$$z_{i, j, k} = \sum_{k=1}^{K/2} x_{i,j, l} \times y_{i, l, k} +  \sum_{k=K/2}^{K} x_{i,j, l} \times y_{i, l, k} $$


In [None]:
def dot_spec(x: Float32[Tensor, "4 32 32"], y: Float32[Tensor, "4 32 32"]) -> Float32[Tensor, "4 32 32"]:
    return x @ y

@triton.jit
def dot_kernel(x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    pid_2 = tl.program_id(2)

    acc = tl.full((B2, B0, B1), 0.0, tl.float32)

    batch_dims = (pid_2*B2 + tl.arange(0, B2))[:, None, None]
    y_row_dims = (tl.arange(0, B0))[None, :, None]
    y_col_dims = (pid_1*B1 + tl.arange(0, B1))[None, None, :]
    x_row_dims = (pid_0*B0 + tl.arange(0, B0))[None, :, None]
    x_col_dims = (tl.arange(0, B1))[None, None, :]

    x_base = x_row_dims*MID + x_col_dims
    y_base = y_row_dims*N1 + y_col_dims
    z_base = batch_dims*N1*N0 + x_row_dims*N1 + y_col_dims

    off_mid = tl.arange(0, B1)


    num = MID // B1
    for k in range(num):
      x = tl.load(x_ptr + batch_dims*N1*N0 + x_base, mask=off_mid[None, None, :] < (B1 + k*B1), other=0.0)
      y = tl.load(y_ptr + batch_dims*N1*N0 + y_base, mask=off_mid[None, :, None] < (B1 + k*B1), other=0.0)

      acc += tl.dot(x, y)
      x_base += B1
      y_base += B1*MID

    z_mask = (x_row_dims < N0) & (y_col_dims < N1) & (batch_dims < N2)
    tl.store(z_ptr + z_base, acc, mask=z_mask)
    return


test(dot_kernel, dot_spec, B={"B0": 16, "B1": 16, "B2": 1}, nelem={"N0": 32, "N1": 32, "N2": 4, "MID": 32})


## Puzzle 12: Quantized Matrix Mult

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.

For this problem our `weight` will be stored in 4 bits. We can store `FPINT` of these in a 32 bit integer. In addition for every `group` weights in order we will store 1 `scale` float value and 1 `shift` 4 bit value. We store these for the column of weight. The `activation`s are stored separately in standard floats.

Mathematically it looks like.

$$z_{j, k} = \sum_{k} sc_{j, l/g} (w_{j, l} - sh_{j, l/g}) \times y_{l, k} \text{ for } i = 1\ldots N_2, j = 1\ldots N_0, k = 1\ldots N_1$$

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.




In [None]:

FPINT = 32 // 4
GROUP = 8

def quant_dot_spec(scale : Float32[Tensor, "32 8"],
                   offset : Int32[Tensor, "32"],
                   weight: Int32[Tensor, "32 8"],
                   activation: Float32[Tensor, "64 32"]) -> Float32[Tensor, "32 32"]:
    offset = offset.view(32, 1)
    def extract(x):
        over = torch.arange(8) * 4
        mask = 2**4 - 1
        return (x[..., None] >> over) & mask
    scale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64)
    offset = extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64)
    return ( scale * (extract(weight).view(-1, 64) - offset))  @ activation

@triton.jit
def quant_dot_kernel(scale_ptr, offset_ptr, weight_ptr, activation_ptr,
                     z_ptr, N0, N1, MID, B0: tl.constexpr, B1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)
    '''
    create X by loading scale, offset, weight and constructing
    the correct weight matrix
    then do the normal block level matrix mult
    '''

    acc = tl.full((B0, B1), 0.0, tl.float32)

    K = B1 // FPINT
    X_STRIDE = MID // FPINT

    y_row_dims = (tl.arange(0, B0))[:, None]
    y_col_dims = (pid_1*B1 + tl.arange(0, B1))[None, :]
    x_row_dims = (pid_0*B0 + tl.arange(0, B0))[:, None]
    x_col_dims = (tl.arange(0, K))[None, :]

    x_base = x_row_dims*X_STRIDE + x_col_dims
    y_base = y_row_dims*N1 + y_col_dims
    z_base = x_row_dims*N1 + y_col_dims

    # need to figure out how this _actually_ works
    over = tl.arange(0, FPINT)*4
    mask = 2**4 - 1

    offset_dims = x_row_dims
    offset = tl.load(offset_ptr + offset_dims).reshape(B0, 1)
    offset = ((offset[:, :, None] >> over) & mask).to(tl.float32).reshape(B0, FPINT)


    num = MID // B1
    for i in range(num):

      scale = tl.load(scale_ptr + x_base)

      x = tl.load(weight_ptr + x_base)
      y = tl.load(activation_ptr + y_base)
      weight = ((x[:, :, None] >> over) & mask).to(tl.float32)
      weight = weight.reshape(B0, B1)


      #weight = (weight-offset)
      #weight = scale*weight
      acc += tl.dot(weight, y)

      x_base += K
      y_base += N1*B1

    tl.store(z_ptr + z_base, acc)

    return




test(quant_dot_kernel, quant_dot_spec, B={"B0": 16, "B1": 16},
                                       nelem={"N0": 32, "N1": 32, "MID": 64})


In [8]:
over = torch.arange(8)*4
mask = 2**4 - 1

In [47]:
offset = torch.tensor([0, 1, 2], dtype=torch.int32).reshape(3, 1)

In [48]:
b = (offset[:, :, None] >> over) & mask

In [49]:
b = b[:, : :, None]
b = b.expand(-1, 1, 8, 8)
b.shape

torch.Size([3, 1, 8, 8])

In [50]:
b = b.reshape(3, 64)

In [51]:
b

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
         1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0,
         2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0,
         2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0]])

In [23]:
w = torch.arange(3*8).reshape(3, 8)
w = (w[:, :, None] >> over) & mask
w = w.reshape(3, 64)