# Dense Max-Plus

In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import ipywidgets as widgets
import numpy as np
from semitorch.utils import Timer, CUDATimer, ntuple, mnistplot
from typing import Optional, Union, Tuple, TypeVar

import taichi as ti
import taichi.math as tm

device = torch.device('cuda')
ti.init(arch=ti.gpu)

[Taichi] version 1.5.0, llvm 15.0.4, commit 7b885c28, linux, python 3.10.10


[I 04/21/23 17:20:33.297 169584] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


[Taichi] Starting on arch=cuda


In [34]:
torch.manual_seed(0)
B, Dx, Dy = 4096, 1024, 512
x = torch.rand(B, Dx, dtype=torch.float32, device=device, requires_grad=False)
a = torch.randn(Dy, Dx, dtype=torch.float32, device=device, requires_grad=False)
grad_y = torch.randn(B, Dy, dtype=torch.float32, device=device, requires_grad=False)

## Baseline (v0)

In [35]:
def maxplus_v0(x, a):
    assert a.ndim == 2 and x.ndim >= 1
    assert x.shape[-1] == a.shape[-1]
    x = x.unsqueeze(-2)
    return torch.max(x + a, dim=-1)[0]


with CUDATimer():
    y_v0 = maxplus_v0(x, a)

Elapsed: 68.72 ms


## Taichi v1

In [36]:
@ti.kernel
def maxplus_kernel_v1(
    y: ti.types.ndarray(ndim=2),  # [B,Dy]
    x: ti.types.ndarray(ndim=2),  # [B,Dx]
    a: ti.types.ndarray(ndim=2),  # [Dy,Dx]
):
    for b, i in y:
        v = -tm.inf
        for j in range(a.shape[-1]):
            v = tm.max(v, x[b, j] + a[i, j])
        y[b, i] = v


@ti.kernel
def maxplus_fw_kernel_v1(
    y: ti.types.ndarray(ndim=2),  # [B,Dy]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    x: ti.types.ndarray(ndim=2),  # [B,Dx]
    a: ti.types.ndarray(ndim=2),  # [Dy,Dx]
):
    for b, i in y:
        v = -tm.inf
        hit = -1
        for j in range(a.shape[0]):
            w = x[b, j] + a[i, j]
            if w > v:
                v = w
                hit = j
        y[b, i] = v
        hits[b, i] = hit


@ti.kernel
def maxplus_bw_x_kernel_v1(
    gradx: ti.types.ndarray(ndim=2),  # [B,Dx]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    grady: ti.types.ndarray(ndim=2),  # [B,Dy]
):
    pass


@ti.kernel
def maxplus_bw_a_kernel_v1(
    grada: ti.types.ndarray(ndim=2),  # [Dy,Dx]
    hits: ti.types.ndarray(dtype=ti.i32, ndim=2),  # [B,Dx]
    grady: ti.types.ndarray(ndim=2),  # [B,Dy]
):
    pass


class MaxPlusFunction_v1(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, a):
        x = x.contiguous()
        a = a.contiguous()

        y = torch.empty((*x.shape[0:-1], a.shape[0]), device=x.device, dtype=x.dtype)

        if x.requires_grad or a.requires_grad:
            hits = torch.empty_like(x, dtype=torch.int32)
            maxplus_fw_kernel_v1(y, hits, x, a)
            ctx.save_for_backward(hits)
        else:
            maxplus_kernel_v1(y, x, a)

        torch.cuda.synchronize()
        return y

    @staticmethod
    def backward(ctx, grad_y):
        (hits,) = ctx.saved_tensors

        ti.sync()


def maxplus_v1(x, a):
    return MaxPlusFunction_v1.apply(x, a)


with CUDATimer():
    y_v1 = maxplus_v1(x, a)

torch.allclose(y_v1, y_v0)

Elapsed: 190.36 ms


True