# QGEMM Benchmark Notebook

This notebook provides comprehensive benchmarking for QGEMM kernels including:
- Performance comparison with PyTorch and vendor libraries (cuBLAS/rocBLAS)
- Testing custom GEMM kernels (including CUTLASS integration)
- Visualization of results and performance analysis
- Correctness verification of custom kernels

## Requirements
- PyTorch with CUDA/ROCm support
- Build QGEMM with `-DBUILD_TESTS=ON`
- For test kernels: `-DBUILD_TEST_KERNELS=ON`

In [None]:
import sys
import os
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from typing import List, Dict, Tuple

# Add build directory to path for loading compiled modules
sys.path.append('../build/lib')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    
    # Memory info
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1024**3
        memory_reserved = torch.cuda.memory_reserved() / 1024**3
        memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"GPU Memory - Allocated: {memory_allocated:.2f} GB, Reserved: {memory_reserved:.2f} GB, Total: {memory_total:.2f} GB")

In [None]:
# Try to import test kernels if available
test_kernels_available = False
try:
    import qgemm_test_kernels_python
    test_kernels_available = True
    print("✅ Test kernels loaded successfully")
    print("   Available functions:", dir(qgemm_test_kernels_python))
except ImportError as e:
    print(f"⚠️  Test kernels not available: {e}")
    print("   Build with -DBUILD_TEST_KERNELS=ON to enable test kernels")

# Try to import main QGEMM module
qgemm_available = False
try:
    import qgemm_python
    qgemm_available = True
    print("✅ QGEMM Python module loaded successfully")
    print("   Available functions:", dir(qgemm_python))
except ImportError as e:
    print(f"⚠️  QGEMM Python module not available: {e}")
    print("   Make sure the project is built and the build/lib directory is in your path")

In [None]:
class GemmBenchmark:
    """GEMM Benchmark utility class"""
    
    def __init__(self, device='cuda', warmup_runs=5, benchmark_runs=10):
        self.device = device
        self.warmup_runs = warmup_runs
        self.benchmark_runs = benchmark_runs
        self.results = []
    
    def benchmark_function(self, func, *args, **kwargs):
        """Benchmark a function with warmup and multiple runs"""
        
        # Warmup
        for _ in range(self.warmup_runs):
            result = func(*args, **kwargs)
            if self.device == 'cuda':
                torch.cuda.synchronize()
        
        # Benchmark
        times = []
        for _ in range(self.benchmark_runs):
            start_time = time.perf_counter()
            result = func(*args, **kwargs)
            if self.device == 'cuda':
                torch.cuda.synchronize()
            end_time = time.perf_counter()
            times.append((end_time - start_time) * 1000)  # Convert to ms
        
        return np.mean(times), np.std(times), result
    
    def calculate_tflops(self, M, N, K, time_ms):
        """Calculate TFLOPS from matrix dimensions and time"""
        flops = 2 * M * N * K  # 2 operations per multiply-add
        time_s = time_ms / 1000.0
        return (flops / time_s) / 1e12
    
    def test_pytorch_gemm(self, M, N, K, dtype=torch.float16):
        """Test PyTorch GEMM"""
        A = torch.randn(M, K, dtype=dtype, device=self.device)
        B = torch.randn(K, N, dtype=dtype, device=self.device)
        
        avg_time, std_time, result = self.benchmark_function(
            torch.mm, A, B
        )
        
        tflops = self.calculate_tflops(M, N, K, avg_time)
        
        return {
            'kernel': 'PyTorch',
            'M': M, 'N': N, 'K': K,
            'dtype': str(dtype).split('.')[-1],
            'time_ms': avg_time,
            'std_ms': std_time,
            'tflops': tflops
        }
    
    def test_cutlass_gemm(self, M, N, K, dtype=torch.float16):
        """Test CUTLASS GEMM if available"""
        if not test_kernels_available:
            return None
            
        A = torch.randn(M, K, dtype=dtype, device=self.device)
        B = torch.randn(K, N, dtype=dtype, device=self.device)
        C = torch.zeros(M, N, dtype=dtype, device=self.device)
        
        try:
            avg_time, std_time, result = self.benchmark_function(
                qgemm_test_kernels_python.cutlass_gemm, A, B, C
            )
            
            tflops = self.calculate_tflops(M, N, K, avg_time)
            
            return {
                'kernel': 'CUTLASS',
                'M': M, 'N': N, 'K': K,
                'dtype': str(dtype).split('.')[-1],
                'time_ms': avg_time,
                'std_ms': std_time,
                'tflops': tflops
            }
        except Exception as e:
            print(f"CUTLASS test failed: {e}")
            return None
    
    def run_comprehensive_benchmark(self):
        """Run comprehensive benchmark across different sizes and precisions"""
        
        # Test configurations
        sizes = [(512, 512, 512), (1024, 1024, 1024), (2048, 2048, 2048), (4096, 4096, 4096)]
        dtypes = [torch.float16, torch.float32]
        
        self.results = []
        
        for M, N, K in sizes:
            for dtype in dtypes:
                print(f"Testing M={M}, N={N}, K={K}, dtype={dtype}...")
                
                # Test PyTorch
                pytorch_result = self.test_pytorch_gemm(M, N, K, dtype)
                self.results.append(pytorch_result)
                
                # Test CUTLASS if available
                if test_kernels_available and dtype == torch.float16:
                    cutlass_result = self.test_cutlass_gemm(M, N, K, dtype)
                    if cutlass_result:
                        self.results.append(cutlass_result)
        
        return self.results
    
    def get_results_dataframe(self):
        """Convert results to pandas DataFrame"""
        return pd.DataFrame(self.results)

# Initialize benchmark
benchmark = GemmBenchmark()
print("✅ Benchmark framework initialized")

In [None]:
# Run comprehensive benchmark
print("🚀 Starting comprehensive benchmark...")
print("This may take several minutes depending on your GPU...")

try:
    results = benchmark.run_comprehensive_benchmark()
    print(f"\n✅ Completed {len(results)} benchmark tests")
except Exception as e:
    print(f"❌ Benchmark failed: {e}")
    results = []

In [None]:
# Display results table
df = benchmark.get_results_dataframe()
if not df.empty:
    print("\n📊 === Benchmark Results ===")
    pd.set_option('display.precision', 3)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    print(df.to_string(index=False))
    
    # Print summary statistics
    print("\n📈 === Performance Summary ===")
    for kernel in df['kernel'].unique():
        kernel_data = df[df['kernel'] == kernel]
        avg_tflops = kernel_data['tflops'].mean()
        max_tflops = kernel_data['tflops'].max()
        print(f"{kernel}: Avg {avg_tflops:.2f} TFLOPS, Peak {max_tflops:.2f} TFLOPS")
else:
    print("⚠️  No benchmark results available")

In [None]:
# Visualize results
if not df.empty:
    # Performance comparison plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # TFLOPS comparison
    for kernel in df['kernel'].unique():
        kernel_data = df[df['kernel'] == kernel]
        sizes = [f"{row['M']}" for _, row in kernel_data.iterrows()]
        ax1.plot(range(len(kernel_data)), kernel_data['tflops'], 'o-', label=kernel, linewidth=2, markersize=8)
    
    ax1.set_xlabel('Test Configuration Index')
    ax1.set_ylabel('TFLOPS')
    ax1.set_title('Performance Comparison (TFLOPS)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Execution time comparison
    for kernel in df['kernel'].unique():
        kernel_data = df[df['kernel'] == kernel]
        ax2.plot(range(len(kernel_data)), kernel_data['time_ms'], 'o-', label=kernel, linewidth=2, markersize=8)
    
    ax2.set_xlabel('Test Configuration Index')
    ax2.set_ylabel('Execution Time (ms)')
    ax2.set_title('Execution Time Comparison')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_yscale('log')
    
    plt.tight_layout()
    plt.show()
    
    # Performance by matrix size
    plt.figure(figsize=(12, 8))
    
    fp16_data = df[df['dtype'] == 'float16']
    if not fp16_data.empty:
        for kernel in fp16_data['kernel'].unique():
            kernel_data = fp16_data[fp16_data['kernel'] == kernel]
            plt.plot(kernel_data['M'], kernel_data['tflops'], 'o-', label=f'{kernel} (FP16)', linewidth=2, markersize=8)
    
    fp32_data = df[df['dtype'] == 'float32']
    if not fp32_data.empty:
        for kernel in fp32_data['kernel'].unique():
            kernel_data = fp32_data[fp32_data['kernel'] == kernel]
            plt.plot(kernel_data['M'], kernel_data['tflops'], 's--', label=f'{kernel} (FP32)', linewidth=2, markersize=8)
    
    plt.xlabel('Matrix Size (M=N=K)')
    plt.ylabel('TFLOPS')
    plt.title('Performance vs Matrix Size')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.xscale('log', base=2)
    plt.show()
else:
    print("📊 No data to visualize")

In [None]:
# Test custom kernel correctness
def test_kernel_correctness(M=128, N=128, K=128, tolerance=1e-3):
    """Test kernel correctness against PyTorch reference"""
    
    if not test_kernels_available:
        print("⚠️  Test kernels not available for correctness testing")
        return False
    
    print(f"\n🔍 Testing kernel correctness (M={M}, N={N}, K={K})...")
    
    # Create test matrices
    A = torch.randn(M, K, dtype=torch.float16, device='cuda')
    B = torch.randn(K, N, dtype=torch.float16, device='cuda')
    C = torch.zeros(M, N, dtype=torch.float16, device='cuda')
    
    # PyTorch reference
    C_ref = torch.mm(A, B)
    
    try:
        # Custom kernel
        qgemm_test_kernels_python.cutlass_gemm(A, B, C)
        
        # Compare results
        diff = torch.abs(C - C_ref)
        max_diff = torch.max(diff)
        mean_diff = torch.mean(diff)
        
        print(f"   Max difference: {max_diff:.6f}")
        print(f"   Mean difference: {mean_diff:.6f}")
        print(f"   Tolerance: {tolerance}")
        
        if max_diff < tolerance:
            print("   ✅ Correctness test PASSED")
            return True
        else:
            print("   ❌ Correctness test FAILED")
            return False
            
    except Exception as e:
        print(f"   ❌ Kernel test failed with error: {e}")
        return False

# Run correctness tests
print("🧪 === Correctness Testing ===")
test_sizes = [(128, 128, 128), (256, 256, 256), (512, 512, 512)]
for M, N, K in test_sizes:
    test_kernel_correctness(M, N, K)

In [None]:
# Memory usage analysis
def analyze_memory_usage(M, N, K, dtype=torch.float16):
    """Analyze memory usage for GEMM operation"""
    
    dtype_size = 2 if dtype == torch.float16 else 4  # bytes
    
    memory_A = M * K * dtype_size
    memory_B = K * N * dtype_size
    memory_C = M * N * dtype_size
    
    total_memory = memory_A + memory_B + memory_C
    
    return {
        'Matrix A (MB)': memory_A / (1024**2),
        'Matrix B (MB)': memory_B / (1024**2),
        'Matrix C (MB)': memory_C / (1024**2),
        'Total (MB)': total_memory / (1024**2),
        'Total (GB)': total_memory / (1024**3)
    }

# Analyze memory for different sizes
print("\n💾 === Memory Usage Analysis ===")
sizes = [512, 1024, 2048, 4096, 8192]
memory_analysis = []

for size in sizes:
    mem_fp16 = analyze_memory_usage(size, size, size, torch.float16)
    mem_fp32 = analyze_memory_usage(size, size, size, torch.float32)
    
    memory_analysis.append({
        'Size': f'{size}x{size}x{size}',
        'FP16 (MB)': mem_fp16['Total (MB)'],
        'FP32 (MB)': mem_fp32['Total (MB)'],
        'FP16 (GB)': mem_fp16['Total (GB)'],
        'FP32 (GB)': mem_fp32['Total (GB)']
    })

memory_df = pd.DataFrame(memory_analysis)
print(memory_df.to_string(index=False))

# Export results to CSV
if not df.empty:
    output_file = 'qgemm_benchmark_results.csv'
    df.to_csv(output_file, index=False)
    print(f"\n💾 Results exported to {output_file}")
    
    # Print final summary
    print("\n🎯 === Final Summary ===")
    summary = df.groupby(['kernel', 'dtype']).agg({
        'tflops': ['mean', 'max', 'std'],
        'time_ms': ['mean', 'min', 'std']
    }).round(3)
    print(summary)
else:
    print("\n⚠️  No results to export")

print("\n🎉 Benchmark notebook completed!")
print("\n📋 Build Instructions:")
print("1. cd .. && ./build.sh --tests --vendor=NVIDIA --arch=80,89")
print("2. For test kernels: ./build.sh --tests --test-kernels")
print("3. Run C++ tests: ./build/bin/qgemm_test")
print("4. Run this notebook for Python interface testing")

# GEMM Test

In [None]:
import torch
import time
import numpy as np
from typing import Tuple, List, Dict, Optional, Callable
import triton

import matplotlib.pyplot as plt
import triton.language as tl

# Define precision types
class PrecisionType:
    FP32 = "fp32"
    FP16 = "fp16"
    BF16 = "bf16"
    INT8 = "int8"
    FP8 = "fp8"  # Note: FP8 support in PyTorch is experimental

class GEMMTester:
    def __init__(self):
        self.kernels = {}
        self.register_default_kernels()
        
    def register_kernel(self, name: str, kernel_fn: Callable):
        """Register a GEMM kernel with a name"""
        self.kernels[name] = kernel_fn
        
    def register_default_kernels(self):
        """Register default PyTorch kernel"""
        self.register_kernel("pytorch", self._pytorch_gemm)
        # Register Triton kernel
        self.register_kernel("triton", self._triton_gemm)
        # We'll assume C++ kernels will be registered separately via pybind
        
    def _pytorch_gemm(self, A, B, C, precision: str):
        """PyTorch native GEMM implementation: D = AB + C"""
        if precision == PrecisionType.FP32:
            A = A.to(torch.float32)
            B = B.to(torch.float32)
            C = C.to(torch.float32)
        elif precision == PrecisionType.FP16:
            A = A.to(torch.float16)
            B = B.to(torch.float16)
            C = C.to(torch.float16)
        elif precision == PrecisionType.BF16:
            A = A.to(torch.bfloat16)
            B = B.to(torch.bfloat16)
            C = C.to(torch.bfloat16)
        elif precision == PrecisionType.INT8:
            A = A.to(torch.int8)
            B = B.to(torch.int8)
            C = C.to(torch.int8)
        # FP8 support is experimental and may require special handling
        
        return torch.matmul(A, B) + C
    
    @staticmethod
    @triton.jit
    def _triton_matmul_kernel(
        a_ptr, b_ptr, c_ptr, d_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        stride_dm, stride_dn,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr
    ):
        """Triton kernel for GEMM operation"""
        pid_m = tl.program_id(0)
        pid_n = tl.program_id(1)
        
        # Compute block offsets
        m_start = pid_m * BLOCK_M
        n_start = pid_n * BLOCK_N
        
        # Initialize accumulator
        acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
        
        # Iterate over blocks of K
        for k in range(0, K, BLOCK_K):
            # Load A block
            a_block_ptr = a_ptr + m_start * stride_am + k * stride_ak
            a_block = tl.load(a_block_ptr, mask=tl.arange(0, BLOCK_M)[:, None] < M - m_start, other=0.0)
            
            # Load B block
            b_block_ptr = b_ptr + k * stride_bk + n_start * stride_bn
            b_block = tl.load(b_block_ptr, mask=tl.arange(0, BLOCK_N)[None, :] < N - n_start, other=0.0)
            
            # Compute matmul for this block
            acc += tl.dot(a_block, b_block)
        
        # Load C block
        c_block_ptr = c_ptr + m_start * stride_cm + n_start * stride_cn
        c_block = tl.load(c_block_ptr, mask=(tl.arange(0, BLOCK_M)[:, None] < M - m_start) & 
                                          (tl.arange(0, BLOCK_N)[None, :] < N - n_start), other=0.0)
        
        # Add C to the result
        result = acc + c_block
        
        # Store the result
        d_block_ptr = d_ptr + m_start * stride_dm + n_start * stride_dn
        tl.store(d_block_ptr, result, mask=(tl.arange(0, BLOCK_M)[:, None] < M - m_start) & 
                                       (tl.arange(0, BLOCK_N)[None, :] < N - n_start))
    
    def _triton_gemm(self, A, B, C, precision: str):
        """Use Triton kernel for GEMM: D = AB + C"""
        # Convert to desired precision
        if precision == PrecisionType.FP32:
            dtype = torch.float32
        elif precision == PrecisionType.FP16:
            dtype = torch.float16
        elif precision == PrecisionType.BF16:
            dtype = torch.bfloat16
        elif precision == PrecisionType.INT8:
            dtype = torch.int8
        else:
            raise ValueError(f"Unsupported precision: {precision}")
        
        A = A.to(dtype)
        B = B.to(dtype)
        C = C.to(dtype)
        
        M, K = A.shape
        K, N = B.shape
        
        # Create output tensor
        D = torch.empty((M, N), device=A.device, dtype=dtype)
        
        # Define the grid
        grid = (triton.cdiv(M, 32), triton.cdiv(N, 32))
        
        # Launch the Triton kernel
        self._triton_matmul_kernel[grid](
            A, B, C, D,
            M, N, K,
            A.stride(0), A.stride(1),
            B.stride(0), B.stride(1),
            C.stride(0), C.stride(1),
            D.stride(0), D.stride(1),
            BLOCK_M=32, BLOCK_N=32, BLOCK_K=32
        )
        
        return D
    
    def benchmark(self, 
                 kernel_name: str, 
                 M: int, N: int, K: int, 
                 precision: str, 
                 device: str = "cuda",
                 num_runs: int = 10):
        """Benchmark a specific GEMM kernel"""
        if kernel_name not in self.kernels:
            raise ValueError(f"Unknown kernel: {kernel_name}")
        
        # Create matrices
        A = torch.randn((M, K), device=device)
        B = torch.randn((K, N), device=device)
        C = torch.randn((M, N), device=device)
        
        # Warmup
        kernel_fn = self.kernels[kernel_name]
        for _ in range(5):
            D = kernel_fn(A, B, C, precision)
            torch.cuda.synchronize()
        
        # Benchmark
        times = []
        for _ in range(num_runs):
            start = time.time()
            D = kernel_fn(A, B, C, precision)
            torch.cuda.synchronize()
            end = time.time()
            times.append((end - start) * 1000)  # Convert to ms
            
        return {
            "kernel": kernel_name,
            "precision": precision,
            "M": M, "N": N, "K": K,
            "mean_time_ms": np.mean(times),
            "std_time_ms": np.std(times),
            "min_time_ms": np.min(times),
            "max_time_ms": np.max(times),
        }
    
    def compare_kernels(self, 
                       kernels: List[str], 
                       shapes: List[Tuple[int, int, int]], 
                       precisions: List[str],
                       device: str = "cuda",
                       num_runs: int = 10):
        """Compare multiple kernels across different shapes and precisions"""
        results = []
        
        for kernel in kernels:
            for precision in precisions:
                for M, N, K in shapes:
                    try:
                        result = self.benchmark(kernel, M, N, K, precision, device, num_runs)
                        results.append(result)
                        print(f"Kernel: {kernel}, Precision: {precision}, Shape: ({M}, {N}, {K}), "
                              f"Time: {result['mean_time_ms']:.2f} ± {result['std_time_ms']:.2f} ms")
                    except Exception as e:
                        print(f"Error benchmarking {kernel} with {precision} for shape ({M}, {N}, {K}): {e}")
                        
        return results
    
    def plot_results(self, results, plot_type="bar"):
        """Plot benchmark results"""
        if plot_type == "bar":
            # Group by shape and precision
            data = {}
            for r in results:
                shape_key = f"({r['M']}, {r['N']}, {r['K']})"
                prec_key = r['precision']
                if shape_key not in data:
                    data[shape_key] = {}
                if prec_key not in data[shape_key]:
                    data[shape_key][prec_key] = {}
                data[shape_key][prec_key][r['kernel']] = r['mean_time_ms']
            
            # Plot
            fig, axes = plt.subplots(len(data), figsize=(12, 4 * len(data)))
            if len(data) == 1:
                axes = [axes]
                
            for i, (shape, precisions) in enumerate(data.items()):
                ax = axes[i]
                ax.set_title(f"Matrix shape: {shape}")
                
                x = np.arange(len(precisions))
                width = 0.8 / len(next(iter(precisions.values())))
                
                for j, (prec, kernels) in enumerate(precisions.items()):
                    for k, (kernel, time) in enumerate(kernels.items()):
                        ax.bar(x[j] + k * width, time, width, label=f"{kernel} ({prec})")
                
                ax.set_ylabel("Time (ms)")
                ax.set_xticks(x + width / 2)
                ax.set_xticklabels(precisions.keys())
                ax.legend()
            
            plt.tight_layout()
            plt.show()
        else:
            # Other plot types can be implemented as needed
            pass

# Example on how to register a C++ kernel with pybind11
# This is a placeholder - you'd need to actually implement the C++ module and binding
"""
import gemm_cpp  # This would be your pybind11 module

def cpp_gemm(A, B, C, precision):
    # Convert PyTorch tensors to the format your C++ code expects
    # Call your C++ implementation
    # Convert back to PyTorch tensor
    if precision == PrecisionType.FP32:
        return gemm_cpp.gemm_fp32(A, B, C)
    elif precision == PrecisionType.BF16:
        return gemm_cpp.gemm_bf16(A, B, C)
    # ... other precision types

# Register the C++ kernel
tester = GEMMTester()
tester.register_kernel("cpp", cpp_gemm)
"""

# Example usage:
def run_gemm_tests():
    tester = GEMMTester()
    
    # Define test configurations
    kernels = ["pytorch", "triton"]  # Add "cpp" when you have the C++ implementation
    shapes = [(128, 128, 128), (512, 512, 512), (1024, 1024, 1024)]
    precisions = [PrecisionType.FP32, PrecisionType.FP16, PrecisionType.BF16]
    
    # For smaller tests, you can use CPU, but for real benchmarks, use CUDA
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Run benchmarks
    print(f"Running GEMM benchmarks on {device}...")
    results = tester.compare_kernels(kernels, shapes, precisions, device)
    
    # Plot results
    tester.plot_results(results)
    
    return results

# When running the notebook, call this function to execute the tests
# results = run_gemm_tests()