In [1]:
import sys
import os
sys.path.insert(0, os.path.join(os.getcwd(), '../core'))

import cupy as cp
from cupy.cuda import runtime
import pack_cuda

# Initialize pack_cuda to compile the kernel
pack_cuda.USE_FLOAT32 = False
pack_cuda._ensure_initialized()

# Get the compiled kernel
kernel = pack_cuda._multi_overlap_list_total_kernel

# 2. Inspect kernel attributes (per-thread stack/local memory, registers, etc.)
attrs = kernel.attributes  # dict
print("=" * 70)
print("OVERLAP KERNEL ATTRIBUTES")
print("=" * 70)
print(f"  num_regs             = {attrs.get('num_regs')}")
print(f"  local_size_bytes     = {attrs.get('local_size_bytes')}  # per-thread stack/local")
print(f"  shared_size_bytes    = {attrs.get('shared_size_bytes')}")
print(f"  max_threads_per_block= {attrs.get('max_threads_per_block')}")

# 3. Get device and device properties
dev = cp.cuda.Device()        # current device
dev_id = dev.id
props = runtime.getDeviceProperties(dev_id)

name = props["name"].decode("utf-8", errors="ignore")
multi_processor_count = props["multiProcessorCount"]
max_threads_per_block = props["maxThreadsPerBlock"]
warp_size = props["warpSize"]
max_threads_per_sm = props.get("maxThreadsPerMultiProcessor", 1536)
max_blocks_per_sm = props.get("maxBlocksPerMultiProcessor", 24)
regs_per_sm = props.get("regsPerMultiprocessor", 65536)

print("\n" + "=" * 70)
print("DEVICE PROPERTIES")
print("=" * 70)
print(f"  name                       = {name}")
print(f"  multiprocessor_count       = {multi_processor_count}")
print(f"  maxThreadsPerBlock         = {max_threads_per_block}")
print(f"  maxThreadsPerSM            = {max_threads_per_sm}")
print(f"  maxBlocksPerSM             = {max_blocks_per_sm}")
print(f"  regsPerSM                  = {regs_per_sm}")
print(f"  warpSize                   = {warp_size}")

# 4. Occupancy analysis for various block sizes
print("\n" + "=" * 70)
print("OCCUPANCY ANALYSIS")
print("=" * 70)

num_regs = attrs.get('num_regs', 64)
local_mem = attrs.get('local_size_bytes', 1104)

print(f"\nPer-thread resources:")
print(f"  Registers: {num_regs}")
print(f"  Local memory: {local_mem} bytes")

test_block_sizes = [20, 40, 80, 128, 256, 512, 1024]

print(f"\n{'Threads/Block':<15} {'Blocks/SM':<15} {'Active Threads/SM':<20} {'Limiting Factor'}")
print("-" * 80)

for block_size in test_block_sizes:
    # Calculate limits
    regs_per_block = num_regs * block_size
    blocks_by_regs = regs_per_sm // regs_per_block if regs_per_block > 0 else 999
    blocks_by_threads = max_threads_per_sm // block_size
    blocks_by_hw = max_blocks_per_sm
    
    max_blocks = min(blocks_by_regs, blocks_by_threads, blocks_by_hw)
    active_threads = max_blocks * block_size
    
    # Determine limiting factor
    if max_blocks == blocks_by_regs:
        limiting = "registers"
    elif max_blocks == blocks_by_threads:
        limiting = "thread count"
    else:
        limiting = "HW block limit"
    
    print(f"{block_size:<15} {max_blocks:<15} {active_threads:<20} {limiting}")

# 5. Test with actual kernel configuration from overlap_multi_ensemble
print("\n" + "=" * 70)
print("ACTUAL USAGE IN overlap_multi_ensemble")
print("=" * 70)

# Typical usage: 5 trees per ensemble, 4 threads per tree = 20 threads
n_trees = 5
threads_per_block_actual = n_trees * 4
regs_per_block_actual = num_regs * threads_per_block_actual
local_mem_per_block = local_mem * threads_per_block_actual

blocks_by_regs_actual = regs_per_sm // regs_per_block_actual
blocks_by_threads_actual = max_threads_per_sm // threads_per_block_actual
blocks_by_hw_actual = max_blocks_per_sm

max_blocks_actual = min(blocks_by_regs_actual, blocks_by_threads_actual, blocks_by_hw_actual)

print(f"\nConfiguration: {n_trees} trees/ensemble × 4 threads/tree = {threads_per_block_actual} threads/block")
print(f"\nPer-block resource usage:")
print(f"  Registers: {regs_per_block_actual} ({100*regs_per_block_actual/regs_per_sm:.1f}% of SM)")
print(f"  Local memory: {local_mem_per_block:,} bytes ({local_mem_per_block/1024:.1f} KB)")
print(f"\nOccupancy limits:")
print(f"  Blocks limited by registers: {blocks_by_regs_actual}")
print(f"  Blocks limited by threads: {blocks_by_threads_actual}")
print(f"  Blocks limited by HW: {blocks_by_hw_actual}")
print(f"  ACTUAL max blocks per SM: {max_blocks_actual}")
print(f"  ACTUAL active threads per SM: {max_blocks_actual * threads_per_block_actual}")
print(f"  Occupancy: {100 * (max_blocks_actual * threads_per_block_actual) / max_threads_per_sm:.1f}%")

print("\n" + "=" * 70)
print("CONCLUSION")
print("=" * 70)
if max_blocks_actual >= 2:
    print(f"✓ GPU SHOULD support {max_blocks_actual} concurrent blocks per SM")
    print(f"  This means {max_blocks_actual * multi_processor_count} blocks across all {multi_processor_count} SMs")
    print(f"  But your tests show only 1 block executing at a time...")
    print(f"  → The bottleneck is NOT resource constraints!")
else:
    print(f"✗ GPU can only run 1 block per SM due to resource limits")
    print(f"  Limited by: ", end="")
    if max_blocks_actual == blocks_by_regs_actual:
        print("registers")
    elif max_blocks_actual == blocks_by_threads_actual:
        print("thread count")
    else:
        print("HW block limit")

local
OVERLAP KERNEL ATTRIBUTES
  num_regs             = 122
  local_size_bytes     = 1824  # per-thread stack/local
  shared_size_bytes    = 0
  max_threads_per_block= 512

DEVICE PROPERTIES
  name                       = NVIDIA GeForce RTX 4070 Ti
  multiprocessor_count       = 60
  maxThreadsPerBlock         = 1024
  maxThreadsPerSM            = 1536
  maxBlocksPerSM             = 24
  regsPerSM                  = 65536
  warpSize                   = 32

OCCUPANCY ANALYSIS

Per-thread resources:
  Registers: 122
  Local memory: 1824 bytes

Threads/Block   Blocks/SM       Active Threads/SM    Limiting Factor
--------------------------------------------------------------------------------
20              24              480                  HW block limit
40              13              520                  registers
80              6               480                  registers
128             4               512                  registers
256             2               512         