In [1]:
import cutlass
import torch

dtype = torch.float32
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
#op = plan.construct()
#grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out_32', jit=True)
import grouped_gemm

In [4]:
import random
import torch
# 读取.txt文件并生成a, b的矩阵维度列表
def read_file_and_generate_matrices(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    
    matrices_a = []  # 存储矩阵a的维度列表
    matrices_b = []  # 存储矩阵b的维度列表
    matrices_c = []  # 存储矩阵a的维度列表
    matrices_d = []  # 存储矩阵b的维度列表
    
    for line in lines:
        # 分割索引和矩阵维度
        index, dimensions = line.strip().split()
        dimensions = dimensions.split('x')

        # 矩阵a是第一个维度，矩阵b是第二个维度
        matrix_a = (int(dimensions[0]), int(dimensions[1]))
        matrix_b = (int(dimensions[1]), int(dimensions[2]))
        matrix_c = (int(dimensions[0]), int(dimensions[2]))
        matrix_d = (int(dimensions[0]), int(dimensions[2]))
        
        
        # 添加矩阵维度到列表
        matrices_a.append(torch.randint(-3, 3, matrix_a, device='cuda').to(dtype))
        matrices_b.append(torch.randint(-3, 3, matrix_b, device='cuda').to(dtype))
        matrices_c.append(torch.randint(-3, 3, matrix_c, device='cuda').to(dtype))
        matrices_d.append(torch.randint(-3, 3, matrix_d, device='cuda').to(dtype))
        
    return matrices_a, matrices_b,matrices_c, matrices_d

# 指定.txt文件路径
file_path = 'test.txt'  # 替换为实际文件路径

As, Bs, Cs, Ds = read_file_and_generate_matrices(file_path)

In [5]:
Ds = grouped_gemm.run(As, Bs)
print(Ds[0])
Ds_torch = [a @ b for a, b in zip(As, Bs)]
for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch)

tensor([[8.]], device='cuda:0')


Finally, we can profile our grouped GEMM extension:

In [6]:
num_warmup = 20
num_profile = 10000

# Warmup iterations
for _ in range(num_warmup):
    Ds = grouped_gemm.run(As, Bs)
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()

# Timing iterations
import time
grouped = 0
nongrouped = 0
for _ in range(num_profile):
    start = time.time()
    Ds = grouped_gemm.run(As, Bs)
    torch.cuda.synchronize()
    grouped += time.time() - start

    start = time.time()
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()
    nongrouped += time.time() - start

print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))
print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))
print('Speedup: {:.3f}'.format(nongrouped / grouped))

Grouped:     121.922 us
Non-Grouped: 216.060 us
Speedup: 1.772
