# Fused softmax + matmul



In [1]:
import sys, os
from pathlib import Path

# Add the parent directory of the current notebook to sys.path
cur_dir = Path().resolve()
parent_dir = cur_dir.parent
sys.path += [str(parent_dir), str(cur_dir)]


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import torch
from utils import cdiv, get_sig, load_cuda, profile_kernel
from collections import namedtuple

def test_allclose(kernels):
    D = 1024
    V_test = torch.randn(D).contiguous().cuda()
    O_torch = torch.softmax(V_test,  dim=0)
    for kernel_name, kernel_data in kernels.items():
        if kernel_name!="torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(V_test)
            if not torch.allclose(O, O_torch, atol=1e-4):
                raise ValueError(f"{kernel_name=} failed:\n\n {O[:10]=}, {O_torch[:10]=}")
            print(f"{kernel_name=} agrees with torch softmax")
        


def profile_kernels(kernels):
    test_allclose(kernels)
    for kernel_name, kernel_data in kernels.items():
        print(f"Profiling: {kernel_name}")
        profile_kernel(kernel_data["module"], kernel_data["fname"], *kernel_data["args"], **kernel_data["kwargs"])


## Python cuda looking implementation

In [3]:
N, L, M = 32, 16, 64
Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

class TorchNaiveSoftmaxAndMatMul(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def softmax_and_matmul(self, Q, K):
        O = Q@K
        return torch.softmax(O, dim=0)

model = TorchNaiveSoftmaxAndMatMul()
output = model.softmax_and_matmul(Q_small, K_small)

print("Input Q shape:", Q_small.shape)
print("Input K shape:", K_small.shape)
print("Output shape :", output.shape)
print("Output:", output)

Input Q shape: torch.Size([32, 16])
Input K shape: torch.Size([16, 64])
Output shape : torch.Size([32, 64])
Output: tensor([[3.2414e-02, 2.9322e-05, 2.3462e-05,  ..., 2.6289e-06, 2.4236e-03,
         3.9168e-07],
        [1.9082e-03, 3.4305e-05, 6.3105e-05,  ..., 1.5615e-03, 3.3092e-03,
         1.2458e-06],
        [1.3428e-02, 2.9178e-04, 8.2823e-06,  ..., 3.0067e-03, 1.0007e-03,
         1.9285e-02],
        ...,
        [4.8037e-01, 2.4345e-04, 2.5271e-10,  ..., 5.0420e-05, 5.3735e-04,
         3.4535e-07],
        [1.3780e-02, 7.5145e-05, 1.8623e-03,  ..., 2.1507e-03, 3.6669e-05,
         8.2915e-02],
        [1.5153e-05, 5.6218e-04, 5.1192e-05,  ..., 1.2755e-02, 7.1740e-04,
         3.6581e-04]], device='cuda:0')


## Cuda 

In [4]:
def get_modules(kernels):
    for kernel_name, kernel_data in kernels.items():

        fname = kernel_data["fname"]
        cuda_source = Path(kernel_data["cuda_source_path"]).read_text()
        cpp_source = get_sig(fname, cuda_source)
        module = load_cuda(cuda_source, cpp_source, funcs=[fname])
        kernel_data["module"] = module


def get_softmax_modules(kernels):
    get_modules(kernels)
    kernels["torch"] = {
        "module": TorchNaiveSoftmaxAndMatMul(),
        "fname": "softmax_and_matmul",
    }

def add_args_kwargs(kernels, *args, **kwargs):
    for kernel_name, kernel_data in kernels.items():
        kernel_data["args"]= args
        


In [5]:
N, L, M = 1024, 256, 768
Q = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

kernels = {
    "fused_softmax_matmul": dict(cuda_source_path = "./fused_softmax_matmul.cu", fname = "fused_softmax_matmul"),
}
get_softmax_modules(kernels)
add_args_kwargs(kernels, Q, K)


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


## Profile

In [6]:
profile_kernels(kernels)


TypeError: fused_softmax_matmul(): incompatible function arguments. The following argument types are supported:
    1. (arg0: torch.Tensor, arg1: torch.Tensor) -> torch.Tensor

Invoked with: tensor([ 0.4030, -0.5492,  0.0681,  ...,  1.1203, -0.3838,  2.1356],
       device='cuda:0')