## Iteratively Improving MatMul Kernel

In [ ]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.cpp_extension import load_inline
import os
import time

### Inline CUDA Kernel System

In [ ]:
import matplotlib.pyplot as plt
import pandas as pd
from statistics import mean, stdev

class CudaKernelTester:
    def __init__(self):
        self.compiled_kernels = {}
        self.benchmark_results = {}
    
    def compile_kernel(self, kernel_name, cuda_code, cpp_code):
        """Compile inline CUDA kernel with C++ wrapper"""
        try:
            module = load_inline(
                name=kernel_name,
                cpp_sources=[cpp_code],
                cuda_sources=[cuda_code],
                verbose=True,
                with_cuda=True
            )
            self.compiled_kernels[kernel_name] = module
            print(f"✓ Successfully compiled kernel: {kernel_name}")
            return module
        except Exception as e:
            print(f"✗ Failed to compile kernel {kernel_name}: {str(e)}")
            return None
    
    def benchmark_kernel(self, kernel_name, A, B, C, num_runs=10, warmup_runs=3):
        """Benchmark a kernel with multiple runs and return statistics"""
        if kernel_name not in self.compiled_kernels:
            return None
        
        module = self.compiled_kernels[kernel_name]
        
        # Warmup runs
        for _ in range(warmup_runs):
            module.matmul_forward(A, B, C)
            torch.cuda.synchronize()
        
        # Benchmark runs
        times = []
        for _ in range(num_runs):
            torch.cuda.synchronize()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            
            start_event.record()
            module.matmul_forward(A, B, C)
            end_event.record()
            
            torch.cuda.synchronize()
            times.append(start_event.elapsed_time(end_event))  # ms
        
        return {
            'mean_ms': mean(times),
            'std_ms': stdev(times) if len(times) > 1 else 0,
            'min_ms': min(times),
            'max_ms': max(times),
            'all_times': times
        }
    
    def benchmark_pytorch(self, A, B, num_runs=10, warmup_runs=3):
        """Benchmark PyTorch matmul with multiple runs"""
        # Warmup runs
        for _ in range(warmup_runs):
            torch.mm(A, B)
            torch.cuda.synchronize()
        
        # Benchmark runs
        times = []
        for _ in range(num_runs):
            torch.cuda.synchronize()
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)
            
            start_event.record()
            result = torch.mm(A, B)
            end_event.record()
            
            torch.cuda.synchronize()
            times.append(start_event.elapsed_time(end_event))  # ms
        
        return {
            'mean_ms': mean(times),
            'std_ms': stdev(times) if len(times) > 1 else 0,
            'min_ms': min(times),
            'max_ms': max(times),
            'all_times': times,
            'result': result
        }
    
    def calculate_performance_metrics(self, m, k, n, time_ms):
        """Calculate FLOPS and memory bandwidth metrics"""
        # Matrix multiplication FLOPS: 2*M*N*K (multiply + add for each element)
        flops = 2 * m * n * k
        gflops = flops / (time_ms * 1e-3) / 1e9  # GFLOPS
        
        # Memory bandwidth calculation
        # Read A (M*K), Read B (K*N), Write C (M*N) - all float32 (4 bytes)
        bytes_transferred = (m * k + k * n + m * n) * 4
        bandwidth_gb_s = bytes_transferred / (time_ms * 1e-3) / 1e9  # GB/s
        
        return {
            'gflops': gflops,
            'bandwidth_gb_s': bandwidth_gb_s,
            'arithmetic_intensity': flops / bytes_transferred  # FLOPS per byte
        }
    
    def test_matmul_kernel(self, kernel_name, test_sizes=[(2,3,2), (32,64,32), (128,128,128), (512,512,512)], num_runs=10):
        """Test matrix multiplication kernel with comprehensive performance analysis"""
        if kernel_name not in self.compiled_kernels:
            print(f"✗ Kernel {kernel_name} not found")
            return False
        
        all_passed = True
        results = []
        
        print(f"\n{'='*60}")
        print(f"BENCHMARKING KERNEL: {kernel_name}")
        print(f"{'='*60}")
        
        for m, k, n in test_sizes:
            print(f"\nTesting size [{m}×{k}] × [{k}×{n}]")
            print("-" * 40)
            
            # Create test matrices
            A = torch.randn(m, k, device='cuda', dtype=torch.float32).contiguous()
            B = torch.randn(k, n, device='cuda', dtype=torch.float32).contiguous()
            C = torch.zeros(m, n, device='cuda', dtype=torch.float32).contiguous()
            
            # Benchmark custom kernel
            custom_stats = self.benchmark_kernel(kernel_name, A, B, C, num_runs)
            if custom_stats is None:
                continue
            
            # Benchmark PyTorch
            pytorch_stats = self.benchmark_pytorch(A, B, num_runs)
            
            # Check correctness
            correctness_passed = torch.allclose(C, pytorch_stats['result'], atol=1e-4)
            
            # Calculate performance metrics
            custom_metrics = self.calculate_performance_metrics(m, k, n, custom_stats['mean_ms'])
            pytorch_metrics = self.calculate_performance_metrics(m, k, n, pytorch_stats['mean_ms'])
            
            # Print results
            print(f"Correctness: {'✓ PASSED' if correctness_passed else '✗ FAILED'}")
            if not correctness_passed:
                max_diff = torch.max(torch.abs(C - pytorch_stats['result']))
                print(f"Max difference: {max_diff:.6f}")
                all_passed = False
            
            print(f"\n📊 PERFORMANCE COMPARISON:")
            print(f"{'Metric':<20} {'Custom':<15} {'PyTorch':<15} {'Speedup':<10}")
            print("-" * 65)
            
            # Timing comparison
            speedup = pytorch_stats['mean_ms'] / custom_stats['mean_ms']
            print(f"{'Time (ms)':<20} {custom_stats['mean_ms']:<15.3f} {pytorch_stats['mean_ms']:<15.3f} {speedup:<10.2f}x")
            print(f"{'±Std (ms)':<20} {custom_stats['std_ms']:<15.3f} {pytorch_stats['std_ms']:<15.3f}")
            
            # Performance metrics
            gflops_speedup = custom_metrics['gflops'] / pytorch_metrics['gflops']
            bandwidth_speedup = custom_metrics['bandwidth_gb_s'] / pytorch_metrics['bandwidth_gb_s']
            
            print(f"{'GFLOPS':<20} {custom_metrics['gflops']:<15.1f} {pytorch_metrics['gflops']:<15.1f} {gflops_speedup:<10.2f}x")
            print(f"{'Bandwidth (GB/s)':<20} {custom_metrics['bandwidth_gb_s']:<15.1f} {pytorch_metrics['bandwidth_gb_s']:<15.1f} {bandwidth_speedup:<10.2f}x")
            print(f"{'Arith. Intensity':<20} {custom_metrics['arithmetic_intensity']:<15.2f} {pytorch_metrics['arithmetic_intensity']:<15.2f}")
            
            # Store results for analysis
            result_entry = {
                'kernel': kernel_name,
                'size': f"{m}×{k}×{n}",
                'm': m, 'k': k, 'n': n,
                'custom_time_ms': custom_stats['mean_ms'],
                'pytorch_time_ms': pytorch_stats['mean_ms'],
                'speedup': speedup,
                'custom_gflops': custom_metrics['gflops'],
                'pytorch_gflops': pytorch_metrics['gflops'],
                'custom_bandwidth': custom_metrics['bandwidth_gb_s'],
                'pytorch_bandwidth': pytorch_metrics['bandwidth_gb_s'],
                'correctness': correctness_passed
            }
            results.append(result_entry)
        
        # Store results for later analysis
        self.benchmark_results[kernel_name] = results
        
        return all_passed
    
    def compare_kernels(self, kernel_names, test_sizes=[(128,128,128), (512,512,512), (1024,1024,1024)]):
        """Compare multiple kernels across different sizes"""
        print(f"\n{'='*80}")
        print(f"KERNEL COMPARISON")
        print(f"{'='*80}")
        
        comparison_data = []
        
        for size in test_sizes:
            m, k, n = size
            print(f"\nSize: [{m}×{k}] × [{k}×{n}]")
            print("-" * 50)
            
            A = torch.randn(m, k, device='cuda', dtype=torch.float32).contiguous()
            B = torch.randn(k, n, device='cuda', dtype=torch.float32).contiguous()
            
            # Benchmark PyTorch as baseline
            pytorch_stats = self.benchmark_pytorch(A, B, num_runs=10)
            pytorch_metrics = self.calculate_performance_metrics(m, k, n, pytorch_stats['mean_ms'])
            
            size_results = {'size': f"{m}×{k}×{n}", 'PyTorch': pytorch_stats['mean_ms']}
            
            print(f"{'Kernel':<20} {'Time (ms)':<12} {'GFLOPS':<10} {'Speedup':<10}")
            print("-" * 55)
            print(f"{'PyTorch (baseline)':<20} {pytorch_stats['mean_ms']:<12.3f} {pytorch_metrics['gflops']:<10.1f} {'1.00x':<10}")
            
            for kernel_name in kernel_names:
                if kernel_name in self.compiled_kernels:
                    C = torch.zeros(m, n, device='cuda', dtype=torch.float32).contiguous()
                    custom_stats = self.benchmark_kernel(kernel_name, A, B, C, num_runs=10)
                    if custom_stats:
                        custom_metrics = self.calculate_performance_metrics(m, k, n, custom_stats['mean_ms'])
                        speedup = pytorch_stats['mean_ms'] / custom_stats['mean_ms']
                        
                        print(f"{kernel_name:<20} {custom_stats['mean_ms']:<12.3f} {custom_metrics['gflops']:<10.1f} {speedup:<10.2f}x")
                        size_results[kernel_name] = custom_stats['mean_ms']
            
            comparison_data.append(size_results)
        
        return comparison_data
    
    def plot_performance_comparison(self, kernel_names=None):
        """Create performance comparison charts"""
        if not self.benchmark_results:
            print("No benchmark results available. Run tests first.")
            return
        
        kernels_to_plot = kernel_names or list(self.benchmark_results.keys())
        
        # Prepare data for plotting
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        fig.suptitle('Kernel Performance Comparison', fontsize=16)
        
        # Collect all data
        all_data = []
        for kernel in kernels_to_plot:
            if kernel in self.benchmark_results:
                for result in self.benchmark_results[kernel]:
                    all_data.append({**result, 'total_elements': result['m'] * result['k'] * result['n']})
        
        if not all_data:
            print("No data to plot")
            return
        
        df = pd.DataFrame(all_data)
        
        # Plot 1: Speedup vs Matrix Size
        ax1 = axes[0, 0]
        for kernel in kernels_to_plot:
            kernel_data = df[df['kernel'] == kernel]
            if not kernel_data.empty:
                ax1.plot(kernel_data['total_elements'], kernel_data['speedup'], 
                        marker='o', label=kernel, linewidth=2)
        ax1.set_xlabel('Total Elements (M×K×N)')
        ax1.set_ylabel('Speedup vs PyTorch')
        ax1.set_title('Speedup vs Matrix Size')
        ax1.set_xscale('log')
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        ax1.axhline(y=1, color='red', linestyle='--', alpha=0.5, label='PyTorch baseline')
        
        # Plot 2: GFLOPS Comparison
        ax2 = axes[0, 1]
        for kernel in kernels_to_plot:
            kernel_data = df[df['kernel'] == kernel]
            if not kernel_data.empty:
                ax2.plot(kernel_data['total_elements'], kernel_data['custom_gflops'], 
                        marker='s', label=kernel, linewidth=2)
        ax2.set_xlabel('Total Elements (M×K×N)')
        ax2.set_ylabel('GFLOPS')
        ax2.set_title('Computational Performance')
        ax2.set_xscale('log')
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        
        # Plot 3: Memory Bandwidth
        ax3 = axes[1, 0]
        for kernel in kernels_to_plot:
            kernel_data = df[df['kernel'] == kernel]
            if not kernel_data.empty:
                ax3.plot(kernel_data['total_elements'], kernel_data['custom_bandwidth'], 
                        marker='^', label=kernel, linewidth=2)
        ax3.set_xlabel('Total Elements (M×K×N)')
        ax3.set_ylabel('Memory Bandwidth (GB/s)')
        ax3.set_title('Memory Bandwidth Utilization')
        ax3.set_xscale('log')
        ax3.grid(True, alpha=0.3)
        ax3.legend()
        
        # Plot 4: Execution Time Comparison
        ax4 = axes[1, 1]
        for kernel in kernels_to_plot:
            kernel_data = df[df['kernel'] == kernel]
            if not kernel_data.empty:
                ax4.plot(kernel_data['total_elements'], kernel_data['custom_time_ms'], 
                        marker='d', label=kernel, linewidth=2)
        ax4.set_xlabel('Total Elements (M×K×N)')
        ax4.set_ylabel('Execution Time (ms)')
        ax4.set_title('Execution Time vs Matrix Size')
        ax4.set_xscale('log')
        ax4.set_yscale('log')
        ax4.grid(True, alpha=0.3)
        ax4.legend()
        
        plt.tight_layout()
        plt.show()

# Initialize the enhanced tester
kernel_tester = CudaKernelTester()

In [ ]:
# Example: Basic Matrix Multiplication Kernel

cuda_kernel_v1 = """
__global__ void matmul_kernel(float* A, float* B, float* C, int M, int K, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; k++) {
            sum += A[row * K + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}
"""

cpp_wrapper_v1 = """
#include <torch/extension.h>
#include <cuda_runtime.h>

void matmul_forward(torch::Tensor A, torch::Tensor B, torch::Tensor C) {
    const int M = A.size(0);
    const int K = A.size(1);
    const int N = B.size(1);
    
    const dim3 blockSize(16, 16);
    const dim3 gridSize((N + blockSize.x - 1) / blockSize.x, 
                       (M + blockSize.y - 1) / blockSize.y);
    
    matmul_kernel<<<gridSize, blockSize>>>(
        A.data_ptr<float>(),
        B.data_ptr<float>(),
        C.data_ptr<float>(),
        M, K, N
    );
    
    cudaDeviceSynchronize();
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("matmul_forward", &matmul_forward, "Matrix multiplication forward");
}
"""

# Compile and test the basic kernel
kernel_tester.compile_kernel("matmul_v1", cuda_kernel_v1, cpp_wrapper_v1)

In [ ]:
# Test the basic kernel
kernel_tester.test_matmul_kernel("matmul_v1")

### Optimized Version with Shared Memory