In [None]:
import os

import matplotlib.pyplot as plt
from urllib.request import urlretrieve
from pathlib import Path

import torch
from torch import tensor
import torchvision as tv
import torchvision.transforms.functional as tvf
from torchvision import io

import triton
import triton.language as tl
from triton_utils import *

In [None]:
url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/1600px-Cute_dog.jpg?20140729055059'

In [None]:
# path_img = Path('puppy.jpg')
# if not path_img.exists():
#     urlretrieve(url, path_img)

In [None]:
img_local_path = "/home/htkumar/llms/triton_kernels/puppy.jpg"
img = io.read_image(img_local_path)
print(img.shape)

In [None]:
img[:2, :3, :4].dtype

In [None]:
def show_img(x, figsize=(4, 3), **kwargs):
    plt.figure(figsize=figsize)
    plt.axis('off')
    if len(x.shape) == 3:
        x = x.permute(1, 2, 0)
    plt.imshow(x.cpu(), **kwargs)

In [None]:
img = tvf.resize(img, 150, antialias=True)
ch, h, w = img.shape
ch, h, w

In [None]:
show_img(img)

In [None]:
img.shape

In [None]:
offset_0 = torch.tensor([2, 3])
offset_1 = torch.tensor([4, 5])
offset_0[:, None].shape, offset_1[None, :].shape

In [None]:
7 * offset_0[:, None] + offset_1[None, :]

In [None]:
def cdiv(a, b): return (a + b - 1) // b

In [None]:
@triton.jit
def rgb2grey_k(x_ptr, out_ptr, h, w, bs0: tl.constexpr, bs1: tl.constexpr):
    pid_0 = tl.program_id(0)
    pid_1 = tl.program_id(1)

    offs_0 = pid_0 * bs0 + tl.arange(0, bs0)
    offs_1 = pid_1 * bs1 + tl.arange(0, bs1)
    offs = w * offs_0[:, None] + offs_1[None, :]

    mask_0 = offs_0 < h
    mask_1 = offs_1 < w
    mask = mask_0[:, None]  & mask_1[None, :]

    r = tl.load(x_ptr + 0*h*w + offs, mask)
    g = tl.load(x_ptr + 1*h*w + offs, mask)
    b = tl.load(x_ptr + 2*h*w + offs, mask)

    out = 0.2989 * r + 0.5870 * g + 0.1140 * b
    tl.store(out_ptr + offs, out, mask)


In [None]:
def rgb2grey(x, bs):
    c,h,w = x.shape
    out = torch.empty((h, w), device=x.device, dtype=x.dtype)

    grid = lambda meta: (triton.cdiv(h, meta['bs0']), triton.cdiv(w, meta['bs1']))
    rgb2grey_k[grid](x, out, h, w, bs0=bs[0], bs1=bs[1])
    return out

In [None]:
grey_img = rgb2grey(img.to('cuda'), bs=(32, 32)).to('cpu')

In [None]:
show_img(grey_img)

In [None]:
grey_img

In [None]:
a = torch.tensor([
    [1, 2, 3],
    [3, 4, 5]
])
print(a.shape)
a.stride(0), a.stride(1)

In [None]:
@triton.jit
def naive_matmul_k(
    a_ptr, b_ptr, c_ptr,
    m, n, k,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    rm = get_1d_offset(bm, pid_m)
    rn = get_1d_offset(bn, pid_n)
    # TODO: do we need rk
    rk = get_1d_offset(bk, 0)
    offs_a = a_ptr + get_2d_offset(rm, rk, stride_am, stride_ak)
    offs_b = b_ptr + get_2d_offset(rk, rn, stride_bk, stride_bn)

    acc = tl.zeros((bm, bn), dtype=tl.float32)
    for _ in range(0, k, bk):
        mask_a = get_2d_mask(rm, rk, m, k)
        mask_b = get_2d_mask(rk, rn, k, n)

        a = tl.load(offs_a, mask=mask_a)
        b = tl.load(offs_b, mask=mask_b)
        acc += tl.dot(a, b)

        offs_a += bk * stride_ak
        offs_b += bk * stride_bk

    c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)
    mask = get_2d_mask(rm, rn, m, n)
    tl.store(c, acc, mask)

In [None]:
from functools import partial

def matmul(a, b, matmul_k_fn, bs=16, group_sz=None):
    check_tensors_gpu_ready(a, b)
    assert a.shape[1] == b.shape[0]
    (m, k), (_, n) = a.shape, b.shape

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    grid = lambda meta: (triton.cdiv(m, meta["bm"]), triton.cdiv(n, meta["bn"]))
    naive_matmul_k[grid](
        a, b, c, m, n, k,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        bm=bs, bn=bs, bk=bs
    )
    return c

In [None]:
naive_matmul = partial(matmul, matmul_k_fn=naive_matmul_k)

In [None]:
a = torch.ones((3, 4), dtype=torch.float32, device='cuda')
b = torch.ones((4, 5), dtype=torch.float32, device='cuda')

In [None]:
naive_matmul(a,b)

In [None]:
torch.manual_seed(128)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = naive_matmul(a, b)
pytorch_output = a@b
torch.allclose(triton_output, pytorch_output)


In [None]:
@triton.jit
def naive_matmul_k2(
    a_ptr, b_ptr, c_ptr,
    m, n, k,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    bm: tl.constexpr, bn: tl.constexpr, bk: tl.constexpr
):
    pid_m, pid_n = tl.program_id(0), tl.program_id(1)
    rm = get_1d_offset(bm, pid_m)
    rn = get_1d_offset(bn, pid_n)
    # TODO: do we need rk
    rk = get_1d_offset(k, 0)
    offs_a = a_ptr + get_2d_offset(rm, k, stride_am, stride_ak)
    offs_b = b_ptr + get_2d_offset(k, rn, stride_bk, stride_bn)

    mask_a = get_2d_mask(rm, rk, m, k)
    mask_b = get_2d_mask(rk, rn, k, n)

    a = tl.load(offs_a, mask=mask_a)
    b = tl.load(offs_b, mask=mask_b)
    acc = tl.dot(a, b)
    c = c_ptr + get_2d_offset(rm, rn, stride_cm, stride_cn)
    mask = get_2d_mask(rm, rn, m, n)
    tl.store(c, acc, mask)

In [None]:
from functools import partial

def matmul2(a, b, matmul_k_fn, bs=16, group_sz=None):
    check_tensors_gpu_ready(a, b)
    assert a.shape[1] == b.shape[0]
    (m, k), (_, n) = a.shape, b.shape

    c = torch.zeros((m, n), device=a.device, dtype=torch.float16)
    grid = lambda meta: (triton.cdiv(m, meta["bm"]), triton.cdiv(n, meta["bn"]))
    naive_matmul_k[grid](
        a, b, c, m, n, k,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        bm=bs, bn=bs, bk=bs
    )
    return c

In [None]:
naive_matmul2 = partial(matmul2, matmul_k_fn=naive_matmul_k2)

In [None]:
a = torch.ones((3, 4), dtype=torch.float32, device='cuda')
b = torch.ones((4, 5), dtype=torch.float32, device='cuda')
naive_matmul2(a,b)

In [None]:
torch.manual_seed(128)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = naive_matmul2(a, b)
pytorch_output = a@b
torch.allclose(triton_output, pytorch_output)