In [2]:
from collections import namedtuple
import torch
dim3 = namedtuple('dim3', ['x','y','z'], defaults=(1,1))

def softmax_kernel(blockId: dim3, blockDim: dim3, threadId, V: torch.Tensor, O: torch.Tensor, s: int):
    """
    Compute softmax of V using O as the output tensor.
    """
    res = 0.0
    tot = 0.0
    location = blockId.x*blockDim.x + threadId.x

    if location >= s:
        return

    for i in range(s):
        cur = V[i] # s reads * s threads
        cur = torch.exp(cur) # s exp * s threads
        tot += cur # s sum * s threads
        if i == location:
            res = cur
        

    res = res/tot  # 1 div * s threads

    O[location] = res # 1 write * s threads


def cdiv(a, b):
    return (a + b - 1) // b

def blk_kernel1d(f, blocks, threads, *args):
        for i1 in range(blocks.x):
                for j1 in range(threads.x): 
                    f(dim3(i1), dim3(j1), threads, *args)
     
def softmax(V): 
    s = V.shape[0]
    tpb = 32  # threads per block
    threads = dim3(tpb, 1, 1)
    blocks = dim3(cdiv(s, tpb), 1, 1)
    O = torch.zeros_like(V, dtype=torch.float32)

    blk_kernel1d(
         softmax_kernel, blocks, threads, V, O, s)

    return O

V = torch.randn(32, dtype=torch.float32)

O = softmax(V)
# print("Softmax output:", O)

O_torch = torch.softmax(V, dim=0)
torch.allclose(O, O_torch, atol=1e-4)


False

In [3]:
O_torch

tensor([0.0175, 0.0661, 0.0389, 0.0158, 0.0126, 0.0118, 0.0034, 0.0843, 0.0199,
        0.0745, 0.0192, 0.0728, 0.0136, 0.0289, 0.0175, 0.0152, 0.0425, 0.0418,
        0.0217, 0.0847, 0.0103, 0.0698, 0.0412, 0.0042, 0.0117, 0.0195, 0.0075,
        0.0091, 0.0493, 0.0424, 0.0044, 0.0279])

In [4]:
O

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])