In [1]:
import cutlass
import torch
dtype = torch.float16


In [2]:
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)
op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=False)

In [None]:
!cd out && TORCH_CUDA_ARCH_LIST="8.0" python setup.py install --user

In [3]:
import sys
sys.path.append("~/.local/lib/python3.11/site-packages")

In [13]:
# Function to generate a batch of input matrices for the first GEMM
def generate_batch(batch_size, sequence_length, hidden_size):
    M = batch_size * sequence_length
    As = []
    for _ in range(batch_size):
        A = torch.randn(M, hidden_size, device='cuda').to(dtype)
        As.append(A)
    return As

# Parameters
batch_size = 1 # 32
sequence_length = 128
hidden_size = 4096
ffn_dim = 14336

# Generate the batch of input matrices
As = generate_batch(batch_size, sequence_length, hidden_size)

# Create the weight matrix using nn.Linear
w1 = torch.nn.Linear(hidden_size, ffn_dim, bias=False).to('cuda').to(dtype)

# Create a list of weight matrices, one for each input matrix
Bs = [w1.weight] * batch_size

# Check the shapes of the generated matrices
print(f"Shape of each input matrix: {As[0].shape}")
print(f"Shape of each weight matrix: {Bs[0].shape}")

Shape of each input matrix: torch.Size([128, 4096])
Shape of each weight matrix: torch.Size([14336, 4096])


In [14]:
import grouped_gemm
grouped_output = grouped_gemm.run(As, Bs)

In [15]:
# Create an nn.Linear layer with the same weight as w1
linear = torch.nn.Linear(hidden_size, ffn_dim, bias=False).to('cuda').to(dtype)
linear.weight = torch.nn.Parameter(w1.weight)

# Apply the nn.Linear layer to each input matrix
linear_output = [linear(A) for A in As]

In [16]:
# Compare the outputs
for go, lo in zip(grouped_output, linear_output):
    print(f"Grouped GEMM output shape: {go.shape}, nn.Linear output shape: {lo.shape}")
    print(f"Max difference: {torch.max(torch.abs(go - lo))}")

Grouped GEMM output shape: torch.Size([4096, 4096]), nn.Linear output shape: torch.Size([128, 14336])


RuntimeError: The size of tensor a (4096) must match the size of tensor b (14336) at non-singleton dimension 1

In [None]:
Ds = grouped_gemm.run(As, Bs)
Ds_torch = [a @ b for a, b in zip(As, Bs)]
for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch)

In [None]:
num_warmup = 20
num_profile = 100

# Warmup iterations
for _ in range(num_warmup):
    Ds = grouped_gemm.run(As, Bs)
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()

# Timing iterations
import time
grouped = 0
nongrouped = 0
for _ in range(num_profile):
    start = time.time()
    Ds = grouped_gemm.run(As, Bs)
    torch.cuda.synchronize()
    grouped += time.time() - start

    start = time.time()
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()
    nongrouped += time.time() - start

print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))
print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))
print('Speedup: {:.3f}'.format(nongrouped / grouped))