# Fused softmax + matmul



In [1]:
import sys, os
from pathlib import Path

# Add the parent directory of the current notebook to sys.path
cur_dir = Path().resolve()
parent_dir = cur_dir.parent
sys.path += [str(parent_dir), str(cur_dir)]


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

TEST_SIZES = [
    (32, 16, 64),
    (3, 16, 65),
    (90, 2, 3),
    (90, 2, 56),
    (90, 2, 32),
    (90, 2, 31),
    (90, 2, 64),
    (90, 2, 65),
    (90, 32, 64),
    (90, 32, 65),
    (32, 16, 64),
    (32, 16, 65),
    (32, 2, 3),
    (32, 2, 56),
    (32, 2, 32),
    (32, 2, 31),
    (32, 2, 64),
    (32, 2, 65),
    (1024, 256, 56),
    (1024, 256, 768),
]

In [2]:
import torch
from utils import cdiv, get_sig, load_cuda, profile_kernel
from collections import namedtuple

# def test_allclose_old(kernels, sizes=(32, 16, 64)):
#     N, L, M = sizes
#     Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
#     K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
#     O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
#     for kernel_name, kernel_data in kernels.items():
#         if kernel_name!="torch":
#             module, fname = kernel_data["module"], kernel_data["fname"]
#             O = getattr(module, fname)(Q_small, K_small)
#             if not torch.allclose(O, O_torch, atol=1e-4):
#                 SIZE = 3
#                 raise ValueError(f"{kernel_name=} failed:\n\n {O[:SIZE, :SIZE]=}\n\n, {O_torch[:SIZE, :SIZE]=}")
#             print(f"{kernel_name=} agrees with torch softmax")


def get_Q_K(sizes, Q_small, K_small, use_int):
    N, L, M = sizes
    if Q_small is None:
        if use_int:
            Q_small = torch.randint(low=10, high=20, size=(N, L)).contiguous().cuda().to(torch.float32)
        else:
            Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
    if K_small is None:
        if use_int:
            K_small = torch.randint(low=10, high=20, size=(L, M)).contiguous().cuda().to(torch.float32)
        else:
            K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
    return Q_small, K_small
 
def test_allclose(
        kernels, 
        raise_error=True, 
        sizes=(32, 16, 64), 
        size_show=3, 
        Q_small=None, 
        K_small=None,
        verbose=True,
        atol=1e-4,
        use_int=False,
    ):
   
    Q_small, K_small = get_Q_K(sizes, Q_small, K_small, use_int)
        
    # run torch
    O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
    results = {"torch": O_torch}

    for kernel_name, kernel_data in kernels.items():
        if kernel_name != "torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            results[fname] = O
            if not torch.allclose(O, O_torch, atol=atol):
                if verbose:
                    print(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
                if raise_error:
                    raise ValueError(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
            if verbose:
                print(f"{kernel_name=} agrees with torch softmax")
    return results

def test_is_normalised(
        kernels, 
        raise_error=True, 
        sizes=(32, 16, 64), 
        size_show=3, 
        Q_small=None, 
        K_small=None,
        verbose=True,
        atol=1e-4,
        use_int=False,
    ):
    Q_small, K_small = get_Q_K(sizes, Q_small, K_small, use_int)
    
    for kernel_name, kernel_data in kernels.items():
        if kernel_name != "torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            results[fname] = O
            if not torch.allclose(O.sum(dim=1), torch.ones_like(O.sum(dim=1)), atol=atol):
                if verbose:
                    print(f"{kernel_name=} rows do not sum up to 1:\n\n {O[:size_show, :size_show]=}")
                if raise_error:
                    raise ValueError(f"{kernel_name=}  rows do not sum up to 1:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
            if verbose:
                print(f"{kernel_name=} rows sum up to 1")
    return results

def profile_kernels(kernels, test_sizes=TEST_SIZES):
    for sizes in TEST_SIZES:
        test_allclose(kernels, sizes=sizes, raise_error=True)
    for kernel_name, kernel_data in kernels.items():
        print(f"Profiling: {kernel_name}")
        profile_kernel(kernel_data["module"], kernel_data["fname"], *kernel_data["args"], **kernel_data["kwargs"])


## Python cuda looking implementation

In [3]:

class TorchNaiveSoftmaxAndMatMul(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def softmax_and_matmul(self, Q, K):
        O = Q@K
        return torch.softmax(O, dim=1)

# N, L, M = 32, 16, 64
# Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
# K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
# model = TorchNaiveSoftmaxAndMatMul()
# output = model.softmax_and_matmul(Q_small, K_small)

# print("Input Q shape:", Q_small.shape)
# print("Input K shape:", K_small.shape)
# print("Output shape :", output.shape)
# print("Output:", output)

## Cuda 

In [4]:
def get_modules(kernels):
    for kernel_name, kernel_data in kernels.items():

        fname = kernel_data["fname"]
        cuda_source = Path(kernel_data["cuda_source_path"]).read_text()
        cpp_source = get_sig(fname, cuda_source)
        module = load_cuda(cuda_source, cpp_source, funcs=[fname])
        kernel_data["module"] = module


def get_softmax_modules(kernels):
    get_modules(kernels)
    kernels["torch"] = {
        "module": TorchNaiveSoftmaxAndMatMul(),
        "fname": "softmax_and_matmul",
        "kwargs": {},
    }

def add_args_kwargs(kernels, *args, **kwargs):
    for kernel_name, kernel_data in kernels.items():
        kernel_data["args"] = args
        kernel_data["kwargs"] = kwargs
        


In [5]:
N, L, M = 1024, 256, 768 +1
Q = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

kernels = {
    "fused_softmax_matmul": dict(cuda_source_path = "./fused_softmax_matmul.cu", fname = "fused_softmax_matmul"),
}
get_softmax_modules(kernels)
add_args_kwargs(kernels, Q, K)


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


## Simple example

In [None]:
module = kernels["fused_softmax_matmul"]["module"]
fname = kernels["fused_softmax_matmul"]["fname"]
Q = torch.tensor([[1, 2],
                  [3, 4]], dtype=torch.float32).contiguous().cuda()

K = torch.tensor([[1, 2],
                  [3, 4]], dtype=torch.float32).contiguous().cuda()

add_number = 1
Q = Q + add_number * torch.sign(Q)

K = K + add_number * torch.sign(K)
print("Q:\n", Q)
print("K:\n", K)

O = getattr(module, fname)(Q, K)
O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q, K)

print("O")
print(O)
print("O_torch")
print(O_torch)
print("O sum")
print(O.sum(dim=1))
print("O_torch sum")
print(O_torch.sum(dim=1))

Q:
 tensor([[2., 3.],
        [4., 5.]], device='cuda:0')
K:
 tensor([[2., 3.],
        [4., 5.]], device='cuda:0')
O
tensor([[6.6929e-03, 9.9331e-01],
        [1.2339e-04, 9.9988e-01]], device='cuda:0')
O_torch
tensor([[6.6929e-03, 9.9331e-01],
        [1.2339e-04, 9.9988e-01]], device='cuda:0')
O sum
tensor([1., 1.], device='cuda:0')
O_torch sum
tensor([1., 1.], device='cuda:0')


## Profile

In [None]:
profile_kernels(kernels)

kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[7.6237e-08, 5.1181e-07, 2.3512e-06],
        [2.2461e-03, 4.7643e-05, 1.4328e-05],
        [5.5389e-08, 9.9053e-02, 1.5366e-08]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[7.6252e-08, 5.1191e-07, 2.3516e-06],
        [2.2581e-03, 4.7899e-05, 1.4405e-05],
        [5.5434e-08, 9.9132e-02, 1.5379e-08]], device='cuda:0')


ValueError: kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[7.6237e-08, 5.1181e-07, 2.3512e-06],
        [2.2461e-03, 4.7643e-05, 1.4328e-05],
        [5.5389e-08, 9.9053e-02, 1.5366e-08]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[7.6252e-08, 5.1191e-07, 2.3516e-06],
        [2.2581e-03, 4.7899e-05, 1.4405e-05],
        [5.5434e-08, 9.9132e-02, 1.5379e-08]], device='cuda:0')

In [None]:


# Ones matrices
sizes = 1024, 256, 56
N, L, M = sizes
results = test_allclose(kernels, raise_error=False, 
# sizes=(32, 16, 64), 
# sizes=(90, 2, 3), 
# sizes=(1024, 256, 768), 
size_show=None,
Q_small = torch.ones(N, L).contiguous().cuda(),
K_small = torch.ones(L, M).contiguous().cuda(),
)
O, O_torch = results["fused_softmax_matmul"], results["torch"] 
        

kernel_name='fused_softmax_matmul' agrees with torch softmax


In [None]:
O_torch

tensor([[0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        ...,
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179,  ..., 0.0179, 0.0179, 0.0179]],
       device='cuda:0')

In [None]:
O_torch.sum(dim=1)[:5]

tensor([1., 1., 1., 1., 1.], device='cuda:0')

In [None]:
O.sum(dim=1)[:5]

tensor([1., 1., 1., 1., 1.], device='cuda:0')

In [None]:
torch.allclose(O, O_torch, atol=1e-4)

True

In [None]:
O.sum(dim=1)

tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')

In [None]:
O[:3, :3]

tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')

In [None]:
O_torch[:3, :3]

tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')

In [None]:
test_allclose(kernels, 
sizes=(32, 16, 64),
)

kernel_name='fused_softmax_matmul' agrees with torch softmax


{'torch': tensor([[8.7599e-08, 1.1568e-04, 8.5495e-05,  ..., 3.2460e-10, 3.2980e-08,
          4.0416e-11],
         [1.2988e-03, 8.0773e-06, 1.1501e-01,  ..., 7.0932e-04, 1.0986e-04,
          9.1397e-04],
         [6.8656e-04, 1.0944e-05, 1.6643e-04,  ..., 5.3404e-05, 6.4530e-06,
          1.3815e-06],
         ...,
         [5.9146e-07, 3.2172e-05, 1.4822e-05,  ..., 2.2011e-03, 1.3360e-07,
          3.7801e-05],
         [4.4553e-02, 1.6476e-04, 3.6103e-05,  ..., 5.0072e-01, 1.7543e-03,
          4.9246e-06],
         [1.3393e-06, 1.4504e-01, 8.2162e-04,  ..., 7.9244e-07, 1.7903e-05,
          3.0058e-03]], device='cuda:0'),
 'fused_softmax_matmul': tensor([[8.7599e-08, 1.1568e-04, 8.5495e-05,  ..., 3.2460e-10, 3.2980e-08,
          4.0416e-11],
         [1.2988e-03, 8.0773e-06, 1.1501e-01,  ..., 7.0932e-04, 1.0986e-04,
          9.1397e-04],
         [6.8656e-04, 1.0944e-05, 1.6643e-04,  ..., 5.3404e-05, 6.4530e-06,
          1.3815e-06],
         ...,
         [5.9146e-07, 3.2172e

### Test 1s matrices
See that with matrices of one thoingd are fine for high tolerance

In [None]:
# Ones matrices
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_allclose(kernels, raise_error=True, 
        size_show=5,
        Q_small = torch.ones(N, L).contiguous().cuda(),
        K_small = torch.ones(L, M).contiguous().cuda(),
        verbose=False,
        atol=1e-3,

        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333]], device='cuda:0')
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444],
        [0.1444, 0.1444, 0.1444]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333]], device='cuda:0')
{(3, 16, 65): True,
 (32, 2, 3): False

### Test random matrices

In [None]:
#  Random matrices
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_allclose(
            kernels, 
            raise_error=True, 
            size_show=3,
            sizes=sizes,
            verbose=False,
            use_int=True,
            atol=1e-2,
        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
{(3, 16, 65): True,
 (32, 2, 3): True,
 (32, 2, 31): True,
 (32, 2, 32): True,
 (32, 2, 56): True,
 (32, 2, 64): True,
 (32, 2, 65): True,
 (32, 16, 64): True,
 (32, 16, 65): True,
 (90, 2, 3): True,
 (90, 2, 31): True,
 (90, 2, 32): True,
 (90, 2, 56): True,
 (90, 2, 64): True,
 (90, 2, 65): True,
 (90, 32, 64): True,
 (90, 32, 65): True,
 (1024, 256, 56): False,
 (1024, 256, 768): False}


In [None]:
#  Random matrices
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_is_normalised(
            kernels, 
            raise_error=True, 
            size_show=3,
            sizes=sizes,
            verbose=False,
            use_int=False,
        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul'  rows do not sum up to 1:

 O[:size_show, :size_show]=tensor([[2.8097e-06, 1.7695e-06, 2.0064e-07],
        [2.4593e-06, 8.6276e-06, 6.6717e-05],
        [5.7646e-07, 1.5800e-06, 7.6471e-06]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')
kernel_name='fused_softmax_matmul'  rows do not sum up to 1:

 O[:size_show, :size_show]=tensor([[0.0038, 0.0762, 0.0170],
        [0.0011, 0.0651, 0.0190],
        [0.0404, 0.0292, 0.0323]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179]], device='cuda:0')
kernel_name='fused_softmax_matmul'  rows do not sum up to 1:

 O[:size_show, :size_show]=tensor([[0.0005, 0.0179, 0.2536],
        [0.0324, 0.0333, 0.1303],
        [0.0094, 0.0424, 0.1927]], device='cuda:0')

 O_torch[:size_show, :size_s