-----

### Önce feature map (x) attention’a verilir, attention’dan çıkan filtrelenmiş çıktı F(·) içine sokulur ve en sonda skip yolu olarak ham x eklenir.,

-----

# Pattern 3 — Pre-Residual Attention (Input Conditioning / Gated Input)

Bu notebook **Pattern-3**'ü baştan sona anlatır ve Pattern-1 / Pattern-2 ile **attention yerleşimi** üzerinden net karşılaştırır.

---

## Hızlı Özet (3 pattern yan yana)

- **Pattern-1 (Inside residual):**  \(y = x + (A(F(x)) \odot F(x))\)  
- **Pattern-2 (Post-addition):**  \(z = x + F(x),\; y = A(z)\odot z\)  
- **Pattern-3 (Pre-residual):**  \(y = x + F(A(x))\)

Pattern-3'te attention, **residual branch'e girecek sinyali** filtreler; skip yolu ayrı akar.

---

## ASCII Akış

```BASH
            ┌─────────────── Skip (identity) ────────────────┐
            │                                                │
x ──► A(.) ─┴─► (gated input) ─► F(.) ───────────────►  (+)  ├─► y
```


## 1) Residual temel hatırlatma

Klasik residual blok:

\[ y = x + F(x) \]

- `x` (skip): referans bilgi + stabil gradient yolu  
- `F(x)` (residual): öğrenilen katkı/düzeltme


## 2) Pattern-3'te ne değişiyor?

Pattern-3'te residual branch'e giren sinyal **x değil**, attention ile filtrelenmiş sinyaldir:

\[ \tilde{x} = A(x) \odot x \]

Residual:

\[ f = F(\tilde{x}) \]

Çıkış:

\[ y = \text{skip}(x) + f \]

Okuma: "Önce x'i seçici hale getir, sonra F bu seçilmiş bilgi üzerinden öğrenir."


## 3) Amaç ve risk

### Amaç
- Erken gürültü baskılama
- Hafif backbone'larda erken seçicilik
- F'nin öğrenmesini daha "temiz" girdiye şartlamak

### Kritik risk
`A(x)` yanlış öğrenirse residual branch'in gördüğü bilgi azalır → `F` kötü/eksik öğrenebilir.

Bu yüzden Pattern-3 çoğu zaman **kontrollü** uygulanır (lambda ile input mixing).


## 4) CBAM ile Pattern-3

CBAM iki maske uygular:

1) Channel attention: \(\alpha_c\in\mathbb{R}^{C\times1\times1}\)  
2) Spatial attention: \(\alpha_s\in\mathbb{R}^{1\times H\times W}\)

Sıra:

\[ x' = \alpha_c(x)\odot x \]  
\[ \tilde{x} = \alpha_s(x')\odot x' \]

Sonra:

\[ y = x + F(\tilde{x}) \]


## 5) Kod — CBAM (PyTorch)

In [1]:
import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 4)
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.max = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, 1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sigmoid(self.mlp(self.avg(x)) + self.mlp(self.max(x)))  # (B,C,1,1)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_map = torch.mean(x, dim=1, keepdim=True)          # (B,1,H,W)
        max_map, _ = torch.max(x, dim=1, keepdim=True)        # (B,1,H,W)
        cat = torch.cat([avg_map, max_map], dim=1)            # (B,2,H,W)
        return self.sigmoid(self.conv(cat))                   # (B,1,H,W)

class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ca(x) * x
        x = self.sa(x) * x
        return x


## 6) Kod — Pattern-3 Residual Block (Pre-Residual Attention)

Akış:

1) `x_att = CBAM(x)`  (CBAM kendi içinde maske üretir ve çarpar, yani `A(x) ⊙ x`)  
2) `f = F(x_att)`  
3) `y = skip(x) + f`

\[ y = \text{skip}(x) + F(\text{CBAM}(x)) \]


In [2]:
import torch
import torch.nn as nn

def make_skip(in_ch: int, out_ch: int, stride: int):
    if stride != 1 or in_ch != out_ch:
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
            nn.BatchNorm2d(out_ch),
        )
    return nn.Identity()

class FxConv(nn.Module):
    # Basit F(): 3x3 -> 3x3 (ResNet Basic)
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.act   = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.act(self.bn1(self.conv1(x)))
        f = self.bn2(self.conv2(f))
        return f

class PreResidualAttentionBlock_P3(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
                 reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.attn = CBAM(in_ch, reduction=reduction, spatial_kernel=spatial_kernel)
        self.F = FxConv(in_ch, out_ch, stride=stride)
        self.skip = make_skip(in_ch, out_ch, stride)
        self.out_act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)
        x_att = self.attn(x)
        f = self.F(x_att)
        y = identity + f
        return self.out_act(y)

# Shape sanity check
x = torch.randn(2, 64, 32, 32)
blk = PreResidualAttentionBlock_P3(64, 64, stride=1)
print("P3 block output:", blk(x).shape)


P3 block output: torch.Size([2, 64, 32, 32])


## 7) Kontrollü Pattern-3 (önerilen): lambda ile input mixing

Pattern-3'te "kapıyı tamamen kapatma" yaklaşımı:

- `x_att = A(x) ⊙ x`
- `x_tilde = (1-λ)·x + λ·x_att`
- `y = skip(x) + F(x_tilde)`

\[ x_{att} = A(x)\odot x \]  
\[ \tilde{x} = (1-\lambda)x + \lambda x_{att} \]  
\[ y = \text{skip}(x) + F(\tilde{x}) \]

`λ` küçük başlatılır (örn. 0.1) ve öğrenilebilir yapılır.


In [3]:
import torch
import torch.nn as nn

class ControlledPreResidualAttentionBlock_P3(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
                 lam_init: float = 0.1, lam_learnable: bool = True,
                 reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.attn = CBAM(in_ch, reduction=reduction, spatial_kernel=spatial_kernel)
        self.F = FxConv(in_ch, out_ch, stride=stride)
        self.skip = make_skip(in_ch, out_ch, stride)
        self.out_act = nn.ReLU(inplace=True)

        lam_init = float(lam_init)
        lam_init = min(max(lam_init, 1e-4), 1 - 1e-4)
        lam_logit = torch.log(torch.tensor(lam_init) / (1 - torch.tensor(lam_init)))
        if lam_learnable:
            self.lam_logit = nn.Parameter(lam_logit)
        else:
            self.register_buffer("lam_logit", lam_logit)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)

        x_att = self.attn(x)
        lam = torch.sigmoid(self.lam_logit)          # scalar in (0,1)
        x_tilde = (1.0 - lam) * x + lam * x_att      # controlled input

        f = self.F(x_tilde)
        y = identity + f
        return self.out_act(y)

x = torch.randn(2, 64, 32, 32)
blk = ControlledPreResidualAttentionBlock_P3(64, 64, lam_init=0.1)
print("Controlled P3:", blk(x).shape, "lambda:", torch.sigmoid(blk.lam_logit).item())


Controlled P3: torch.Size([2, 64, 32, 32]) lambda: 0.10000000149011612


## 8) Normal modele entegrasyon (stage yok, düz akış)

Aşağıdaki model: `stem -> conv -> (P3 blokları) -> head`.

Pattern-3 için öneri: kontrollü versiyon.


In [4]:
import torch
import torch.nn as nn

class SimpleCNN_P3_Controlled(nn.Module):
    def __init__(self, num_classes: int = 10, lam_init: float = 0.1):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )

        self.conv1 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.p3_1 = ControlledPreResidualAttentionBlock_P3(64, 64, stride=1, lam_init=lam_init)
        self.p3_2 = ControlledPreResidualAttentionBlock_P3(64, 128, stride=2, lam_init=lam_init)

        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.conv1(x)
        x = self.p3_1(x)
        x = self.p3_2(x)
        x = self.conv2(x)
        x = self.head(x)
        return x

if __name__ == "__main__":
    m = SimpleCNN_P3_Controlled(num_classes=10, lam_init=0.1)
    x = torch.randn(4, 3, 32, 32)
    y = m(x)
    print("Output:", y.shape)
    print("lambda p3_1:", torch.sigmoid(m.p3_1.lam_logit).item())
    print("lambda p3_2:", torch.sigmoid(m.p3_2.lam_logit).item())


Output: torch.Size([4, 10])
lambda p3_1: 0.10000000149011612
lambda p3_2: 0.10000000149011612


## 9) Mini kontrol listesi (Pattern-3)

- [ ] Attention girişte mi? (`x -> A(x) -> x_tilde`)
- [ ] `F(.)` x_tilde üstünde mi?
- [ ] Skip yolu en sonda toplanıyor mu?
- [ ] Boyut/kanal değişiminde skip eşitleme var mı?
- [ ] Risk için input mixing (lambda) var mı?
