In [23]:
from datasets import load_dataset
# train_compare_resnet_triton.py
import os, time, copy, argparse, math
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from conv_gemm.layers.triton_conv2d import TritonConv2d

# === заглушечный датасет (замени на свой) ===
from torchvision.datasets import FakeData
from torchvision import transforms
from torchvision.models import resnet18

torch.backends.cudnn.benchmark = True
device = "cuda" if torch.cuda.is_available() else "cpu"

In [24]:
# DATASET

In [20]:
def to_rgb(img):
    # HF Datasets обычно отдают PIL.Image, но подстрахуемся
    if isinstance(img, Image.Image):
        return img.convert("RGB")
    if torch.is_tensor(img):
        # tensor [C,H,W]
        if img.ndim == 3 and img.size(0) == 1:
            return img.expand(3, -1, -1)  # уже тензор -> вернём тензор
        return img
    # numpy -> PIL -> RGB
    return Image.fromarray(img).convert("RGB")

In [21]:
ds = load_dataset("zh-plus/tiny-imagenet")
train_tfms = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  # теперь гарантированно [3,H,W]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

valid_tfms = transforms.Compose([
    transforms.Lambda(to_rgb),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # [3,H,W]
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
def make_collate(tfms):
    def collate(batch):
        imgs = [tfms(sample["image"]) for sample in batch]
        labels = torch.tensor([int(sample["label"]) for sample in batch], dtype=torch.long)
        return torch.stack(imgs, 0), labels
    return collate

train_loader = DataLoader(
    ds["train"], batch_size=32, shuffle=True,
    num_workers=4, pin_memory=True,
    collate_fn=make_collate(train_tfms),
)
valid_loader = DataLoader(
    ds["valid"], batch_size=32, shuffle=False,
    num_workers=4, pin_memory=True,
    collate_fn=make_collate(valid_tfms),
)

In [25]:
# MODELING

In [26]:
def set_seed(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def count_params(m):
    return sum(p.numel() for p in m.parameters())

def topk(output, target, ks=(1,)):
    maxk = max(ks)
    B = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in ks:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append((correct_k.mul_(100.0 / B)).item())
    return res


In [27]:
# --------------------------
# Conv2d -> TritonConv2d swap
# --------------------------
def _make_triton_from_conv(conv: nn.Conv2d,
                           precision_mode="fp32",
                           use_weight_shadow=True,
                           triton_blocks=(64, 64, 32),
                           triton_launch=(4, 2)) -> TritonConv2d:
    tm = TritonConv2d(
        in_channels=conv.in_channels,
        out_channels=conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        dilation=conv.dilation,
        bias=(conv.bias is not None),
        BLOCK_M=triton_blocks[0], BLOCK_N=triton_blocks[1], BLOCK_K=triton_blocks[2],
        NUM_WARPS=triton_launch[0], NUM_STAGES=triton_launch[1],
        precision_mode=precision_mode,
        use_weight_shadow=use_weight_shadow,
    ).to(conv.weight.device).to(conv.weight.dtype)

    # копируем веса/биас 1-в-1
    with torch.no_grad():
        tm.weight.copy_(conv.weight)
        if conv.bias is not None and tm.bias is not None:
            tm.bias.copy_(conv.bias)
    return tm


def replace_all_convs_with_triton(model: nn.Module,
                                  precision_mode="fp32",
                                  use_weight_shadow=True,
                                  triton_blocks=(64, 64, 32),
                                  triton_launch=(4, 2)):
    """
    Рекурсивно меняет все nn.Conv2d -> TritonConv2d, сохраняя веса/биас и гиперпараметры.
    """
    for name, module in list(model.named_children()):
        # Глубже
        replace_all_convs_with_triton(module, precision_mode, use_weight_shadow, triton_blocks, triton_launch)
        # Замена на этом уровне
        if isinstance(module, nn.Conv2d):
            tm = _make_triton_from_conv(
                module,
                precision_mode=precision_mode,
                use_weight_shadow=use_weight_shadow,
                triton_blocks=triton_blocks,
                triton_launch=triton_launch,
            )
            setattr(model, name, tm)
    return model

In [28]:
set_seed(0)
device = "cuda"
base = resnet18(weights=None, num_classes=200).to(device)
tri = copy.deepcopy(base)
precision_mode = 'fp32'
use_weight_shadow = False
triton_blocks=(64, 64, 32)
triton_launch=(4, 2)

In [29]:
replace_all_convs_with_triton(
        tri,
        precision_mode=precision_mode,           # 'fp32' | 'fp16_runtime' | 'fp16_infer'
        use_weight_shadow=use_weight_shadow > 0,
        triton_blocks=triton_blocks,
        triton_launch=triton_launch,
    )
tri.to(device)

ResNet(
  (conv1): TritonConv2d()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): TritonConv2d()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): TritonConv2d()
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): TritonConv2d()
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): TritonConv2d()
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): TritonConv2d()
      (bn1): BatchNorm2d(128, eps=1e-05, mome

In [30]:
# train

In [36]:
@torch.no_grad()
def evaluate(model, loader, device="cuda", amp=False):
    model.eval()
    scaler_ctx = torch.autocast(device_type="cuda", dtype=torch.float16) if amp else nullcontext()
    total, top1_sum, top5_sum, loss_sum = 0, 0.0, 0.0, 0.0
    criterion = nn.CrossEntropyLoss()

    t0 = time.perf_counter()
    for imgs, labels in loader:
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        with scaler_ctx:
            logits = model(imgs)
            loss = criterion(logits, labels)

        bsz = labels.size(0)
        top1, top5 = topk(logits, labels, ks=(1, 5))
        top1_sum += top1 * bsz / 100.0
        top5_sum += top5 * bsz / 100.0
        loss_sum += loss.item() * bsz
        total += bsz

    dt = time.perf_counter() - t0
    return {
        "loss": loss_sum / total,
        "top1": 100.0 * top1_sum / total,
        "top5": 100.0 * top5_sum / total,
        "time_s": dt,
        "img_s": total / dt if dt > 0 else float("inf"),
    }

In [50]:
def train_one_epoch(model, loader, optimizer, device="cuda", amp=False, grad_clip=None):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion = nn.CrossEntropyLoss()

    start = time.perf_counter()
    torch.cuda.reset_peak_memory_stats(device)

    for imgs, labels in loader:
        imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=amp):
            logits = model(imgs)
            loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        if grad_clip is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        scaler.step(optimizer)
        scaler.update()

    dt = time.perf_counter() - start
    max_mem = torch.cuda.max_memory_allocated(device) / (1024**2)
    return {"time_s": dt, "img_s": (len(loader.dataset) / dt), "max_mem_MB": max_mem}

In [51]:
lr = 0.003
epochs = 3
amp = False
from PIL import Image

In [55]:
opt_base = torch.optim.SGD(base.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
opt_tri  = torch.optim.SGD(tri.parameters(),  lr=lr, momentum=0.9, weight_decay=1e-4)

In [56]:
for epoch in range(1, epochs + 1):
    # ---- Torch ----
    print('Train: ')
    train_stats_base = train_one_epoch(base, train_loader, opt_base, device=device, amp=amp, grad_clip=None)
    eval_base = evaluate(base, valid_loader, device=device, amp=amp)
    # ---- Triton ----
    print('Val: ')
    train_stats_tri = train_one_epoch(tri, train_loader, opt_tri, device=device, amp=amp, grad_clip=None)
    eval_tri = evaluate(tri, valid_loader, device=device, amp=amp)

Train: 


  scaler = torch.cuda.amp.GradScaler(enabled=amp)
  with torch.cuda.amp.autocast(enabled=amp):


Val: 


OutOfResources: out of resource: shared memory, Required: 196608, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.

In [5]:
def replace_conv2d_with_gem2col(module: nn.Module):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Conv2d) and child.groups == 1:
            new = Gem2ColConv2d(
                in_channels=child.in_channels,
                out_channels=child.out_channels,
                kernel_size=child.kernel_size,
                stride=child.stride,
                padding=child.padding,
                dilation=child.dilation,
                bias=(child.bias is not None),
            ).to(next(module.parameters(), torch.tensor(0, device=device)).device)
            with torch.no_grad():
                new.weight.copy_(child.weight)         # <-- ключевой момент
                if child.bias is not None and new.bias is not None:
                    new.bias.copy_(child.bias)
            setattr(module, name, new)
        else:
            replace_conv2d_with_gem2col(child)

In [6]:

base = resnet18(weights=None, num_classes=200).to(device)
model_base  = deepcopy(base) 
model_gemm = deepcopy(base)

In [5]:
replace_conv2d_with_gem2col(model_gemm)
model_base = model_base.eval()
model_gemm = model_gemm.eval()

# Трейн

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_base.parameters(), lr=1e-3)

In [12]:
for epoch in range(5):
    model_base.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model_base(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    print(f"Train loss: {running_loss / len(train_loader.dataset):.4f}")

    # === Валидация ===
    model_base.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in valid_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model_base(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    print(f"Val acc: {correct / total * 100:.2f}%")

Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [02:16<00:00, 11.47it/s]

Train loss: 4.4698





Val acc: 14.11%


Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [02:14<00:00, 11.62it/s]

Train loss: 3.5515





Val acc: 22.65%


Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [02:20<00:00, 11.12it/s]

Train loss: 3.0252





Val acc: 31.38%


Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [02:18<00:00, 11.26it/s]

Train loss: 2.6760





Val acc: 37.37%


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1563/1563 [02:15<00:00, 11.51it/s]

Train loss: 2.3983





Val acc: 40.45%


In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_gemm.parameters(), lr=1e-3)

In [10]:
for epoch in range(5):
    model_gemm.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        out = model_gemm(imgs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
    print(f"Train loss: {running_loss / len(train_loader.dataset):.4f}")

    # === Валидация ===
    model_gemm.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in valid_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model_gemm(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.numel()
    print(f"Val acc: {correct / total * 100:.2f}%")

Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3125/3125 [04:41<00:00, 11.09it/s]

Train loss: 4.5089





Val acc: 14.97%


Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3125/3125 [04:38<00:00, 11.23it/s]

Train loss: 3.5733





Val acc: 24.50%


Epoch 3: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3125/3125 [04:37<00:00, 11.25it/s]

Train loss: 3.0637





Val acc: 30.00%


Epoch 4: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3125/3125 [04:39<00:00, 11.20it/s]

Train loss: 2.7205





Val acc: 35.48%


Epoch 5: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3125/3125 [04:35<00:00, 11.36it/s]

Train loss: 2.4424





Val acc: 39.73%
