# RepVGG — Eğitimde Çok Dal, Inference’ta Tek 3×3 (Re-parameterization)

Bu notebook’ta **RepVGG** fikrini sıfırdan ve uygulamalı öğreneceğiz.

## Bu notebook sonunda şunları net bileceksin
1) RepVGG bir "conv türü" değil, bir **mimari/deploy fikri**
2) Training-time block (3×3 + 1×1 + identity) nasıl çalışır?
3) BatchNorm (BN) folding nedir?
4) Multi-branch → **tek 3×3 Conv** dönüşümü (re-parameterization)
5) Dönüşümün doğrulanması (çıktılar neredeyse aynı mı?)
6) RepVGG kullanan mini model

Not: RepVGG’in olayı "eğitimde kolay öğren, inference’ta hızlı koş" yaklaşımıdır.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
print('torch:', torch.__version__)


torch: 2.9.1+cpu


---
## 1) RepVGG Nedir?

RepVGG’in ana fikri:
- **Training zamanında** bir blok içinde birden fazla dal (branch) kullan.
- **Inference zamanında** bu dalları matematiksel olarak birleştir ve tek bir 3×3 conv’a indir.

### Neden?
- Training’de multi-branch yapı + BN genelde öğrenmeyi kolaylaştırır.
- Inference’ta branch’lar ve BN operasyonları hız/latency için kötüdür.

RepVGG: "Training’i rahatlat, deploy’u hızlandır".


---
## 2) Training-time RepVGG Block Yapısı

Tipik RepVGG bloğu (stride=1 ise):

```bash
             ┌─ Conv3×3 + BN ─┐
x ───────────┼─ Conv1×1 + BN ─┼─ (+) ─ ReLU
             └─ Identity + BN ─┘
```

stride=2 olursa identity dalı yoktur (shape uymadığı için).

Bu dalların hepsi **lineer** işlemlerdir (Conv ve BN lineer),
bu yüzden **tek bir Conv** altında birleştirilebilir.


---
## 3) BatchNorm Folding (BN’yi Conv’a Katlamak)

BN inference formu:

$y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$

Eğer $x = Conv(z, W) + b$ ise, BN ile birleşip yeni conv üretir:

$W' = W \cdot \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}}$

$b' = (b - \mu) \cdot \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} + \beta$

Bu sayede BN inference’ta kaldırılır.


---
## 4) 1×1 Conv’u 3×3’e Dönüştürme (Padding)

1×1 kernel’i 3×3’ün merkezine koyup diğer yerleri 0 yaparız.
Bu sayede 1×1 branch de 3×3 kernel gibi temsil edilir.


---
## 5) Identity Dalını 3×3 Conv Gibi Yazma

Identity (stride=1, cin=cout) özel bir 3×3 kernel ile ifade edilir:
- Sadece diagonal kanal eşleşmeleri
- Merkez (1,1) = 1
- Diğer her yer 0


---
## 6) Kod: BN Folding ve Kernel Birleştirme Yardımcıları


In [2]:
def fuse_conv_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d):
    """Conv + BN -> (W_fused, b_fused). conv.bias=None varsayar."""
    assert conv.bias is None
    W = conv.weight
    gamma = bn.weight
    beta = bn.bias
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps

    scale = gamma / torch.sqrt(var + eps)
    W_fused = W * scale.reshape(-1, 1, 1, 1)
    b_fused = beta - mean * scale
    return W_fused, b_fused


def pad_1x1_to_3x3(W_1x1: torch.Tensor):
    """(C_out,C_in,1,1) -> (C_out,C_in,3,3)"""
    if W_1x1.size(-1) == 3:
        return W_1x1
    assert W_1x1.size(-1) == 1
    W_3x3 = torch.zeros((W_1x1.size(0), W_1x1.size(1), 3, 3), device=W_1x1.device, dtype=W_1x1.dtype)
    W_3x3[:, :, 1:2, 1:2] = W_1x1
    return W_3x3


def fuse_identity_bn(num_channels: int, bn: nn.BatchNorm2d, device=None, dtype=None):
    """Identity + BN -> (W_fused, b_fused) şeklinde 3x3 conv kernel üret."""
    if device is None:
        device = bn.weight.device
    if dtype is None:
        dtype = bn.weight.dtype

    W = torch.zeros((num_channels, num_channels, 3, 3), device=device, dtype=dtype)
    for i in range(num_channels):
        W[i, i, 1, 1] = 1.0

    gamma = bn.weight
    beta = bn.bias
    mean = bn.running_mean
    var = bn.running_var
    eps = bn.eps

    scale = gamma / torch.sqrt(var + eps)
    W_fused = W * scale.reshape(-1, 1, 1, 1)
    b_fused = beta - mean * scale
    return W_fused, b_fused


---
## 7) RepVGGBlock (Training-time) + Deploy’a Dönüşüm


In [3]:
class RepVGGBlock(nn.Module):
    def __init__(self, cin, cout, stride=1, deploy=False):
        super().__init__()
        self.cin = cin
        self.cout = cout
        self.stride = stride
        self.deploy = deploy
        self.act = nn.ReLU(inplace=True)

        if deploy:
            self.rbr_reparam = nn.Conv2d(cin, cout, 3, stride=stride, padding=1, bias=True)
        else:
            self.rbr_3x3 = nn.Sequential(
                nn.Conv2d(cin, cout, 3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(cout)
            )
            self.rbr_1x1 = nn.Sequential(
                nn.Conv2d(cin, cout, 1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(cout)
            )
            if cout == cin and stride == 1:
                self.rbr_identity = nn.BatchNorm2d(cout)
            else:
                self.rbr_identity = None

    def forward(self, x):
        if self.deploy:
            return self.act(self.rbr_reparam(x))

        out = self.rbr_3x3(x) + self.rbr_1x1(x)
        if self.rbr_identity is not None:
            out = out + self.rbr_identity(x)
        return self.act(out)

    def get_equivalent_kernel_bias(self):
        conv3, bn3 = self.rbr_3x3[0], self.rbr_3x3[1]
        W3, b3 = fuse_conv_bn(conv3, bn3)

        conv1, bn1 = self.rbr_1x1[0], self.rbr_1x1[1]
        W1, b1 = fuse_conv_bn(conv1, bn1)
        W1 = pad_1x1_to_3x3(W1)

        if self.rbr_identity is not None:
            Wid, bid = fuse_identity_bn(self.cout, self.rbr_identity, device=W3.device, dtype=W3.dtype)
        else:
            Wid = torch.zeros_like(W3)
            bid = torch.zeros_like(b3)

        W_eq = W3 + W1 + Wid
        b_eq = b3 + b1 + bid
        return W_eq, b_eq

    def switch_to_deploy(self):
        if self.deploy:
            return
        W, b = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(self.cin, self.cout, 3, stride=self.stride, padding=1, bias=True)
        self.rbr_reparam.weight.data = W
        self.rbr_reparam.bias.data = b

        del self.rbr_3x3
        del self.rbr_1x1
        if hasattr(self, 'rbr_identity'):
            del self.rbr_identity
        self.deploy = True


---
## 8) Dönüşümü Doğrulama (Training vs Deploy çıktısı)

Önemli: BN folding doğrulaması için `eval()` gerekir.


In [4]:
blk = RepVGGBlock(cin=16, cout=16, stride=1, deploy=False)
blk.eval()

x = torch.randn(2, 16, 32, 32)
with torch.no_grad():
    y_train = blk(x)

blk.switch_to_deploy()
blk.eval()
with torch.no_grad():
    y_deploy = blk(x)

print('max abs diff:', (y_train - y_deploy).abs().max().item())
print('shapes:', y_train.shape, y_deploy.shape)


max abs diff: 3.337860107421875e-06
shapes: torch.Size([2, 16, 32, 32]) torch.Size([2, 16, 32, 32])


---
## 9) Mini RepVGG Model


In [5]:
class MiniRepVGG(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, deploy=False):
        super().__init__()
        self.deploy = deploy

        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )

        self.stage1 = nn.Sequential(
            RepVGGBlock(32, 64, stride=2, deploy=deploy),
            RepVGGBlock(64, 64, stride=1, deploy=deploy),
        )
        self.stage2 = nn.Sequential(
            RepVGGBlock(64, 128, stride=2, deploy=deploy),
            RepVGGBlock(128, 128, stride=1, deploy=deploy),
        )
        self.stage3 = nn.Sequential(
            RepVGGBlock(128, 256, stride=2, deploy=deploy),
            RepVGGBlock(256, 256, stride=1, deploy=deploy),
        )

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x, verbose=False):
        if verbose: print('input:', x.shape)
        x = self.stem(x)
        if verbose: print('stem :', x.shape)
        x = self.stage1(x)
        if verbose: print('s1   :', x.shape)
        x = self.stage2(x)
        if verbose: print('s2   :', x.shape)
        x = self.stage3(x)
        if verbose: print('s3   :', x.shape)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        out = self.fc(x)
        if verbose: print('out  :', out.shape)
        return out

    def switch_to_deploy(self):
        for m in self.modules():
            if isinstance(m, RepVGGBlock):
                m.switch_to_deploy()
        self.deploy = True


In [6]:
model = MiniRepVGG(in_channels=3, num_classes=10, deploy=False)
x = torch.randn(2, 3, 64, 64)

model.eval()
with torch.no_grad():
    y1 = model(x, verbose=True)

model.switch_to_deploy()
model.eval()
with torch.no_grad():
    y2 = model(x, verbose=True)

print('\nmax abs diff (model):', (y1 - y2).abs().max().item())


input: torch.Size([2, 3, 64, 64])
stem : torch.Size([2, 32, 64, 64])
s1   : torch.Size([2, 64, 32, 32])
s2   : torch.Size([2, 128, 16, 16])
s3   : torch.Size([2, 256, 8, 8])
out  : torch.Size([2, 10])
input: torch.Size([2, 3, 64, 64])
stem : torch.Size([2, 32, 64, 64])
s1   : torch.Size([2, 64, 32, 32])
s2   : torch.Size([2, 128, 16, 16])
s3   : torch.Size([2, 256, 8, 8])
out  : torch.Size([2, 10])

max abs diff (model): 2.9802322387695312e-08


---
## 10) Ne Zaman RepVGG Mantıklı?

**Mantıklı:** deploy/latency kritik, hızlı backbone isteniyor.

**Dikkat:** training tarafında branch maliyeti var; ama inference’ta tek conv’a iner.
