In [1]:
from conv_gemm.layers.triton_conv2d import TritonConv2d 
import time
import torch

In [2]:
def test_configs():
    device = "cuda"

    # === РАЗМЕРЫ ПОД ТВОЮ СЕТКУ ===
    B = 8
    Cin = 64
    Cout = 128
    H = 256
    W = 256
    ks = 3
    stride = 1
    padding = 1
    dilation = 1

    x = torch.randn(B, Cin, H, W, device=device, dtype=torch.float32)

    # Набор конфигов для перебора
    configs = [
        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=64, BLOCK_N=32, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=32, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=32, BLOCK_N=32, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=2),

        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=64, BLOCK_N=32, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),
        dict(BLOCK_M=32, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=2),

        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=4, NUM_STAGES=1),
        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=16, NUM_WARPS=4, NUM_STAGES=1),

        dict(BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, NUM_WARPS=2, NUM_STAGES=2),
    ]

    results = []

    # глобальная синхра перед всеми замерами
    torch.cuda.synchronize()

    for cfg in configs:
        print("\n=== Testing config:", cfg, "===")
        try:
            conv = TritonConv2d(
                in_channels=Cin,
                out_channels=Cout,
                kernel_size=ks,
                stride=stride,
                padding=padding,
                dilation=dilation,
                BLOCK_M=cfg["BLOCK_M"],
                BLOCK_N=cfg["BLOCK_N"],
                BLOCK_K=cfg["BLOCK_K"],
                NUM_WARPS=cfg["NUM_WARPS"],
                NUM_STAGES=cfg["NUM_STAGES"],
                precision_mode="fp16_runtime",  # можно "fp16" для начала
            ).to(device)

            conv.train()
            opt = torch.optim.SGD(conv.parameters(), lr=1e-3)

            # прогрев 2 итерации (без тайминга)
            for _ in range(2):
                opt.zero_grad(set_to_none=True)
                out = conv(x)
                loss = out.float().pow(2).mean()
                loss.backward()
                opt.step()

            torch.cuda.synchronize()

            # измеряем несколько раз, усредняем
            iters = 10
            start_event = torch.cuda.Event(enable_timing=True)
            end_event = torch.cuda.Event(enable_timing=True)

            total_ms = 0.0
            last_loss = None

            for _ in range(iters):
                opt.zero_grad(set_to_none=True)
                start_event.record()

                out = conv(x)
                loss = out.float().pow(2).mean()
                loss.backward()
                opt.step()

                end_event.record()
                torch.cuda.synchronize()  # гарантируем, что итерация закончилась
                iter_ms = start_event.elapsed_time(end_event)  # ms
                total_ms += iter_ms
                last_loss = loss.item()

            avg_ms = total_ms / iters
            print(f"OK, avg_time = {avg_ms:.3f} ms, loss={last_loss:.4f}")
            results.append((avg_ms / 1000.0, cfg, "OK"))  # храним в секундах

        except RuntimeError as e:
            msg = str(e)
            if "OutOfResources" in msg or "out of resource: shared memory" in msg:
                print("!! OutOfResources / shared memory for config:", cfg)
                results.append((None, cfg, "OOR"))
            else:
                print("!! RuntimeError other:", msg)
                results.append((None, cfg, "ERR"))

    print("\n\n========= SUMMARY =========")
    for dt, cfg, status in sorted(results, key=lambda x: (x[0] if x[0] is not None else 1e9)):
        t_str = f"{dt*1000:.3f} ms" if dt is not None else "   -   "
        print(status, "|", t_str, "|", cfg)

In [3]:
test_configs()


=== Testing config: {'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32, 'NUM_WARPS': 4, 'NUM_STAGES': 2} ===
!! RuntimeError other: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 7.68 GiB of which 307.00 MiB is free. Process 41506 has 4.70 GiB memory in use. Including non-PyTorch memory, this process has 1.43 GiB memory in use. Of the allocated memory 1.19 GiB is allocated by PyTorch, and 65.58 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

=== Testing config: {'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'NUM_WARPS': 4, 'NUM_STAGES': 2} ===
!! RuntimeError other: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 7.68 GiB of which 295.31 MiB is free. Process 41506 has 4.70 GiB memory in use. 