# Fused softmax + matmul



In [1]:
import sys, os
from pathlib import Path

# Add the parent directory of the current notebook to sys.path
cur_dir = Path().resolve()
parent_dir = cur_dir.parent
sys.path += [str(parent_dir), str(cur_dir)]


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

TEST_SIZES = [
    (32, 16, 64),
    (3, 16, 65),
    (90, 2, 3),
    (90, 2, 56),
    (90, 2, 32),
    (90, 2, 31),
    (90, 2, 64),
    (90, 2, 65),
    (90, 32, 64),
    (90, 32, 65),
    (32, 16, 64),
    (32, 16, 65),
    (32, 2, 3),
    (32, 2, 56),
    (32, 2, 32),
    (32, 2, 31),
    (32, 2, 64),
    (32, 2, 65),
    (1024, 256, 56),
    (1024, 256, 768),
]

In [2]:
import torch
from utils import cdiv, get_sig, load_cuda, profile_kernel
from collections import namedtuple


def get_Q_K(sizes, Q_small, K_small, use_int):
    N, L, M = sizes
    if Q_small is None:
        if use_int:
            Q_small = torch.randint(low=10, high=20, size=(N, L)).contiguous().cuda().to(torch.float32)
        else:
            Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
    if K_small is None:
        if use_int:
            K_small = torch.randint(low=10, high=20, size=(L, M)).contiguous().cuda().to(torch.float32)
        else:
            K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
    return Q_small, K_small
 
def test_allclose(
        kernels, 
        raise_error=True, 
        sizes=(32, 16, 64), 
        size_show=3, 
        Q_small=None, 
        K_small=None,
        verbose=True,
        atol=1e-4,
        use_int=False,
    ):
   
    Q_small, K_small = get_Q_K(sizes, Q_small, K_small, use_int)
        
    # run torch
    O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q_small, K_small)
    results = {"torch": O_torch}

    for kernel_name, kernel_data in kernels.items():
        if kernel_name != "torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            results[fname] = O
            if not torch.allclose(O, O_torch, atol=atol):
                if verbose:
                    print(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
                if raise_error:
                    raise ValueError(f"{kernel_name=} failed:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
            if verbose:
                print(f"{kernel_name=} agrees with torch softmax")
    return results

def test_is_normalised(
        kernels, 
        raise_error=True, 
        sizes=(32, 16, 64), 
        size_show=3, 
        Q_small=None, 
        K_small=None,
        verbose=True,
        atol=1e-4,
        use_int=False,
    ):
    Q_small, K_small = get_Q_K(sizes, Q_small, K_small, use_int)
    
    for kernel_name, kernel_data in kernels.items():
        if kernel_name != "torch":
            module, fname = kernel_data["module"], kernel_data["fname"]
            O = getattr(module, fname)(Q_small, K_small)
            results[fname] = O
            if not torch.allclose(O.sum(dim=1), torch.ones_like(O.sum(dim=1)), atol=atol):
                if verbose:
                    print(f"{kernel_name=} rows do not sum up to 1:\n\n {O[:size_show, :size_show]=}")
                if raise_error:
                    raise ValueError(f"{kernel_name=}  rows do not sum up to 1:\n\n {O[:size_show, :size_show]=}\n\n {O_torch[:size_show, :size_show]=}")
            if verbose:
                print(f"{kernel_name=} rows sum up to 1")
    return results

def profile_kernels(kernels, test_sizes=TEST_SIZES, test_kwargs={}):
    for sizes in TEST_SIZES:
        test_allclose(kernels, sizes=sizes, raise_error=True, **test_kwargs)
    for kernel_name, kernel_data in kernels.items():
        print(f"Profiling: {kernel_name}")
        profile_kernel(kernel_data["module"], kernel_data["fname"], *kernel_data["args"], **kernel_data["kwargs"])


## Python cuda looking implementation

In [3]:

class TorchNaiveSoftmaxAndMatMul(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def softmax_and_matmul(self, Q, K):
        O = Q@K
        return torch.softmax(O, dim=1)

# N, L, M = 32, 16, 64
# Q_small = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
# K_small = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()
# model = TorchNaiveSoftmaxAndMatMul()
# output = model.softmax_and_matmul(Q_small, K_small)

# print("Input Q shape:", Q_small.shape)
# print("Input K shape:", K_small.shape)
# print("Output shape :", output.shape)
# print("Output:", output)

## Cuda 

In [4]:
def get_modules(kernels):
    for kernel_name, kernel_data in kernels.items():

        fname = kernel_data["fname"]
        cuda_source = Path(kernel_data["cuda_source_path"]).read_text()
        cpp_source = get_sig(fname, cuda_source)
        module = load_cuda(cuda_source, cpp_source, funcs=[fname])
        kernel_data["module"] = module


def get_softmax_modules(kernels):
    get_modules(kernels)
    kernels["torch"] = {
        "module": TorchNaiveSoftmaxAndMatMul(),
        "fname": "softmax_and_matmul",
        "kwargs": {},
    }

def add_args_kwargs(kernels, *args, **kwargs):
    for kernel_name, kernel_data in kernels.items():
        kernel_data["args"] = args
        kernel_data["kwargs"] = kwargs
        


In [5]:
N, L, M = 1024, 256, 768 +1
Q = torch.randn(N, L, dtype=torch.float32).contiguous().cuda()
K = torch.randn(L, M, dtype=torch.float32).contiguous().cuda()

kernels = {
    "fused_softmax_matmul": dict(cuda_source_path = "./fused_softmax_matmul.cu", fname = "fused_softmax_matmul"),
}
get_softmax_modules(kernels)
add_args_kwargs(kernels, Q, K)


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


## Simple example

In [6]:
module = kernels["fused_softmax_matmul"]["module"]
fname = kernels["fused_softmax_matmul"]["fname"]
Q = torch.tensor([[1, 2],
                  [3, 4]], dtype=torch.float32).contiguous().cuda()

K = torch.tensor([[1, 2],
                  [3, 4]], dtype=torch.float32).contiguous().cuda()

add_number = 1
Q = Q + add_number * torch.sign(Q)

K = K + add_number * torch.sign(K)
print("Q:\n", Q)
print("K:\n", K)

O = getattr(module, fname)(Q, K)
O_torch = TorchNaiveSoftmaxAndMatMul().softmax_and_matmul(Q, K)

print("O")
print(O)
print("O_torch")
print(O_torch)
print("O sum")
print(O.sum(dim=1))
print("O_torch sum")
print(O_torch.sum(dim=1))

Q:
 tensor([[2., 3.],
        [4., 5.]], device='cuda:0')
K:
 tensor([[2., 3.],
        [4., 5.]], device='cuda:0')
O
tensor([[6.6929e-03, 9.9331e-01],
        [1.2339e-04, 9.9988e-01]], device='cuda:0')
O_torch
tensor([[6.6929e-03, 9.9331e-01],
        [1.2339e-04, 9.9988e-01]], device='cuda:0')
O sum
tensor([1., 1.], device='cuda:0')
O_torch sum
tensor([1., 1.], device='cuda:0')


## Profile

In [9]:
profile_kernels(
    kernels, 
    test_kwargs={
        "atol":1e-0, # Trivialise tests as larger matrices have large approximation errors
        "use_int": True,
    }
) 

kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softmax_matmul' agrees with torch softmax
kernel_name='fused_softm

### Test 1s matrices
See that with matrices of one thoingd are fine for high tolerance

In [None]:
# Ones matrices
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_allclose(kernels, raise_error=True, 
        size_show=5,
        Q_small = torch.ones(N, L).contiguous().cuda(),
        K_small = torch.ones(L, M).contiguous().cuda(),
        verbose=False,
        atol=1e-3,

        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824],
        [0.0824, 0.0824, 0.0824]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333],
        [0.3333, 0.3333, 0.3333]], device='cuda:0')
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[0.0161, 0.0161, 0.0161, 0.0161, 0.0161],
        [0.0161, 0.0161, 0.0161, 0.0161, 0.0161],
        [0.0161, 0.0161, 0.0161, 0.0161, 0.0161],
        [0.0161, 0.0161, 0.0161, 0.0161, 0.0161],
        [0.0161, 0.0161, 0.0161, 0.0161, 0.0161]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0.0179, 0.0179, 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179, 0.0179, 0.0179],
        [0.0179, 0.0179, 0.0179, 0.0179, 0.017

### Test random matrices

In [None]:
#  Random matrices
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_allclose(
            kernels, 
            raise_error=True, 
            size_show=3,
            sizes=sizes,
            verbose=False,
            use_int=True,
            atol=1e-2,
        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[1.8746e-24, 1.1830e-20, 1.1598e-22],
        [1.4867e-25, 5.8047e-20, 8.5556e-25],
        [5.9265e-23, 1.2798e-20, 5.5125e-24]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='cuda:0')
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[5.8067e-04, 6.4352e-03, 5.1341e-13],
        [1.5850e-04, 3.3223e-02, 7.4096e-15],
        [1.4823e-03, 1.1002e-02, 4.2304e-11]], device='cuda:0')

 O_torch[:size_show, :size_show]=tensor([[1.4146e-21, 9.2884e-14, 0.0000e+00],
        [4.1450e-25, 9.7566e-08, 0.0000e+00],
        [4.1638e-18, 1.3612e-11, 0.0000e+00]], device='cuda:0')
kernel_name='fused_softmax_matmul' failed:

 O[:size_show, :size_show]=tensor([[3.5035e-13, 8.4036e-11, 1.2350e-17],
        [3.9179e-11, 6.5400e-07, 4.7177e-13],
        [8.2501e-12, 1.3833e-08, 7.0070e-15]], device='cuda:0')

 O_torch[:

### Test random matrices row are normalised

In [11]:
#  Random matrices are normalised
results = {}
for sizes in TEST_SIZES:
    try:
        N, L, M = sizes
        test_is_normalised(
            kernels, 
            raise_error=True, 
            size_show=3,
            sizes=sizes,
            verbose=False,
            use_int=True, # Many fails if False
        )
        results[sizes] = True
    except Exception as e:
        print(str(e))
        results[sizes] = False
from pprint import  pprint
pprint(results)

{'fused_softmax_matmul': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 (3, 16, 65): True,
 (32, 2, 3): True,
 (32, 2, 31): True,
 (32, 2, 32): True,
 (32, 2, 56): True,
 (32, 2, 64): True,
 (32, 2, 65): True,
 (32, 16, 64): True,
 (32, 16, 65): True,
 (90, 2, 3): True,
 (90, 2, 31): True,
 (90, 2, 32): True,
 (90, 2, 56): True,
 (90, 2, 64): True,
 (90, 2, 65): True,
 (90, 32, 64): True,
 (90, 32, 65): True,
 (1024, 256, 56): True,
 (1024, 256, 768): True}


# Appendix

In [None]:


# Ones matrices
sizes = 1024, 256, 56
N, L, M = sizes
results = test_allclose(kernels, raise_error=False, 
# sizes=(32, 16, 64), 
# sizes=(90, 2, 3), 
# sizes=(1024, 256, 768), 
size_show=None,
Q_small = torch.ones(N, L).contiguous().cuda(),
K_small = torch.ones(L, M).contiguous().cuda(),
)
O, O_torch = results["fused_softmax_matmul"], results["torch"] 
        
O_torch
O_torch.sum(dim=1)[:5]
O.sum(dim=1)[:5]
torch.allclose(O, O_torch, atol=1e-4)
O.sum(dim=1)
O[:3, :3]
O_torch[:3, :3]
test_allclose(kernels, 
sizes=(32, 16, 64),
)