In [1]:
import cutlass
import torch

dtype = torch.float16
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)

In [2]:
import random
random.seed(2023)

# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
    sizes = [(M, K), (K, N), (M, N), (M, N)]
    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]

# Utility function to generate `problems` GEMMs of random sizes
def generate_problems(problems):
    valid_sizes = [128, 256, 512, 1024]
    As, Bs, Cs, Ds = [], [], [], []
    for _ in range(problems):
        M, N, K = [random.choice(valid_sizes) for _ in range(3)]
        A, B, C, D = initialize(dtype, M, N, K)
        As.append(A)
        Bs.append(B)
        Cs.append(C)
        Ds.append(D)
    return As, Bs, Cs, Ds

As, Bs, Cs, Ds, = generate_problems(20)

In [3]:
op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=False)

cd out
TORCH_CUDA_ARCH_LIST="8.0" python setup.py install --user

In [3]:
import sys
sys.path.append("~/.local/lib/python3.11/site-packages")

In [4]:
import torch
import grouped_gemm

grouped_gemm.run(As, Bs)

[tensor([[260., 152., 203.,  ..., 275., 356., 233.],
         [174., 135., 336.,  ..., 196., 521., 215.],
         [299., 239., 220.,  ..., 180., 355., 261.],
         ...,
         [286., 293., 145.,  ..., 198., 371., 264.],
         [120., 214., 238.,  ...,  93., 404.,  69.],
         [354., 190., 242.,  ...,  68., 357., 309.]], device='cuda:0',
        dtype=torch.float16),
 tensor([[ -4., -30., -18.,  ...,   1.,  41., -42.],
         [ 33.,  40.,   7.,  ..., -18.,  27.,  12.],
         [ 46.,  27.,  26.,  ...,  44.,  16.,  60.],
         ...,
         [ 78.,   2.,  68.,  ...,  73., -10.,  64.],
         [ -2., -10.,  62.,  ...,  27.,  32.,  17.],
         [ 33.,  67.,  59.,  ...,  57.,  88.,  25.]], device='cuda:0',
        dtype=torch.float16),
 tensor([[118.,  96., 217.,  ..., 194., 204., 288.],
         [176., 214., 208.,  ..., 139., 211., 164.],
         [ 65., 188., 115.,  ..., 117., 133.,  36.],
         ...,
         [165., 165.,  73.,  ..., 222., 148., 244.],
         [ 67.

In [5]:
Ds = grouped_gemm.run(As, Bs)
Ds_torch = [a @ b for a, b in zip(As, Bs)]
for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch)

In [6]:
num_warmup = 20
num_profile = 100

# Warmup iterations
for _ in range(num_warmup):
    Ds = grouped_gemm.run(As, Bs)
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()

# Timing iterations
import time
grouped = 0
nongrouped = 0
for _ in range(num_profile):
    start = time.time()
    Ds = grouped_gemm.run(As, Bs)
    torch.cuda.synchronize()
    grouped += time.time() - start

    start = time.time()
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()
    nongrouped += time.time() - start

print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))
print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))
print('Speedup: {:.3f}'.format(nongrouped / grouped))

Grouped:     146.914 us
Non-Grouped: 233.657 us
Speedup: 1.590
