# Fused softmax + matmul



In [1]:
import sys, os
from pathlib import Path

# Add the parent directory of the current notebook to sys.path
cur_dir = Path().resolve()
parent_dir = cur_dir.parent
sys.path += [str(parent_dir), str(cur_dir)]


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [10]:
import torch
from utils import cdiv, get_sig, load_cuda, profile_kernel
from collections import namedtuple

def test_allclose(kernels):
    N, L, M = 32, 16, 64
    Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
    K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
    O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
    for kernel_name, kernel_data in kernels.items():
        if kernel_name!="torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            if not torch.allclose(O, O_torch, atol=1e-4):
                raise ValueError(f"{kernel_name=} failed:\n\n {O[:10]=}, {O_torch[:10]=}")
            print(f"{kernel_name=} agrees with torch softmax")
        


def profile_kernels(kernels):
    test_allclose(kernels)
    for kernel_name, kernel_data in kernels.items():
        print(f"Profiling: {kernel_name}")
        profile_kernel(kernel_data["module"], kernel_data["fname"], *kernel_data["args"], **kernel_data["kwargs"])


## Python cuda looking implementation

In [3]:
N, L, M = 32, 16, 64
Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

class TorchNaiveSoftmaxAndMatMul(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def softmax_and_matmul(self, Q, K):
        O = Q@K
        return torch.softmax(O, dim=0)

model = TorchNaiveSoftmaxAndMatMul()
output = model.softmax_and_matmul(Q_small, K_small)

print("Input Q shape:", Q_small.shape)
print("Input K shape:", K_small.shape)
print("Output shape :", output.shape)
print("Output:", output)

Input Q shape: torch.Size([32, 16])
Input K shape: torch.Size([16, 64])
Output shape : torch.Size([32, 64])
Output: tensor([[3.2414e-02, 2.9322e-05, 2.3462e-05,  ..., 2.6289e-06, 2.4236e-03,
         3.9168e-07],
        [1.9082e-03, 3.4305e-05, 6.3105e-05,  ..., 1.5615e-03, 3.3092e-03,
         1.2458e-06],
        [1.3428e-02, 2.9178e-04, 8.2823e-06,  ..., 3.0067e-03, 1.0007e-03,
         1.9285e-02],
        ...,
        [4.8037e-01, 2.4345e-04, 2.5271e-10,  ..., 5.0420e-05, 5.3735e-04,
         3.4535e-07],
        [1.3780e-02, 7.5145e-05, 1.8623e-03,  ..., 2.1507e-03, 3.6669e-05,
         8.2915e-02],
        [1.5153e-05, 5.6218e-04, 5.1192e-05,  ..., 1.2755e-02, 7.1740e-04,
         3.6581e-04]], device='cuda:0')


## Cuda 

In [4]:
def get_modules(kernels):
    for kernel_name, kernel_data in kernels.items():

        fname = kernel_data["fname"]
        cuda_source = Path(kernel_data["cuda_source_path"]).read_text()
        cpp_source = get_sig(fname, cuda_source)
        module = load_cuda(cuda_source, cpp_source, funcs=[fname])
        kernel_data["module"] = module


def get_softmax_modules(kernels):
    get_modules(kernels)
    kernels["torch"] = {
        "module": TorchNaiveSoftmaxAndMatMul(),
        "fname": "softmax_and_matmul",
    }

def add_args_kwargs(kernels, *args, **kwargs):
    for kernel_name, kernel_data in kernels.items():
        kernel_data["args"]= args
        


In [8]:
N, L, M = 1024, 256, 768
Q = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

kernels = {
    "fused_softmax_matmul": dict(cuda_source_path = "./fused_softmax_matmul.cu", fname = "fused_softmax_matmul"),
}
get_softmax_modules(kernels)
add_args_kwargs(kernels, Q, K)


## Profile

In [11]:
profile_kernels(kernels)

ValueError: kernel_name='fused_softmax_matmul' failed:

 O[:10]=tensor([[inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf],
        [inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf,
         inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf]],
       device='cuda:0'), O_torch[:10]=tensor([[1.3848e-02, 9.2309e-08, 1.4082e-02, 6.6766e-08, 2.1426e-02, 1.7599e-07,
         4.8200e-08, 4.1218e-05, 3.9299e-05, 1.2663e-03, 1.1633e-03, 4.6437e-03,
         1.0946e-03, 6.1957e-09, 3.0093e-04, 7.3129e-04, 1.4905e-03, 2.7493e-08,
         3.4428e-05, 1.8017e-05, 1.9827e-04, 1.6011e-05, 1.1305e-03, 1.3530e-05,
         7.1084e-03, 2.9943e-01, 1.2937e-02, 6.2817e-04, 9.9044e-07, 1.4306e-04,
         5.4234e-04, 1.5883e-04, 6.9838e-05, 6.3806e-05, 4.4358e-04, 3.7384e-05,
         1.4716e-07, 1.4325e-02, 4.1881e-04, 6.1232e-06, 6.5569e-05, 3.7534e-03,
         7.4224e-05, 2.3357e-03, 5.0607e-02, 2.8963e-07, 5.1588e-03, 2.0085e-03,
         3.1143e-06, 1.7450e-02, 3.4432e-10, 2.0151e-04, 2.8077e-02, 1.0061e-06,
         1.5180e-02, 5.0359e-04, 4.1825e-06, 2.8901e-05, 1.1917e-01, 1.1289e-03,
         1.2304e-03, 2.2805e-03, 6.1624e-05, 1.7455e-03],
        [3.4477e-05, 1.1892e-02, 6.3960e-02, 9.1675e-04, 1.8499e-02, 2.7449e-03,
         5.3653e-08, 1.7011e-03, 2.2993e-05, 2.1797e-03, 6.8534e-04, 4.1630e-01,
         1.1115e-05, 2.4375e-04, 8.9878e-02, 3.5817e-05, 2.2939e-04, 5.8403e-06,
         2.7999e-06, 2.0935e-03, 5.0186e-08, 2.3663e-06, 4.2241e-07, 2.8398e-02,
         3.5287e-04, 2.2383e-02, 1.0213e-04, 1.5922e-04, 1.1735e-04, 4.6045e-06,
         4.2159e-07, 6.8795e-01, 1.1442e-04, 1.4884e-01, 2.3573e-04, 3.0594e-05,
         9.2424e-03, 5.2015e-07, 1.6887e-02, 6.3290e-10, 1.3004e-02, 1.2037e-04,
         1.6257e-04, 8.6546e-01, 1.0762e-03, 1.4725e-04, 6.1158e-06, 2.1322e-06,
         5.7387e-03, 3.9459e-06, 9.9533e-01, 3.7285e-03, 1.0351e-04, 8.4881e-01,
         1.9983e-06, 6.7160e-03, 1.0793e-05, 9.2652e-06, 2.4227e-05, 1.0289e-03,
         1.4050e-03, 1.9985e-06, 7.8181e-05, 1.3898e-01],
        [4.0799e-06, 1.0813e-05, 2.2190e-06, 8.6692e-04, 2.5399e-04, 7.2376e-05,
         3.6124e-04, 1.9775e-03, 3.2944e-03, 3.4289e-04, 1.9370e-05, 1.3468e-05,
         1.5480e-04, 2.6280e-04, 4.2907e-06, 7.9591e-07, 2.1680e-01, 4.3165e-04,
         7.0241e-05, 2.2905e-04, 4.9280e-03, 1.2749e-06, 2.6912e-02, 2.5465e-04,
         7.2938e-06, 4.9129e-03, 6.7242e-04, 1.2395e-03, 5.1837e-05, 2.5736e-02,
         9.7124e-04, 7.5761e-04, 1.9517e-05, 3.9498e-04, 1.0354e-04, 7.6489e-06,
         6.8711e-04, 3.0927e-02, 2.3562e-06, 4.8962e-05, 1.9990e-05, 1.0076e-04,
         2.3195e-02, 2.4266e-03, 4.4802e-05, 1.0181e-01, 2.2458e-06, 1.1191e-04,
         9.8977e-03, 1.8372e-02, 6.8697e-07, 6.7790e-03, 4.0994e-03, 1.3750e-06,
         8.3412e-06, 1.6489e-03, 7.6286e-07, 1.5406e-01, 7.7590e-04, 9.5875e-03,
         3.0071e-05, 1.4486e-03, 5.0153e-03, 8.9697e-05],
        [2.2192e-03, 6.5134e-03, 7.7420e-03, 1.7386e-05, 8.4836e-01, 1.3822e-04,
         3.1881e-03, 3.2554e-06, 1.7297e-02, 1.7554e-02, 3.2716e-03, 3.0114e-02,
         7.5313e-07, 4.2874e-07, 8.9372e-01, 5.1958e-03, 3.3836e-05, 4.7052e-05,
         3.0259e-07, 9.7355e-03, 3.4403e-06, 9.2574e-05, 3.3610e-02, 1.7696e-03,
         7.9522e-02, 3.9379e-04, 4.7234e-04, 3.1010e-07, 1.4170e-05, 6.4973e-08,
         1.3906e-02, 5.6619e-02, 9.6572e-04, 3.0668e-06, 4.0881e-01, 2.2998e-04,
         3.7455e-04, 2.6496e-04, 1.7254e-03, 2.2180e-07, 2.1545e-03, 1.9935e-07,
         2.0658e-03, 3.3735e-03, 1.2788e-02, 2.3951e-05, 2.4781e-07, 5.3650e-06,
         7.8195e-05, 4.8844e-07, 1.9748e-04, 7.1993e-04, 3.9835e-03, 8.3656e-03,
         1.1839e-06, 4.8270e-02, 1.3831e-05, 6.8456e-04, 8.7170e-04, 2.8028e-03,
         1.1776e-03, 2.3000e-05, 6.9898e-03, 1.3753e-03],
        [5.1018e-04, 1.1703e-05, 1.9512e-05, 1.9313e-03, 2.0194e-04, 5.2502e-05,
         2.5148e-03, 6.7274e-06, 3.1654e-05, 3.2488e-02, 4.9159e-06, 1.7114e-04,
         3.4730e-02, 3.2583e-04, 2.2715e-04, 7.9360e-07, 2.0046e-02, 3.1065e-04,
         2.4675e-06, 1.7868e-02, 6.2885e-02, 9.7762e-01, 4.0722e-01, 2.4601e-04,
         5.0015e-02, 5.2747e-03, 6.6516e-06, 8.4387e-04, 2.5518e-04, 9.9232e-04,
         1.4643e-02, 2.6552e-03, 3.5464e-04, 3.0013e-04, 2.8887e-04, 2.0012e-03,
         1.7211e-05, 8.8844e-03, 1.2866e-04, 5.6194e-05, 5.6593e-03, 1.3736e-03,
         8.7162e-05, 2.2386e-04, 4.7286e-05, 2.3847e-03, 1.2433e-02, 3.9617e-02,
         2.3118e-05, 9.8405e-02, 7.7095e-08, 4.3169e-02, 4.6502e-04, 5.7544e-07,
         6.4688e-05, 5.4826e-02, 1.0420e-03, 1.8859e-02, 2.6921e-03, 1.2840e-02,
         9.3266e-05, 3.4283e-05, 1.1877e-02, 2.1764e-03],
        [1.9033e-03, 8.6364e-04, 7.0604e-01, 1.7623e-03, 2.3934e-08, 7.8307e-07,
         1.9183e-06, 5.3706e-06, 5.0685e-06, 1.1296e-06, 9.8575e-04, 6.1020e-02,
         3.2013e-05, 1.0793e-06, 8.4287e-07, 3.6913e-04, 3.8749e-04, 3.4684e-03,
         6.7285e-03, 1.0346e-02, 3.2454e-04, 3.5863e-07, 8.9820e-07, 1.2108e-06,
         4.9298e-06, 3.6808e-04, 2.9548e-04, 2.7306e-05, 1.4219e-06, 1.4769e-05,
         3.7621e-06, 7.0708e-09, 1.0439e-03, 4.0405e-06, 2.7082e-05, 1.5984e-04,
         5.4959e-01, 4.2063e-03, 2.5836e-06, 2.2183e-04, 5.3957e-03, 5.3863e-04,
         1.2057e-02, 4.4711e-06, 2.2858e-04, 1.3382e-04, 2.0807e-08, 9.3878e-04,
         1.2241e-01, 8.4464e-05, 1.3779e-07, 7.7436e-06, 1.1400e-04, 7.8239e-07,
         3.1977e-08, 2.4119e-06, 8.5134e-01, 2.3974e-03, 1.7961e-05, 4.7211e-06,
         5.5406e-02, 5.2910e-08, 2.1020e-03, 1.0387e-07],
        [2.5372e-05, 2.6015e-06, 2.1917e-03, 7.8042e-06, 5.6758e-03, 1.7486e-05,
         4.5292e-05, 3.0542e-05, 1.9277e-03, 1.9676e-03, 2.0486e-04, 4.7120e-05,
         3.2208e-07, 1.9733e-07, 7.0444e-04, 9.1454e-04, 2.3876e-03, 9.6085e-05,
         1.0436e-07, 1.0330e-02, 2.7340e-05, 1.9038e-06, 4.4794e-02, 3.3270e-04,
         1.5247e-04, 5.8693e-03, 2.2182e-04, 9.6233e-07, 5.3692e-05, 1.4159e-04,
         1.5653e-02, 2.6189e-04, 5.1739e-03, 8.9599e-06, 4.1650e-03, 4.8791e-06,
         1.2413e-04, 1.2686e-01, 6.2877e-06, 1.0958e-06, 1.9051e-03, 1.3343e-05,
         5.1964e-02, 1.4150e-04, 1.7438e-01, 2.6778e-05, 2.5301e-05, 1.1421e-04,
         6.6276e-05, 2.6854e-06, 1.8061e-06, 2.2179e-03, 3.2775e-03, 8.8604e-06,
         6.8815e-06, 1.4606e-04, 1.3811e-05, 1.9119e-03, 3.4502e-04, 3.1242e-03,
         2.6225e-05, 2.4538e-04, 5.7930e-04, 2.4001e-04],
        [2.2119e-04, 9.5870e-01, 1.7076e-06, 5.6993e-04, 6.3541e-06, 3.8325e-03,
         9.0994e-06, 3.4873e-07, 9.5970e-01, 1.7698e-04, 1.0298e-03, 4.2998e-04,
         5.2534e-08, 3.5537e-02, 7.6569e-05, 2.1877e-05, 9.3674e-05, 7.3667e-02,
         1.7720e-01, 6.4727e-03, 1.2857e-05, 3.8948e-11, 4.5782e-03, 3.0018e-04,
         2.3888e-11, 3.1833e-04, 2.6551e-01, 4.2102e-02, 4.3055e-01, 2.5891e-04,
         3.5907e-06, 3.3616e-07, 1.3208e-09, 5.2503e-06, 1.4317e-08, 2.3779e-05,
         3.2018e-08, 7.8659e-02, 7.7465e-06, 2.0961e-07, 4.7021e-06, 2.6526e-04,
         2.0584e-01, 2.2251e-03, 2.7214e-01, 1.3088e-01, 2.1886e-06, 3.8614e-09,
         3.7427e-02, 2.8105e-05, 2.4319e-06, 4.4496e-01, 1.0335e-02, 2.8585e-03,
         3.3717e-09, 1.5676e-04, 5.1613e-04, 8.6213e-03, 4.2484e-06, 5.5350e-01,
         8.7261e-01, 2.3639e-02, 2.7915e-02, 2.9300e-08],
        [2.0950e-03, 1.5393e-04, 4.0569e-04, 2.7118e-04, 5.8681e-03, 1.2620e-01,
         7.7352e-05, 2.0713e-02, 8.9288e-06, 1.1902e-07, 9.1514e-06, 7.3798e-04,
         3.6123e-05, 3.9253e-06, 4.6786e-06, 2.7256e-04, 1.2061e-02, 6.7030e-03,
         2.8421e-05, 5.4552e-05, 1.6009e-06, 8.4361e-04, 3.7110e-04, 1.2760e-02,
         1.1246e-05, 2.9493e-04, 1.9265e-04, 1.6903e-04, 2.0500e-05, 5.9485e-04,
         4.7923e-04, 4.2913e-02, 3.6341e-04, 2.0404e-01, 1.6038e-02, 1.4815e-06,
         4.2168e-04, 3.7979e-06, 1.3568e-03, 4.9429e-06, 1.0615e-04, 3.0701e-05,
         4.1337e-04, 2.2353e-02, 1.7406e-05, 1.9230e-03, 3.0698e-08, 7.5778e-05,
         5.9374e-05, 3.8118e-05, 1.4005e-03, 1.1135e-04, 2.5562e-05, 1.5568e-04,
         9.5294e-05, 1.1410e-02, 4.3791e-05, 2.1598e-01, 4.5862e-04, 1.9607e-03,
         6.4996e-05, 5.2839e-05, 7.7264e-06, 1.1088e-01],
        [2.3274e-06, 7.7974e-08, 3.9888e-03, 3.1559e-06, 4.9264e-02, 2.9401e-04,
         7.9885e-05, 1.4244e-05, 6.5525e-04, 7.5062e-02, 1.5372e-01, 3.4110e-03,
         4.1790e-05, 1.7817e-06, 5.9880e-03, 6.0912e-06, 4.2883e-05, 4.1997e-04,
         1.6934e-06, 4.5703e-04, 1.7232e-07, 1.2751e-06, 8.1659e-04, 7.1552e-02,
         7.0041e-02, 1.4263e-01, 4.3865e-04, 7.2260e-06, 1.1050e-02, 1.0747e-03,
         7.5811e-02, 1.9274e-03, 4.0284e-05, 1.6440e-05, 1.9569e-04, 2.8281e-06,
         9.4464e-07, 1.1192e-03, 1.0642e-03, 9.2037e-05, 4.2807e-06, 1.1707e-04,
         4.4009e-02, 1.4201e-03, 4.4712e-05, 7.1106e-04, 9.1164e-07, 1.1083e-05,
         7.9712e-04, 4.1480e-04, 1.4194e-04, 1.4134e-03, 1.7248e-02, 3.0186e-06,
         2.0215e-07, 6.7467e-02, 4.6402e-07, 2.9839e-05, 3.1294e-02, 4.3763e-03,
         1.4224e-07, 4.3131e-03, 1.2563e-05, 5.6531e-01]], device='cuda:0')