# Model.ipynb — Pattern-3 (Pre-Residual Attention / Input Conditioning)

Bu notebook Pattern-3’ü **en baştan en sona** anlatır:
- Ne yapar, amaç nedir?
- Nasıl kodlanır? (CBAM örnekli)
- Stabil kullanım tüyoları
- Normal bir modele (stage yok) nasıl entegre edilir?

---

## Pattern-3 tek cümle
**Önce x attention’dan geçer → çıkan filtrelenmiş temsil F(·) içine girer → en sonda skip olarak ham x eklenir.**

\[ y = x + F(A(x)) \]

ASCII:
```bash
x ──► A(.) ──► F(.) ──► (+) ──► y
│
└──── skip (x) ────────────────┘
```


## 1) Pattern-3 ne yapıyor?

En basit akış:
1) `x_att = A(x)`  → attention, x’i **ölçekler/filtreler** (maskeler)
2) `f = F(x_att)`  → conv/bn/act gibi residual dönüşüm
3) `y = skip(x) + f` → skip (identity) en sonda eklenir

**Pattern-1 farkı:** Pattern-1’de attention `F(x)` üstünde; Pattern-3’te attention **x’in üstünde**.

## 2) Amaç + Risk
**Amaç:** Residual branch’e girecek sinyali daha baştan seçici yapmak (erken gürültü baskılama, hafif backbone’larda verim).

**Risk:** Attention kötü öğrenirse residual branch’e giden bilgi azalır → `F` yanlış/eksik öğrenebilir.
Bu yüzden pratikte Pattern-3 çoğu zaman **kontrollü** kullanılır (λ ile input mixing).


## 3) CBAM ile Pattern-3 (A(x))

CBAM tipik olarak:
- Channel mask (B,C,1,1)
- Spatial mask (B,1,H,W)

ve pratikte `A(x) ⊙ x` etkisini üretir (x’i ölçekler/filtreler).

Aşağıda CBAM’i sıfırdan yazıyoruz.


In [None]:
import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 4)
        self.avg = nn.AdaptiveAvgPool2d(1)
        self.max = nn.AdaptiveMaxPool2d(1)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, 1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sigmoid(self.mlp(self.avg(x)) + self.mlp(self.max(x)))  # (B,C,1,1)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_map = torch.mean(x, dim=1, keepdim=True)          # (B,1,H,W)
        max_map, _ = torch.max(x, dim=1, keepdim=True)        # (B,1,H,W)
        cat = torch.cat([avg_map, max_map], dim=1)            # (B,2,H,W)
        return self.sigmoid(self.conv(cat))                   # (B,1,H,W)

class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ca(x) * x
        x = self.sa(x) * x
        return x

## 4) Residual dönüşüm F(·) ve Skip eşitleme

- `F(·)`: BasicBlock tarzı 3×3→3×3 residual dönüşüm
- Skip yolu: stride/kanal değişince 1×1 conv + BN ile eşitlenir


In [None]:
class FxConv(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.act   = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        f = self.act(self.bn1(self.conv1(x)))
        f = self.bn2(self.conv2(f))
        return f

def make_skip(in_ch: int, out_ch: int, stride: int):
    if stride != 1 or in_ch != out_ch:
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
            nn.BatchNorm2d(out_ch),
        )
    return nn.Identity()

## 5) Pattern-3 Block (Temel)

Akış:
1) `x_att = CBAM(x)`
2) `f = F(x_att)`
3) `y = skip(x) + f`

\[ y = \text{skip}(x) + F(\text{CBAM}(x)) \]


In [3]:
class Pattern3_PreResidualCBAM(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
                 reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.attn = CBAM(in_ch, reduction=reduction, spatial_kernel=spatial_kernel)
        self.F = FxConv(in_ch, out_ch, stride=stride)
        self.skip = make_skip(in_ch, out_ch, stride)
        self.out_act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)
        x_att = self.attn(x)
        f = self.F(x_att)
        y = identity + f
        return self.out_act(y)

# quick test
x = torch.randn(2, 64, 32, 32)
blk = Pattern3_PreResidualCBAM(64, 64)
print('P3 basic:', blk(x).shape)


P3 basic: torch.Size([2, 64, 32, 32])


## 6) Stabil kullanım tüyoları (en önemlisi)

Pattern-3’te risk: attention residual branch’e giden bilgiyi fazla kısabilir.
En pratik çözüm: **λ ile input mixing**.

\[ x_{att}=A(x)\odot x \]
\[ \tilde{x}=(1-\lambda)x+\lambda x_{att} \]
\[ y=skip(x)+F(\tilde{x}) \]

Öneriler:
- `λ` küçük başlat: 0.05–0.2
- `λ` öğrenilebilir olsun (sigmoid ile 0–1 aralığında)
- Büyük data yoksa warm-up iyi çalışır
- Çok erken katmanda kontrolsüz Pattern-3 agresif olabilir


In [None]:
class Pattern3_PreResidualCBAM_Controlled(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1,
                 lam_init: float = 0.1, lam_learnable: bool = True,
                 reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.attn = CBAM(in_ch, reduction=reduction, spatial_kernel=spatial_kernel)
        self.F = FxConv(in_ch, out_ch, stride=stride)
        self.skip = make_skip(in_ch, out_ch, stride)
        self.out_act = nn.ReLU(inplace=True)

        lam_init = float(lam_init)
        lam_init = min(max(lam_init, 1e-4), 1 - 1e-4)
        lam_logit = torch.log(torch.tensor(lam_init) / (1 - torch.tensor(lam_init)))
        if lam_learnable:
            self.lam_logit = nn.Parameter(lam_logit)
        else:
            self.register_buffer('lam_logit', lam_logit)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)
        x_att = self.attn(x)

        lam = torch.sigmoid(self.lam_logit)
        x_tilde = (1.0 - lam) * x + lam * x_att

        f = self.F(x_tilde)
        y = identity + f
        return self.out_act(y)

x = torch.randn(2, 64, 32, 32)
blk = Pattern3_PreResidualCBAM_Controlled(64, 64, lam_init=0.1)
print('P3 controlled:', blk(x).shape, 'lambda=', float(torch.sigmoid(blk.lam_logit)))

P3 controlled: torch.Size([2, 64, 32, 32]) lambda= 0.10000000149011612


## 7) Normal modele entegrasyon (stage yok)

Düz mimari:
- stem (conv+bn+relu)
- conv
- p3_1 (same resolution)
- p3_2 (downsample + kanal artışı)
- head (GAP + FC)


In [5]:
class SimpleCNN_With_Pattern3(nn.Module):
    def __init__(self, num_classes: int = 10, lam_init: float = 0.1):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
        )
        self.conv1 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.p3_1 = Pattern3_PreResidualCBAM_Controlled(64, 64, stride=1, lam_init=lam_init)
        self.p3_2 = Pattern3_PreResidualCBAM_Controlled(64, 128, stride=2, lam_init=lam_init)

        self.conv2 = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.conv1(x)
        x = self.p3_1(x)
        x = self.p3_2(x)
        x = self.conv2(x)
        x = self.head(x)
        return x

m = SimpleCNN_With_Pattern3(num_classes=10, lam_init=0.1)
x = torch.randn(4, 3, 32, 32)
y = m(x)
print('model out:', y.shape)
print('lambda p3_1:', float(torch.sigmoid(m.p3_1.lam_logit)))
print('lambda p3_2:', float(torch.sigmoid(m.p3_2.lam_logit)))


model out: torch.Size([4, 10])
lambda p3_1: 0.10000000149011612
lambda p3_2: 0.10000000149011612


## 8) Son mini checklist

- [ ] Attention **x üstünde mi?** (Pattern-3 = evet)
- [ ] `F` attention sonrası girdiyi mi alıyor?
- [ ] Skip eşitleme var mı? (stride/kanal değişince)
- [ ] Stabilite için controlled (λ) kullandın mı?
- [ ] Çok erken katmanda agresif kullanım var mı? (gerek yoksa azalt)
