# Fused softmax + matmul



In [1]:
import sys, os
from pathlib import Path

# Add the parent directory of the current notebook to sys.path
cur_dir = Path().resolve()
parent_dir = cur_dir.parent
sys.path += [str(parent_dir), str(cur_dir)]


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
import torch
from utils import cdiv, get_sig, load_cuda, profile_kernel
from collections import namedtuple

def test_allclose(kernels, sizes=(32, 16, 64)):
    N, L, M = sizes
    Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
    K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
    O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
    for kernel_name, kernel_data in kernels.items():
        if kernel_name!="torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            if not torch.allclose(O, O_torch, atol=1e-4):
                SIZE = 3
                raise ValueError(f"{kernel_name=} failed:\n\n {O[:SIZE, :SIZE]=}\n\n, {O_torch[:SIZE, :SIZE]=}")
            print(f"{kernel_name=} agrees with torch softmax")
        


def profile_kernels(kernels):
    test_allclose(kernels)
    for kernel_name, kernel_data in kernels.items():
        print(f"Profiling: {kernel_name}")
        profile_kernel(kernel_data["module"], kernel_data["fname"], *kernel_data["args"], **kernel_data["kwargs"])


## Python cuda looking implementation

In [3]:

class TorchNaiveSoftmaxAndMatMul(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def softmax_and_matmul(self, Q, K):
        O = Q@K
        return torch.softmax(O, dim=1)

# N, L, M = 32, 16, 64
# Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
# K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
# model = TorchNaiveSoftmaxAndMatMul()
# output = model.softmax_and_matmul(Q_small, K_small)

# print("Input Q shape:", Q_small.shape)
# print("Input K shape:", K_small.shape)
# print("Output shape :", output.shape)
# print("Output:", output)

## Cuda 

In [4]:
def get_modules(kernels):
    for kernel_name, kernel_data in kernels.items():

        fname = kernel_data["fname"]
        cuda_source = Path(kernel_data["cuda_source_path"]).read_text()
        cpp_source = get_sig(fname, cuda_source)
        module = load_cuda(cuda_source, cpp_source, funcs=[fname])
        kernel_data["module"] = module


def get_softmax_modules(kernels):
    get_modules(kernels)
    kernels["torch"] = {
        "module": TorchNaiveSoftmaxAndMatMul(),
        "fname": "softmax_and_matmul",
        "kwargs": {},
    }

def add_args_kwargs(kernels, *args, **kwargs):
    for kernel_name, kernel_data in kernels.items():
        kernel_data["args"] = args
        kernel_data["kwargs"] = kwargs
        


In [5]:
N, L, M = 1024, 256, 768 +1
Q = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

kernels = {
    "fused_softmax_matmul": dict(cuda_source_path = "./fused_softmax_matmul.cu", fname = "fused_softmax_matmul"),
}
get_softmax_modules(kernels)
add_args_kwargs(kernels, Q, K)


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


## Profile

In [6]:
profile_kernels(kernels)

kernel_name='fused_softmax_matmul' agrees with torch softmax
Profiling: fused_softmax_matmul
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
fused_softmax_matmul_kernel(float*, float*, float*, ...         0.00%       0.000us         0.00%       0.000us       0.000us       4.307ms        99.90%       4.307ms       4.307ms             1  
                                            aten::zeros         0.66%      39.750us         2.38%     143.801us   

In [7]:
def test_allclose_new(kernels, not_raise=False, sizes=(32, 16, 64), size_show=3, Q_small=None, K_small=None):
    N, L, M = sizes
    if Q_small is None:
        Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
    if K_small is None:
        K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
    O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
    results = {"torch": O_torch}
    for kernel_name, kernel_data in kernels.items():
        if kernel_name!="torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            results[fname] = O
            if not torch.allclose(O, O_torch, atol=1e-4):
                if not_raise:
                    print(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
                else:
                    ValueError(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")

            print(f"{kernel_name=} agrees with torch softmax")
    return results

# Ones matrices
sizes = 1024, 256, 56
N, L, M = sizes
results = test_allclose_new(kernels, not_raise=True, 
# sizes=(32, 16, 64), 
# sizes=(90, 2, 3), 
# sizes=(1024, 256, 768), 
size_show=None,
Q_small = torch.ones(N, L).contiguous().cuda(),
K_small = torch.ones(L, M).contiguous().cuda(),
)
O, O_torch = results["fused_softmax_matmul"], results["torch"] 
        

kernel_name='fused_softmax_matmul' agrees with torch softmax


In [8]:
O_torch

tensor([[0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        ...,
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179]],
       device='cuda:0')

In [9]:
O_torch.sum(dim=1)[:5]

tensor([1., 1., 1., 1., 1.], device='cuda:0')

In [10]:
O.sum(dim=1)[:5]

tensor([1., 1., 1., 1., 1.], device='cuda:0')

In [11]:
torch.allclose(O, O_torch, atol=1e-4)

True

In [12]:
O.sum(dim=1)

tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')

In [13]:
O[:3, :3]

tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')

In [14]:
O_torch[:3, :3]

tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')

In [15]:
test_allclose(kernels, 
sizes=(32, 16, 64),
)

kernel_name='fused_softmax_matmul' agrees with torch softmax


### TEst 1s matrices

In [16]:
# Ones matrices
sizes_list = [
    (1024, 256, 56),
    (90, 2, 56),
    (90, 2, 32),
    (90, 2, 31),
    (90, 2, 3),
    (1024, 256, 768),
    (32, 16, 64),
    (32, 16, 65),
    (3, 16, 65),
]
results = {}
for sizes in sizes_list:
    try:
        N, L, M = sizes
        test_allclose_new(kernels, not_raise=False, 
        size_show=None,
        Q_small = torch.ones(N, L).contiguous().cuda(),
        K_small = torch.ones(L, M).contiguous().cuda(),
        )
        results[sizes] = True
    except:
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
{(3, 16, 65): True,
 (32, 16, 64): True,
 (32, 16, 65): True,
 (90, 2, 3): True,
 (90, 2, 31): True,
 (90, 2, 32): True,
 (90, 2, 56): True,
 (1024, 256, 56): True,
 (1024, 256, 768): True}


In [17]:
#  Random matrices
sizes_list = [
    (1024, 256, 56),
    (90, 2, 56),
    (90, 2, 32),
    (90, 2, 31),
    (90, 2, 3),
    (1024, 256, 768),
    (32, 16, 64),
    (32, 16, 65),
    (3, 16, 65),
]
results = {}
for sizes in sizes_list:
    try:
        N, L, M = sizes
        test_allclose_new(
            kernels, 
            not_raise=False, 
            size_show=3,
            sizes=sizes,
        )
        results[sizes] = True
    except:
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
{(3, 16, 65): True,
 (32, 16, 64): True,
 (32, 16, 65): True,
 (90, 2, 3): True,
 (90, 2, 31): True,
 (90, 2, 32): True,
 (90, 2, 56): True,
 (1024, 256, 56): True,
 (1024, 256, 768): True}
