# Dense Max-Plus

In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import ipywidgets as widgets
import numpy as np
from semitorch.utils import Timer, CUDATimer, ntuple, mnistplot
from typing import Optional, Union, Tuple, TypeVar
import math

import taichi as ti
import taichi.math as tm
import hidet

device = torch.device('cuda')
ti.init(arch=ti.gpu)

[Taichi] version 1.5.0, llvm 15.0.4, commit 7b885c28, linux, python 3.10.10


[I 05/10/23 15:15:52.805 34838] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


[Taichi] Starting on arch=cuda


In [2]:
torch.manual_seed(0)
B, Dx, Dy = 2048, 1024, 512
x = torch.rand(B, Dx, dtype=torch.float32, device=device, requires_grad=True)
a = torch.randn(Dy, Dx, dtype=torch.float32, device=device, requires_grad=True)
grad_y = torch.randn(B, Dy, dtype=torch.float32, device=device, requires_grad=False)

## Baseline (version 0)

In [3]:
def maxplus_v0(x, a):
    assert a.ndim == 2 and x.ndim >= 1
    assert x.shape[-1] == a.shape[-1]
    x = x.unsqueeze(-2)
    return torch.max(x + a, dim=-1)[0]


with CUDATimer():
    with torch.no_grad():
        y_v0 = maxplus_v0(x, a)

with CUDATimer():
    y_v0 = maxplus_v0(x, a)

with CUDATimer():
    y_v0.backward(grad_y)

# with torch.profiler.profile(
#     record_shapes=True, profile_memory=True, with_stack=True,
#     on_trace_ready=torch.profiler.tensorboard_trace_handler(
#         '../samples', worker_name='maxplus_v0'),
# ) as prof:
#     y_v0 = maxplus_v0(x, a)
#     y_v0.backward(grad_y)

grad_x, grad_a = x.grad, a.grad

Elapsed: 47.65 ms
Elapsed: 36.42 ms
Elapsed: 35.73 ms


## Version 1: Naive Taichi

In [None]:
@ti.kernel
def maxplus_kernel_v1(
    y: ti.types.ndarray(ndim=2),  # [B,Dy]
    x: ti.types.ndarray(ndim=2),  # [B,Dx]
    a: ti.types.ndarray(ndim=2),  # [Dy,Dx]
):
    for b, i in y:
        v = -tm.inf
        for j in range(a.shape[-1]):
            v = tm.max(v, x[b, j] + a[i, j])
        y[b, i] = v


@ti.kernel
def maxplus_fw_kernel_v1(
    y: ti.types.ndarray(ndim=2),  # [B,Dy]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    x: ti.types.ndarray(ndim=2),  # [B,Dx]
    a: ti.types.ndarray(ndim=2),  # [Dy,Dx]
):
    for b, i in y:
        v = -tm.inf
        hit: ti.i32 = -1
        for j in range(a.shape[0]):
            w = x[b, j] + a[i, j]
            if w > v:
                v = w
                hit = j
        y[b, i] = v
        hits[b, i] = hit


@ti.kernel
def maxplus_bw_x_kernel_v1(
    gradx: ti.types.ndarray(ndim=2),  # [B,Dx]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    grady: ti.types.ndarray(ndim=2),  # [B,Dy]
):
    pass


@ti.kernel
def maxplus_bw_a_kernel_v1(
    grada: ti.types.ndarray(ndim=2),  # [Dy,Dx]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    grady: ti.types.ndarray(ndim=2),  # [B,Dy]
):
    pass


class MaxPlusFunction_v1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, a):
        assert x.device == a.device, "inputs x and a should be on the same device but are on f{x.device} resp. f{a.device}"
        x = x.contiguous()
        a = a.contiguous()

        y = torch.empty((*x.shape[0:-1], a.shape[0]), device=x.device, dtype=x.dtype)

        if (x.requires_grad or a.requires_grad) and torch.is_grad_enabled():
            hits = torch.empty_like(x, dtype=torch.int32)
            maxplus_fw_kernel_v1(y, hits, x, a)
            ctx.save_for_backward(hits)
        else:
            maxplus_kernel_v1(y, x, a)

        x.device.type == 'cuda' and torch.cuda.synchronize()
        return y

    @staticmethod
    def backward(ctx, grad_y):
        (hits,) = ctx.saved_tensors

        grad_y.device.type == 'cuda' and torch.cuda.synchronize()


def maxplus_v1(x, a):
    return MaxPlusFunction_v1.apply(x, a)


with CUDATimer():
    with torch.no_grad():
        y_v1 = maxplus_v1(x, a)

# with CUDATimer():
#     y_v1 = maxplus_v1(x, a)
#     y_v1.backward(grad_y)

torch.allclose(y_v1, y_v0)

## Version 2: Naive Hidet

In [None]:
def maxplus_fw_v2(nbatch, nin, nout):
    from hidet.lang import f32, attr
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim

    with hidet.script_module() as script_module:
        @hidet.script
        def kernel(
            y: f32[nbatch, nout],
            x: f32[nbatch, nout],
            a: f32[nout, nin]
        ):
            attr.cuda_grid_dim = ((nbatch + 31) // 32, (nout + 31) // 32)
            attr.cuda_block_dim = (32, 32)
            i = threadIdx.x + blockIdx.x * blockDim.x
            j = threadIdx.y + blockIdx.y * blockDim.y
            if i < nbatch and j < nout:
                val = -1e6
                for k in range(nin):
                    val = max(x[i, k] + a[j, k], val)
                y[i, j] = val

    ir_module = script_module.ir_module()
    func = hidet.driver.build_ir_module(ir_module)
    return func


class MaxPlusFunction_v2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, a):
        assert x.device == a.device, "inputs x and a should be on the same device but are on f{x.device} resp. f{a.device}"
        x = x.contiguous()
        a = a.contiguous()

        y = torch.empty((*x.shape[0:-1], a.shape[0]), device=x.device, dtype=x.dtype)

        if (x.requires_grad or a.requires_grad) and torch.is_grad_enabled():
            hits = torch.empty_like(x, dtype=torch.int32)
            # maxplus_fw_kernel_v1(y, hits, x, a)
            ctx.save_for_backward(hits)
        else:
            nbatch, nin = x.shape
            nout, nin = a.shape
            f = maxplus_fw_v2(nbatch, nin, nout)
            f(y.detach(), x.detach(), a.detach())

        x.device.type == 'cuda' and torch.cuda.synchronize()
        return y

    @staticmethod
    def backward(ctx, grad_y):
        pass


def maxplus_v2(x, a):
    return MaxPlusFunction_v2.apply(x, a)


with CUDATimer():
    with torch.no_grad():
        y_v2 = maxplus_v2(x, a)

# with CUDATimer():
#     y_v1 = maxplus_v1(x, a)
#     y_v1.backward(grad_y)

torch.allclose(y_v2, y_v0)

In [None]:
y_v2 - y_v0

## Version 3: Rule-based Hidet

In [43]:
from hidet.ir.compute import TensorNode, compute, reduce, arg_reduce
from hidet.ir.task import Task
from hidet.graph import Operator, Tensor
from hidet.graph.ops.definitions.utils import input_like

class MaxPlusNoGradTask(Task):
    def __init__(self, x: TensorNode, a: TensorNode):
        # get the input sizes
        batch_size, in_size = x.const_shape()
        out_size, in_size = a.const_shape()

        # define the computation
        y = compute(
            name='y',
            shape=[batch_size, out_size],
            fcompute=lambda b, i: reduce(
                shape=[in_size],
                fcompute=lambda k: x[b, k] + a[i, k],
                reduce_type='max',
            ),
        )

        # call the parent class constructor to initialize the task
        super().__init__(
            name='maxplus',  # the name of the task
            inputs=[x, a],  # the input tensor nodes
            outputs=[y],  # the output tensor nodes
        )

class MaxPlusNoGradOp(Operator):
    def __init__(self, x, a):
        # call the parent class constructor to initialize the operator
        super().__init__(
            inputs=[x, a],  # the input tensors
            attributes={},
            task=MaxPlusNoGradTask(  # the task of the operator
                # create tensor nodes (TensorNode) with the same shape and dtype as the tensors (Tensor)
                input_like(x, 'x'),
                input_like(a, 'a'),
            ),
        )

class MaxPlusTask(Task):
    def __init__(self, x: TensorNode, a: TensorNode):
        # get the input sizes
        batch_size, in_size = x.const_shape()
        out_size, in_size = a.const_shape()

        # define the computation
        y = compute(
            name='y',
            shape=[batch_size, out_size],
            fcompute=lambda b, i: reduce(
                shape=[in_size],
                fcompute=lambda k: x[b, k] + a[i, k],
                reduce_type='max',
            ),
        )

        hits = compute(
            name='hits',
            shape=[batch_size, out_size],
            fcompute=lambda b, i: arg_reduce(
                in_size,
                fcompute=lambda k: x[b, k] + a[i, k],
                reduce_type='max',
            ),
        )

        # call the parent class constructor to initialize the task
        super().__init__(
            name='maxplus',  # the name of the task
            inputs=[x, a],  # the input tensor nodes
            outputs=[y, hits],  # the output tensor nodes
        )

class MaxPlusOp(Operator):
    def __init__(self, x, a):
        # call the parent class constructor to initialize the operator
        super().__init__(
            inputs=[x, a],  # the input tensors
            attributes={},
            task=MaxPlusTask(  # the task of the operator
                # create tensor nodes (TensorNode) with the same shape and dtype as the tensors (Tensor)
                input_like(x, 'x'),
                input_like(a, 'a'),
            ),
        )


def maxplus_bw_x_v3(nbatch, nin, nout):
    from hidet.lang import f32, i64, attr
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim

    with hidet.script_module() as script_module:
        @hidet.script
        def kernel(
            grad_x: f32[nbatch, nin],
            grad_y: f32[nbatch, nout],
            hits: i64[nbatch, nout]
        ):
            attr.cuda_grid_dim = ((nbatch + 31) // 32, (nin + 31) // 32)
            attr.cuda_block_dim = (32, 32)
            i = threadIdx.x + blockIdx.x * blockDim.x
            j = threadIdx.y + blockIdx.y * blockDim.y
            if i < nbatch and j < nin:
                for k in range(nout):
                    if hits[i, k] == j:
                        grad_x[i, j] += grad_y[i, k]

    ir_module = script_module.ir_module()
    func = hidet.driver.build_ir_module(ir_module)
    return func

def maxplus_bw_a_v3(nbatch, nin, nout):
    from hidet.lang import f32, i64, attr
    from hidet.lang.cuda import threadIdx, blockIdx, blockDim

    with hidet.script_module() as script_module:
        @hidet.script
        def kernel(
            grad_a: f32[nout, nin],
            grad_y: f32[nbatch, nout],
            hits: i64[nbatch, nout]
        ):
            attr.cuda_grid_dim = ((nout + 31) // 32, (nin + 31) // 32)
            attr.cuda_block_dim = (32, 32)
            i = threadIdx.x + blockIdx.x * blockDim.x
            j = threadIdx.y + blockIdx.y * blockDim.y
            if i < nout and j < nin:
                for k in range(nbatch):
                    if hits[k, i] == j:
                        grad_a[i, j] += grad_y[k, i]

    ir_module = script_module.ir_module()
    func = hidet.driver.build_ir_module(ir_module)
    return func


class MaxPlusFunction_v3(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, a, grad_enabled: bool):
        assert x.device == a.device, "inputs x and a should be on the same device but are on f{x.device} resp. f{a.device}"
        x = x.contiguous()
        a = a.contiguous()

        y = torch.empty((*x.shape[0:-1], a.shape[0]), device=x.device, dtype=x.dtype)

        if grad_enabled:
            op = MaxPlusOp(hidet.from_torch(x.detach()), hidet.from_torch(a.detach()))
            y = op.get_output(0).torch()
            hits = op.get_output(1).torch()
            ctx.save_for_backward(hits)
            ctx.in_features = x.shape[-1]
        else:
            op = MaxPlusNoGradOp(hidet.from_torch(x.detach()), hidet.from_torch(a.detach()))
            y = op.get_output(0).torch()
            
        x.device.type == 'cuda' and torch.cuda.synchronize()
        return y

    @staticmethod
    def backward(ctx, grad_y):
        (hits,) = ctx.saved_tensors
        grad_y = grad_y.contiguous()
        hits.to(grad_y.device)

        grad_x = torch.zeros(grad_y.shape[0], ctx.in_features, dtype=grad_y.dtype, device=grad_y.device)
        grad_a = torch.zeros(grad_y.shape[1], ctx.in_features, dtype=grad_y.dtype, device=grad_y.device)

        nbatch, nout = grad_y.shape
        nin = ctx.in_features

        backward_x = maxplus_bw_x_v3(nbatch, nin, nout)
        backward_a = maxplus_bw_a_v3(nbatch, nin, nout)

        backward_x(grad_x, grad_y, hits)
        backward_a(grad_a, grad_y, hits)

        grad_y.device.type == 'cuda' and torch.cuda.synchronize()
        return grad_x, grad_a, None


def maxplus_v3(x, a):
    return MaxPlusFunction_v3.apply(x, a, torch.is_grad_enabled())


with CUDATimer():
    with torch.no_grad():
        y_v3 = maxplus_v3(x, a)

with CUDATimer():
    y_v3 = maxplus_v3(x, a)

with CUDATimer():
    x.grad, a.grad = None, None
    y_v3.backward(grad_y)

torch.allclose(y_v3, y_v0), torch.allclose(x.grad, grad_x, atol=1e-6), torch.allclose(a.grad, grad_a, atol=1e-6)

Elapsed: 34.90 ms
Elapsed: 65.18 ms
Elapsed: 3.37 s


(True, True, False)

In [45]:
torch.min(grad_a - a.grad), torch.numel(grad_a - a.grad)

(tensor(-298.7048, device='cuda:0'), 524288)