# CBAM Channel + Coordinate Attention — Adım 1’den Adım 2’ye Geçiş

## İçindekiler
1. Amaç ve kapsam  
2. Adım 1: Baseline mimari  
3. Adım 2 hedefleri  
4. Adım 2 değişiklikleri (modül modül)  
5. Kontrol listesi  
6. Adım 1 kodu  
7. Adım 2 kodu  
8. Hızlı test  
9. Çıktıları okuma notları  


## 1) Amaç ve kapsam

Bu defter, **Adım 1**’deki CBAM Channel + Coordinate Attention bloğunu alıp, **Adım 2**’de YOLO eğitiminde daha stabil olacak şekilde yaptığın ekleri **başlık başlık** ve **kodun adım adım mantığıyla** açıklar.

Kural: Burada amaç, blok agresifliğini kontrol etmek ve “çok bastırma” durumunda kendini güvenli tarafa çekmektir.


## 2) Adım 1: Baseline mimari

Adım 1 zinciri:
- CA: (avg/max) → MLP → sum/softmax router → ca_map → `y = x * ca_map`
- Coord: H/W profilleri → dw + dilated dw → karıştır → bottleneck → attn_h/attn_w → scale → `out = x * scale`
- Birleşim: `out = x + alpha*(y - x)` (residual açıksa)

Bu versiyon çalışır ama YOLO’da bazı durumlarda “çok bastırma” (std düşmesi) riski taşır.


## 3) Adım 2 hedefleri

Adım 2’de hedef:
- CA çarpımını yumuşat (beta_ca)
- Learnable T uçmasın (T clamp)
- Router softmax erken kilitlenmesin (Tr)
- Coord ölçeği uçmasın (scale clamp)
- Coord head’ler agresif başlamasın (small std init + bias=0)
- Eğitimde bastırmayı ölç ve gerekirse residual karışımı otomatik yumuşat (EMA + alpha_eff)
- Monitor hesapları sadece training’de ve monitor=True iken çalışsın


## 4) Adım 2 değişiklikleri (modül modül)

### 4A) ChannelAttentionFusionT (CA)
**Yeni parametreler:**
- `t_min, t_max`: learnable temperature güvenliği
- `router_temperature (Tr)`: router softmax yumuşatma
- `beta_ca`: CA’yı yumuşak ölçek şeklinde uygulama

**Ne değişti?**
- Adım 1: `y = x * ca`
- Adım 2: `scale_ca = 1 + beta_ca*(ca-1)` ve `y = x * scale_ca`

Ayrıca router tarafı: `softmax(logits / Tr)`.

### 4B) CoordinateAttPlus (Coord)
**Yeni parametreler:**
- `scale_min, scale_max`: `scale` clamp
- `head_init_std`: attention head init’i yumuşatma

**Ne değişti?**
- scale hesaplandıktan sonra clamp var.
- head weight küçük std ile init, bias=0.

### 4C) CBAMChannelPlusCoord (Birleşim)
**Yeni mekanizma: monitor + rescue**
- `r_out = std(out_tmp) / std(x)` ölçülür.
- `r_ema` ile EMA tutulur.
- `alpha_eff` rescue_mode’a göre hesaplanır.

**rescue_mode:**
- `ratio_floor`: ratio tabandan kesilir, `alpha_eff = alpha * ratio`
- `alpha_floor`: `alpha_eff` doğrudan alt sınıra clamp (daha agresif)


## 5) Kontrol listesi

- [ ] CA: `get_T(x)` içinde `clamp(t_min, t_max)`
- [ ] Router: `softmax(logits / Tr)`
- [ ] CA: `y = x * (1 + beta_ca*(ca-1))`
- [ ] Coord: `scale.clamp(scale_min, scale_max)`
- [ ] Coord head init: küçük std + bias=0
- [ ] Monitor: sadece `training and monitor=True`
- [ ] EMA: `r_ema` buffer + update
- [ ] Rescue: `alpha_eff` hesaplama ve `rescue_mode`


## 6) Adım 1 kodu (baseline)
Bu hücre Adım 1’in referans halidir.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def softplus_inverse(y: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return torch.log(torch.clamp(torch.exp(y) - 1.0, min=eps))


def _get_gate(gate: str):
    g = gate.lower()
    if g == "sigmoid":
        return torch.sigmoid
    if g == "hardsigmoid":
        return F.hardsigmoid
    raise ValueError("gate 'sigmoid' veya 'hardsigmoid' olmalı.")


def _get_act(act: str):
    a = act.lower()
    if a == "relu":
        return nn.ReLU(inplace=True)
    if a == "silu":
        return nn.SiLU(inplace=True)
    raise ValueError("act 'relu' veya 'silu' olmalı.")


class ChannelAttentionFusionT(nn.Module):
    def __init__(
        self,
        channels: int,
        reduction: int = 16,
        min_hidden: int = 4,
        fusion: str = "softmax",
        gate: str = "sigmoid",
        temperature: float = 0.9,
        learnable_temperature: bool = False,
        eps: float = 1e-6,
        act: str = "relu",
        bias: bool = True,
        fusion_router_hidden: int = 16,
        return_fusion_weights: bool = False,
    ):
        super().__init__()

        if channels < 1:
            raise ValueError("channels >= 1 olmalı.")
        if reduction < 1:
            raise ValueError("reduction >= 1 olmalı.")
        if fusion not in ("sum", "softmax"):
            raise ValueError("fusion 'sum' veya 'softmax' olmalı.")
        if temperature <= 0:
            raise ValueError("temperature pozitif olmalı.")
        if fusion == "softmax" and fusion_router_hidden < 1:
            raise ValueError("fusion_router_hidden >= 1 olmalı.")

        self.eps = float(eps)
        self.fusion = fusion
        self.return_fusion_weights = bool(return_fusion_weights)
        self.gate_fn = _get_gate(gate)

        hidden = max(int(min_hidden), int(channels) // int(reduction))
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=bias)
        self.act = _get_act(act)
        self.fc2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=bias)

        if self.fusion == "softmax":
            self.fusion_router = nn.Sequential(
                nn.Conv2d(2 * channels, fusion_router_hidden, kernel_size=1, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(fusion_router_hidden, 2, kernel_size=1, bias=True),
            )
            last = self.fusion_router[-1]
            nn.init.zeros_(last.weight)
            nn.init.zeros_(last.bias)
        else:
            self.fusion_router = None

        self.learnable_temperature = bool(learnable_temperature)
        if self.learnable_temperature:
            t0 = torch.tensor(float(temperature))
            t_inv = softplus_inverse(t0, eps=self.eps)
            self.t_raw = nn.Parameter(t_inv)
        else:
            self.register_buffer("T", torch.tensor(float(temperature)))

    def get_T(self) -> torch.Tensor:
        if self.learnable_temperature:
            return F.softplus(self.t_raw) + self.eps
        return self.T

    def mlp(self, s: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(s)))

    def forward(self, x: torch.Tensor):
        avg_s = self.avg_pool(x)
        max_s = self.max_pool(x)

        a = self.mlp(avg_s)
        m = self.mlp(max_s)

        fusion_w = None
        if self.fusion == "sum":
            z = a + m
        else:
            s_cat = torch.cat([avg_s, max_s], dim=1)
            logits = self.fusion_router(s_cat).flatten(1)  # (B, 2)
            fusion_w = torch.softmax(logits, dim=1)
            w0 = fusion_w[:, 0].view(-1, 1, 1, 1)
            w1 = fusion_w[:, 1].view(-1, 1, 1, 1)
            z = w0 * a + w1 * m

        T = self.get_T().to(device=x.device, dtype=x.dtype)
        ca = self.gate_fn(z / T)
        y = x * ca

        if self.return_fusion_weights and (fusion_w is not None):
            return y, ca, fusion_w
        return y, ca


class HSwish(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * F.relu6(x + 3.0, inplace=True) / 6.0


def make_norm(norm: str, ch: int):
    norm = norm.lower()
    if norm == "bn":
        return nn.BatchNorm2d(ch)
    if norm == "gn":
        g = min(32, ch)
        while ch % g != 0 and g > 2:
            g //= 2
        if ch % g != 0:
            g = 2 if (ch % 2 == 0) else 1
        return nn.GroupNorm(g, ch)
    if norm == "none":
        return nn.Identity()
    raise ValueError("norm 'none', 'bn', 'gn' dışında olamaz.")


class CoordinateAttPlus(nn.Module):
    def __init__(
        self,
        in_channels: int,
        reduction: int = 32,
        min_mid_channels: int = 8,
        act: str = "hswish",
        init_alpha: float = 0.7,
        learnable_alpha: bool = True,
        beta: float = 0.35,
        dilation: int = 2,
        norm: str = "gn",
        use_spatial_gate: bool = False,
        spatial_gate_beta: float = 0.35,
    ):
        super().__init__()

        if in_channels < 1:
            raise ValueError("in_channels >= 1 olmalı.")
        if reduction < 1:
            raise ValueError("reduction >= 1 olmalı.")
        if dilation < 1:
            raise ValueError("dilation >= 1 olmalı.")

        mid_floor = max(8, min(32, int(in_channels) // 4))
        mid = max(int(min_mid_channels), int(in_channels) // int(reduction))
        mid = max(mid, int(mid_floor))

        act_l = act.lower()
        if act_l == "hswish":
            self.act = HSwish()
        elif act_l == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act_l == "silu":
            self.act = nn.SiLU(inplace=True)
        else:
            raise ValueError("act 'hswish', 'relu', 'silu' olmalı.")

        self.shared_bottleneck_proj = nn.Conv2d(in_channels, mid, 1, bias=False)
        self.shared_bottleneck_norm = make_norm(norm, mid)
        self.shared_bottleneck_refine = nn.Conv2d(mid, mid, 1, bias=False)
        self.shared_bottleneck_refine_norm = make_norm(norm, mid)

        self.h_local_dw = nn.Conv2d(in_channels, in_channels, kernel_size=(3, 1), padding=(1, 0), groups=in_channels, bias=False)
        self.w_local_dw = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1), groups=in_channels, bias=False)

        d = int(dilation)
        self.h_dilated_dw = nn.Conv2d(in_channels, in_channels, kernel_size=(3, 1), padding=(d, 0), dilation=(d, 1), groups=in_channels, bias=False)
        self.w_dilated_dw = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, d), dilation=(1, d), groups=in_channels, bias=False)

        self.h_channel_mixer = nn.Conv2d(in_channels, in_channels, 1, bias=True)
        self.w_channel_mixer = nn.Conv2d(in_channels, in_channels, 1, bias=True)

        self.h_attention_head = nn.Conv2d(mid, in_channels, 1, bias=True)
        self.w_attention_head = nn.Conv2d(mid, in_channels, 1, bias=True)

        self.beta = float(beta)

        eps = 1e-6
        a0 = float(init_alpha)
        a0 = min(max(a0, eps), 1.0 - eps)
        raw0 = torch.logit(torch.tensor(a0), eps=eps)

        if learnable_alpha:
            self.alpha_h_raw = nn.Parameter(raw0.clone())
            self.alpha_w_raw = nn.Parameter(raw0.clone())
        else:
            self.register_buffer("alpha_h_raw", raw0.clone())
            self.register_buffer("alpha_w_raw", raw0.clone())

        self.use_spatial_gate = bool(use_spatial_gate)
        self.spatial_gate_beta = float(spatial_gate_beta)
        if self.use_spatial_gate:
            self.spatial_gate_dw = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=False)
            self.spatial_gate_pw = nn.Conv2d(in_channels, in_channels, 1, bias=True)

        self._last_ah = None
        self._last_aw = None

    def forward(self, x: torch.Tensor):
        _, _, H, W = x.shape

        h_profile = 0.5 * (x.mean(dim=3, keepdim=True) + x.amax(dim=3, keepdim=True))
        w_profile = 0.5 * (x.mean(dim=2, keepdim=True) + x.amax(dim=2, keepdim=True))

        h_ms = self.h_channel_mixer(self.h_local_dw(h_profile) + self.h_dilated_dw(h_profile))
        w_ms = self.w_channel_mixer(self.w_local_dw(w_profile) + self.w_dilated_dw(w_profile))
        w_ms = w_ms.permute(0, 1, 3, 2)

        hw = torch.cat([h_ms, w_ms], dim=2)

        mid = self.act(self.shared_bottleneck_norm(self.shared_bottleneck_proj(hw)))
        mid = self.act(self.shared_bottleneck_refine_norm(self.shared_bottleneck_refine(mid)))

        mid_h, mid_w = torch.split(mid, [H, W], dim=2)
        mid_w = mid_w.permute(0, 1, 3, 2)

        attn_h = F.hardsigmoid(self.h_attention_head(mid_h), inplace=False)
        attn_w = F.hardsigmoid(self.w_attention_head(mid_w), inplace=False)

        self._last_ah = attn_h.detach()
        self._last_aw = attn_w.detach()

        alpha_h = torch.sigmoid(self.alpha_h_raw).to(device=x.device, dtype=x.dtype)
        alpha_w = torch.sigmoid(self.alpha_w_raw).to(device=x.device, dtype=x.dtype)

        scale_h = (1.0 - alpha_h) + alpha_h * attn_h
        scale_w = (1.0 - alpha_w) + alpha_w * attn_w

        scale = scale_h * scale_w
        scale = 1.0 + self.beta * (scale - 1.0)

        out = x * scale

        if self.use_spatial_gate:
            sg = self.spatial_gate_pw(self.spatial_gate_dw(x))
            sg = F.hardsigmoid(sg, inplace=False)
            sg = 1.0 + self.spatial_gate_beta * (sg - 1.0)
            out = out * sg

        return out

    @torch.no_grad()
    def last_mask_stats(self):
        if (self._last_ah is None) or (self._last_aw is None):
            return None
        ah = self._last_ah
        aw = self._last_aw
        return {
            "a_h": {"min": float(ah.min()), "mean": float(ah.mean()), "max": float(ah.max()), "std": float(ah.std())},
            "a_w": {"min": float(aw.min()), "mean": float(aw.mean()), "max": float(aw.max()), "std": float(aw.std())},
        }


class CBAMChannelPlusCoord(nn.Module):
    def __init__(
        self,
        channels: int,
        ca_reduction: int = 16,
        ca_min_hidden: int = 4,
        ca_fusion: str = "softmax",
        ca_gate: str = "sigmoid",
        ca_temperature: float = 0.9,
        ca_act: str = "relu",
        ca_fusion_router_hidden: int = 16,
        learnable_temperature: bool = False,
        coord_reduction: int = 32,
        coord_min_mid: int = 8,
        coord_act: str = "hswish",
        coord_init_alpha: float = 0.7,
        coord_learnable_alpha: bool = True,
        coord_beta: float = 0.35,
        coord_dilation: int = 2,
        coord_norm: str = "gn",
        coord_use_spatial_gate: bool = False,
        coord_spatial_gate_beta: float = 0.35,
        residual: bool = True,
        alpha_init: float = 0.75,
        learnable_alpha: bool = False,
        return_maps: bool = False,
    ):
        super().__init__()
        if channels < 1:
            raise ValueError("channels >= 1 olmalı.")

        self.return_maps = bool(return_maps)
        self.residual = bool(residual)

        self.ca = ChannelAttentionFusionT(
            channels=channels,
            reduction=ca_reduction,
            min_hidden=ca_min_hidden,
            fusion=ca_fusion,
            gate=ca_gate,
            temperature=ca_temperature,
            learnable_temperature=learnable_temperature,
            eps=1e-6,
            act=ca_act,
            bias=True,
            fusion_router_hidden=ca_fusion_router_hidden,
            return_fusion_weights=self.return_maps,
        )

        self.coord = CoordinateAttPlus(
            in_channels=channels,
            reduction=coord_reduction,
            min_mid_channels=coord_min_mid,
            act=coord_act,
            init_alpha=coord_init_alpha,
            learnable_alpha=coord_learnable_alpha,
            beta=coord_beta,
            dilation=coord_dilation,
            norm=coord_norm,
            use_spatial_gate=coord_use_spatial_gate,
            spatial_gate_beta=coord_spatial_gate_beta,
        )

        if self.residual:
            eps = 1e-6
            a0 = float(alpha_init)
            a0 = min(max(a0, eps), 1.0 - eps)
            raw0 = torch.logit(torch.tensor(a0), eps=eps)
            if learnable_alpha:
                self.alpha_raw = nn.Parameter(raw0)
            else:
                self.register_buffer("alpha_raw", raw0)

    def _alpha(self, x: torch.Tensor) -> torch.Tensor:
        if not hasattr(self, "alpha_raw"):
            return x.new_tensor(1.0)
        return torch.sigmoid(self.alpha_raw).to(device=x.device, dtype=x.dtype)

    def forward(self, x: torch.Tensor):
        if self.return_maps:
            y, ca_map, fusion_w = self.ca(x)
            y = self.coord(y)
            out = x + self._alpha(x) * (y - x) if self.residual else y
            coord_stats = self.coord.last_mask_stats()
            return out, ca_map, fusion_w, coord_stats

        y, _ = self.ca(x)
        y = self.coord(y)
        out = x + self._alpha(x) * (y - x) if self.residual else y
        return out


## 7) Adım 2 kodu (eklemeler dahil)
Bu hücre Adım 2’nin tam halidir.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def softplus_inverse(y: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    return torch.log(torch.clamp(torch.exp(y) - 1.0, min=eps))


def _get_gate(gate: str):
    g = gate.lower()
    if g == "sigmoid":
        return torch.sigmoid
    if g == "hardsigmoid":
        return F.hardsigmoid
    raise ValueError("gate 'sigmoid' veya 'hardsigmoid' olmalı.")


def _get_act(act: str):
    a = act.lower()
    if a == "relu":
        return nn.ReLU(inplace=True)
    if a == "silu":
        return nn.SiLU(inplace=True)
    raise ValueError("act 'relu' veya 'silu' olmalı.")


class ChannelAttentionFusionT(nn.Module):
    def __init__(
        self,
        channels: int,
        reduction: int = 16,
        min_hidden: int = 4,
        fusion: str = "softmax",
        gate: str = "sigmoid",
        temperature: float = 0.9,
        learnable_temperature: bool = False,
        eps: float = 1e-6,
        act: str = "relu",
        bias: bool = True,
        fusion_router_hidden: int = 16,
        return_fusion_weights: bool = False,
        t_min: float = 0.5,
        t_max: float = 3.0,
        router_temperature: float = 1.5,
        beta_ca: float = 0.35,
    ):
        super().__init__()

        if channels < 1:
            raise ValueError("channels >= 1 olmalı.")
        if reduction < 1:
            raise ValueError("reduction >= 1 olmalı.")
        if fusion not in ("sum", "softmax"):
            raise ValueError("fusion 'sum' veya 'softmax' olmalı.")
        if temperature <= 0:
            raise ValueError("temperature pozitif olmalı.")
        if fusion == "softmax" and fusion_router_hidden < 1:
            raise ValueError("fusion_router_hidden >= 1 olmalı.")
        if t_min <= 0 or t_max <= 0 or t_min > t_max:
            raise ValueError("T clamp aralığı hatalı.")
        if router_temperature <= 0:
            raise ValueError("router_temperature pozitif olmalı.")
        if beta_ca < 0:
            raise ValueError("beta_ca >= 0 olmalı.")

        self.eps = float(eps)
        self.fusion = fusion
        self.return_fusion_weights = bool(return_fusion_weights)
        self.gate_fn = _get_gate(gate)

        self.t_min = float(t_min)
        self.t_max = float(t_max)
        self.Tr = float(router_temperature)
        self.beta_ca = float(beta_ca)

        hidden = max(int(min_hidden), int(channels) // int(reduction))
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(channels, hidden, kernel_size=1, bias=bias)
        self.act = _get_act(act)
        self.fc2 = nn.Conv2d(hidden, channels, kernel_size=1, bias=bias)

        if self.fusion == "softmax":
            self.fusion_router = nn.Sequential(
                nn.Conv2d(2 * channels, fusion_router_hidden, kernel_size=1, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(fusion_router_hidden, 2, kernel_size=1, bias=True),
            )
            last = self.fusion_router[-1]
            nn.init.zeros_(last.weight)
            nn.init.zeros_(last.bias)
        else:
            self.fusion_router = None

        self.learnable_temperature = bool(learnable_temperature)

        if self.learnable_temperature:
            # İsteğe bağlı strict kontrol (istersen aç)
            # if not (self.t_min <= float(temperature) <= self.t_max):
            #     raise ValueError(f"temperature [{self.t_min}, {self.t_max}] aralığında olmalı (learnable_temperature=True).")
            # Stabil başlangıç: temperature'ı clamp edip t_raw'ı onunla başlat.
            t0 = float(temperature)
            # Sınırların tam üstüne yapışmayı engellemek için küçük güvenlik payı
            lo = self.t_min + self.eps
            hi = self.t_max - self.eps
            # Eğer kullanıcı ters aralık verdiyse zaten yukarıda kontrol var; yine de defansif:
            if lo >= hi:
                lo = self.t_min
                hi = self.t_max
            t0 = min(max(t0, lo), hi)
            t_inv = softplus_inverse(torch.tensor(t0), eps=self.eps)
            self.t_raw = nn.Parameter(t_inv)
        else:
            # Fixed temperature için de clamp'li başlatmak istersen (opsiyonel):
            # t0 = float(temperature)
            # t0 = min(max(t0, self.t_min), self.t_max)
            # self.register_buffer("T", torch.tensor(t0))
            self.register_buffer("T", torch.tensor(float(temperature)))

    def get_T(self, x: torch.Tensor) -> torch.Tensor:
        if self.learnable_temperature:
            T = F.softplus(self.t_raw) + self.eps
        else:
            T = self.T
        T = T.to(device=x.device, dtype=x.dtype)
        return T.clamp(self.t_min, self.t_max)

    def mlp(self, s: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(s)))

    def forward(self, x: torch.Tensor):
        avg_s = self.avg_pool(x)
        max_s = self.max_pool(x)

        a = self.mlp(avg_s)
        m = self.mlp(max_s)

        fusion_w = None
        if self.fusion == "sum":
            z = a + m
        else:
            s_cat = torch.cat([avg_s, max_s], dim=1)
            logits = self.fusion_router(s_cat).flatten(1)  # (B,2)
            fusion_w = torch.softmax(logits / self.Tr, dim=1)
            w0 = fusion_w[:, 0].view(-1, 1, 1, 1)
            w1 = fusion_w[:, 1].view(-1, 1, 1, 1)
            z = w0 * a + w1 * m

        T = self.get_T(x)
        ca = self.gate_fn(z / T)

        scale_ca = 1.0 + self.beta_ca * (ca - 1.0)
        y = x * scale_ca

        if self.return_fusion_weights:
            return y, ca, fusion_w
        return y, ca


class HSwish(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * F.relu6(x + 3.0, inplace=True) / 6.0


def make_norm(norm: str, ch: int):
    norm = norm.lower()
    if norm == "bn":
        return nn.BatchNorm2d(ch)
    if norm == "gn":
        g = min(32, ch)
        while ch % g != 0 and g > 2:
            g //= 2
        if ch % g != 0:
            g = 2 if (ch % 2 == 0) else 1
        return nn.GroupNorm(g, ch)
    if norm == "none":
        return nn.Identity()
    raise ValueError("norm 'none', 'bn', 'gn' dışında olamaz.")


class CoordinateAttPlus(nn.Module):
    def __init__(
        self,
        in_channels: int,
        reduction: int = 32,
        min_mid_channels: int = 8,
        act: str = "hswish",
        init_alpha: float = 0.7,
        learnable_alpha: bool = True,
        beta: float = 0.35,
        dilation: int = 2,
        norm: str = "gn",
        use_spatial_gate: bool = False,
        spatial_gate_beta: float = 0.35,
        scale_min: float = 0.6,
        scale_max: float = 1.6,
        head_init_std: float = 0.01,
    ):
        super().__init__()

        if in_channels < 1:
            raise ValueError("in_channels >= 1 olmalı.")
        if reduction < 1:
            raise ValueError("reduction >= 1 olmalı.")
        if dilation < 1:
            raise ValueError("dilation >= 1 olmalı.")
        if scale_min <= 0 or scale_max <= 0 or scale_min > scale_max:
            raise ValueError("scale clamp aralığı hatalı.")
        if head_init_std <= 0:
            raise ValueError("head_init_std pozitif olmalı.")

        self.beta = float(beta)
        self.scale_min = float(scale_min)
        self.scale_max = float(scale_max)

        mid_floor = max(8, min(32, int(in_channels) // 4))
        mid = max(int(min_mid_channels), int(in_channels) // int(reduction))
        mid = max(mid, int(mid_floor))

        act_l = act.lower()
        if act_l == "hswish":
            self.act = HSwish()
        elif act_l == "relu":
            self.act = nn.ReLU(inplace=True)
        elif act_l == "silu":
            self.act = nn.SiLU(inplace=True)
        else:
            raise ValueError("act 'hswish', 'relu', 'silu' olmalı.")

        self.shared_bottleneck_proj = nn.Conv2d(in_channels, mid, 1, bias=False)
        self.shared_bottleneck_norm = make_norm(norm, mid)
        self.shared_bottleneck_refine = nn.Conv2d(mid, mid, 1, bias=False)
        self.shared_bottleneck_refine_norm = make_norm(norm, mid)

        self.h_local_dw = nn.Conv2d(
            in_channels, in_channels, kernel_size=(3, 1), padding=(1, 0),
            groups=in_channels, bias=False
        )
        self.w_local_dw = nn.Conv2d(
            in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1),
            groups=in_channels, bias=False
        )

        d = int(dilation)
        self.h_dilated_dw = nn.Conv2d(
            in_channels, in_channels, kernel_size=(3, 1), padding=(d, 0),
            dilation=(d, 1), groups=in_channels, bias=False
        )
        self.w_dilated_dw = nn.Conv2d(
            in_channels, in_channels, kernel_size=(1, 3), padding=(0, d),
            dilation=(1, d), groups=in_channels, bias=False
        )

        self.h_channel_mixer = nn.Conv2d(in_channels, in_channels, 1, bias=True)
        self.w_channel_mixer = nn.Conv2d(in_channels, in_channels, 1, bias=True)

        self.h_attention_head = nn.Conv2d(mid, in_channels, 1, bias=True)
        self.w_attention_head = nn.Conv2d(mid, in_channels, 1, bias=True)

        nn.init.normal_(self.h_attention_head.weight, mean=0.0, std=float(head_init_std))
        nn.init.normal_(self.w_attention_head.weight, mean=0.0, std=float(head_init_std))
        if self.h_attention_head.bias is not None:
            nn.init.zeros_(self.h_attention_head.bias)
        if self.w_attention_head.bias is not None:
            nn.init.zeros_(self.w_attention_head.bias)

        eps = 1e-6
        a0 = float(init_alpha)
        a0 = min(max(a0, eps), 1.0 - eps)
        raw0 = torch.logit(torch.tensor(a0), eps=eps)

        if learnable_alpha:
            self.alpha_h_raw = nn.Parameter(raw0.clone())
            self.alpha_w_raw = nn.Parameter(raw0.clone())
        else:
            self.register_buffer("alpha_h_raw", raw0.clone())
            self.register_buffer("alpha_w_raw", raw0.clone())

        self.use_spatial_gate = bool(use_spatial_gate)
        self.spatial_gate_beta = float(spatial_gate_beta)
        if self.use_spatial_gate:
            self.spatial_gate_dw = nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=False)
            self.spatial_gate_pw = nn.Conv2d(in_channels, in_channels, 1, bias=True)

        self._last_ah = None
        self._last_aw = None

    def forward(self, x: torch.Tensor):
        _, _, H, W = x.shape

        h_profile = 0.5 * (x.mean(dim=3, keepdim=True) + x.amax(dim=3, keepdim=True))
        w_profile = 0.5 * (x.mean(dim=2, keepdim=True) + x.amax(dim=2, keepdim=True))

        h_ms = self.h_channel_mixer(self.h_local_dw(h_profile) + self.h_dilated_dw(h_profile))
        w_ms = self.w_channel_mixer(self.w_local_dw(w_profile) + self.w_dilated_dw(w_profile))
        w_ms = w_ms.permute(0, 1, 3, 2)

        hw = torch.cat([h_ms, w_ms], dim=2)

        mid = self.act(self.shared_bottleneck_norm(self.shared_bottleneck_proj(hw)))
        mid = self.act(self.shared_bottleneck_refine_norm(self.shared_bottleneck_refine(mid)))

        mid_h, mid_w = torch.split(mid, [H, W], dim=2)
        mid_w = mid_w.permute(0, 1, 3, 2)

        attn_h = F.hardsigmoid(self.h_attention_head(mid_h), inplace=False)
        attn_w = F.hardsigmoid(self.w_attention_head(mid_w), inplace=False)

        self._last_ah = attn_h.detach()
        self._last_aw = attn_w.detach()

        alpha_h = torch.sigmoid(self.alpha_h_raw).to(device=x.device, dtype=x.dtype)
        alpha_w = torch.sigmoid(self.alpha_w_raw).to(device=x.device, dtype=x.dtype)

        scale_h = (1.0 - alpha_h) + alpha_h * attn_h
        scale_w = (1.0 - alpha_w) + alpha_w * attn_w

        scale = scale_h * scale_w
        scale = 1.0 + self.beta * (scale - 1.0)
        scale = scale.clamp(self.scale_min, self.scale_max)

        out = x * scale

        if self.use_spatial_gate:
            sg = self.spatial_gate_pw(self.spatial_gate_dw(x))
            sg = F.hardsigmoid(sg, inplace=False)
            sg = 1.0 + self.spatial_gate_beta * (sg - 1.0)
            out = out * sg

        return out

    @torch.no_grad()
    def last_mask_stats(self):
        if (self._last_ah is None) or (self._last_aw is None):
            return None
        ah = self._last_ah
        aw = self._last_aw
        return {
            "a_h": {"min": float(ah.min()), "mean": float(ah.mean()), "max": float(ah.max()), "std": float(ah.std())},
            "a_w": {"min": float(aw.min()), "mean": float(aw.mean()), "max": float(aw.max()), "std": float(aw.std())},
        }


class CBAMChannelPlusCoord(nn.Module):
    """
    rescue_mode:
      - "ratio_floor": ratio'yu [min_rescue, 1] aralığında tutar. alpha_eff = alpha * ratio
      - "alpha_floor": alpha_eff'i direkt [alpha_eff_min, 1] aralığına clamp eder (daha agresif)
    """
    def __init__(
        self,
        channels: int,
        ca_reduction: int = 16,
        ca_min_hidden: int = 4,
        ca_fusion: str = "softmax",
        ca_gate: str = "sigmoid",
        ca_temperature: float = 0.9,
        ca_act: str = "relu",
        ca_fusion_router_hidden: int = 16,
        learnable_temperature: bool = False,
        ca_t_min: float = 0.5,
        ca_t_max: float = 3.0,
        ca_router_temperature: float = 1.5,
        beta_ca: float = 0.35,
        coord_reduction: int = 32,
        coord_min_mid: int = 8,
        coord_act: str = "hswish",
        coord_init_alpha: float = 0.7,
        coord_learnable_alpha: bool = True,
        coord_beta: float = 0.35,
        coord_dilation: int = 2,
        coord_norm: str = "gn",
        coord_use_spatial_gate: bool = False,
        coord_spatial_gate_beta: float = 0.35,
        coord_scale_min: float = 0.6,
        coord_scale_max: float = 1.6,
        coord_head_init_std: float = 0.01,
        residual: bool = True,
        alpha_init: float = 0.75,
        learnable_alpha: bool = False,
        monitor: bool = False,
        r_min: float = 0.45,
        ema_momentum: float = 0.95,
        min_rescue_ratio: float = 0.2,
        alpha_eff_min: float = 0.2,
        rescue_mode: str = "ratio_floor",
        return_maps: bool = False,
    ):
        super().__init__()
        if channels < 1:
            raise ValueError("channels >= 1 olmalı.")
        if not (0.0 < ema_momentum < 1.0):
            raise ValueError("ema_momentum (0,1) aralığında olmalı.")
        if r_min <= 0:
            raise ValueError("r_min pozitif olmalı.")
        if not (0.0 <= min_rescue_ratio <= 1.0):
            raise ValueError("min_rescue_ratio [0,1] aralığında olmalı.")
        if not (0.0 <= alpha_eff_min <= 1.0):
            raise ValueError("alpha_eff_min [0,1] aralığında olmalı.")
        if rescue_mode not in ("ratio_floor", "alpha_floor"):
            raise ValueError("rescue_mode 'ratio_floor' veya 'alpha_floor' olmalı.")

        self.return_maps = bool(return_maps)
        self.residual = bool(residual)

        self.monitor = bool(monitor)
        self.r_min = float(r_min)
        self.ema_m = float(ema_momentum)
        self.min_rescue_ratio = float(min_rescue_ratio)
        self.alpha_eff_min = float(alpha_eff_min)
        self.rescue_mode = str(rescue_mode)

        self.ca = ChannelAttentionFusionT(
            channels=channels,
            reduction=ca_reduction,
            min_hidden=ca_min_hidden,
            fusion=ca_fusion,
            gate=ca_gate,
            temperature=ca_temperature,
            learnable_temperature=learnable_temperature,
            eps=1e-6,
            act=ca_act,
            bias=True,
            fusion_router_hidden=ca_fusion_router_hidden,
            return_fusion_weights=self.return_maps,
            t_min=ca_t_min,
            t_max=ca_t_max,
            router_temperature=ca_router_temperature,
            beta_ca=beta_ca,
        )

        self.coord = CoordinateAttPlus(
            in_channels=channels,
            reduction=coord_reduction,
            min_mid_channels=coord_min_mid,
            act=coord_act,
            init_alpha=coord_init_alpha,
            learnable_alpha=coord_learnable_alpha,
            beta=coord_beta,
            dilation=coord_dilation,
            norm=coord_norm,
            use_spatial_gate=coord_use_spatial_gate,
            spatial_gate_beta=coord_spatial_gate_beta,
            scale_min=coord_scale_min,
            scale_max=coord_scale_max,
            head_init_std=coord_head_init_std,
        )

        if self.residual:
            eps = 1e-6
            a0 = float(alpha_init)
            a0 = min(max(a0, eps), 1.0 - eps)
            raw0 = torch.logit(torch.tensor(a0), eps=eps)
            if learnable_alpha:
                self.alpha_raw = nn.Parameter(raw0)
            else:
                self.register_buffer("alpha_raw", raw0)

        self.register_buffer("r_ema", torch.tensor(1.0))

    def _alpha(self, x: torch.Tensor) -> torch.Tensor:
        if (not self.residual) or (not hasattr(self, "alpha_raw")):
            return x.new_tensor(1.0)
        return torch.sigmoid(self.alpha_raw).to(device=x.device, dtype=x.dtype)

    @staticmethod
    def _std_per_sample(x: torch.Tensor) -> torch.Tensor:
        # Daha güvenli: unbiased=False (correction=0) ile uç durumlarda NaN riskini azaltır
        return x.float().flatten(1).std(dim=1, unbiased=False).mean()

    @torch.no_grad()
    def _update_r_ema(self, r_out: torch.Tensor):
        r_det = r_out.detach().to(device=self.r_ema.device, dtype=self.r_ema.dtype)
        self.r_ema.mul_(self.ema_m).add_((1.0 - self.ema_m) * r_det)

    def _compute_alpha_eff(self, x: torch.Tensor, alpha: torch.Tensor) -> torch.Tensor:
        ratio = (self.r_ema.detach() / max(self.r_min, 1e-12)).clamp(0.0, 1.0)
        ratio = ratio.to(device=x.device, dtype=x.dtype)

        if self.rescue_mode == "ratio_floor":
            ratio = ratio.clamp(self.min_rescue_ratio, 1.0)
            return alpha * ratio

        alpha_eff = alpha * ratio
        return alpha_eff.clamp(self.alpha_eff_min, 1.0)

    def forward(self, x: torch.Tensor):
        monitor_stats = None

        if self.return_maps:
            y_ca, ca_map, fusion_w = self.ca(x)
            y = self.coord(y_ca)

            if not self.residual:
                coord_stats = self.coord.last_mask_stats()
                return y, ca_map, fusion_w, coord_stats, None

            alpha = self._alpha(x)
            alpha_eff = alpha

            if self.training and self.monitor:
                x_std = self._std_per_sample(x)
                y_std = self._std_per_sample(y)

                out_tmp = x + alpha * (y - x)
                out_std = self._std_per_sample(out_tmp)

                r_block = (y_std / (x_std + 1e-12)).clamp(0.0, 10.0)
                r_out = (out_std / (x_std + 1e-12)).clamp(0.0, 10.0)

                self._update_r_ema(r_out)
                alpha_eff = self._compute_alpha_eff(x, alpha)

                monitor_stats = {
                    "x_std": float(x_std.detach()),
                    "y_std": float(y_std.detach()),
                    "out_std_pre": float(out_std.detach()),
                    "r_block": float(r_block.detach()),
                    "r_out_pre": float(r_out.detach()),
                    "r_ema": float(self.r_ema.detach()),
                    "alpha": float(alpha.detach()),
                    "alpha_eff": float(alpha_eff.detach()),
                    "rescue_mode": self.rescue_mode,
                }

            out = x + alpha_eff * (y - x)
            coord_stats = self.coord.last_mask_stats()
            return out, ca_map, fusion_w, coord_stats, monitor_stats

        y_ca, _ = self.ca(x)
        y = self.coord(y_ca)

        if not self.residual:
            return y

        alpha = self._alpha(x)
        alpha_eff = alpha

        if self.training and self.monitor:
            x_std = self._std_per_sample(x)
            out_tmp = x + alpha * (y - x)
            out_std = self._std_per_sample(out_tmp)
            r_out = (out_std / (x_std + 1e-12)).clamp(0.0, 10.0)

            self._update_r_ema(r_out)
            alpha_eff = self._compute_alpha_eff(x, alpha)

        out = x + alpha_eff * (y - x)
        return out

## 8) Hızlı test

Aşağıdaki hücre:
- Adım 2 bloğunu kurar
- Dummy input ile forward çalıştırır
- Dönen shape’leri ve monitor_stats anahtarlarını gösterir


In [2]:
import torch

x = torch.randn(2, 64, 56, 56)

m = CBAMChannelPlusCoord(
    channels=64,
    return_maps=True,
    residual=True,
    learnable_temperature=True,
    monitor=True,
    rescue_mode="ratio_floor",
    min_rescue_ratio=0.2,
    alpha_eff_min=0.2,
    r_min=0.45,
    ema_momentum=0.95,
    beta_ca=0.35,
    ca_router_temperature=1.5,
    ca_t_min=0.5,
    ca_t_max=3.0,
    coord_scale_min=0.6,
    coord_scale_max=1.6,
)

m.train()
out, ca_map, fusion_w, coord_stats, monitor_stats = m(x)

print("x:", x.shape)
print("out:", out.shape)
print("ca_map:", ca_map.shape)
print("fusion_w:", fusion_w.shape)
print("coord_stats:", coord_stats)
print("monitor_stats keys:", None if monitor_stats is None else list(monitor_stats.keys()))
print("monitor_stats:", monitor_stats)


x: torch.Size([2, 64, 56, 56])
out: torch.Size([2, 64, 56, 56])
ca_map: torch.Size([2, 64, 1, 1])
fusion_w: torch.Size([2, 2])
coord_stats: {'a_h': {'min': 0.4740104675292969, 'mean': 0.49960988759994507, 'max': 0.5296449065208435, 'std': 0.004075914621353149}, 'a_w': {'min': 0.48295947909355164, 'mean': 0.5000577569007874, 'max': 0.5220022201538086, 'std': 0.0036030977498739958}}
monitor_stats keys: ['x_std', 'y_std', 'out_std_pre', 'r_block', 'r_out_pre', 'r_ema', 'alpha', 'alpha_eff', 'rescue_mode']
monitor_stats: {'x_std': 1.0005626678466797, 'y_std': 0.6620503664016724, 'out_std_pre': 0.7463032603263855, 'r_block': 0.6616780757904053, 'r_out_pre': 0.745883584022522, 'r_ema': 0.9872941970825195, 'alpha': 0.7500000596046448, 'alpha_eff': 0.7500000596046448, 'rescue_mode': 'ratio_floor'}


## 9) Çıktıları okuma notları

- `ca_map (B,C,1,1)`: kanal maskesi (CA)
- `fusion_w (B,2)`: avg/max birleşim ağırlıkları
- `coord_stats`: Coord H/W maskelerinin dağılımı (min/mean/max/std)
- `monitor_stats` (sadece training+monitor):
  - `r_out_pre`: residual karışım uygulanmadan önce “out_tmp / x”
  - `r_ema`: bunun EMA hali
  - `alpha_eff`: rescue sonrası etkin karışım

YOLO tarafında pratikte:
- `r_out_pre` çok düşüyorsa (ör. 0.35–0.45 altı), blok bastırıyor olabilir.
- `alpha_eff` düşmeye başladıysa rescue devreye giriyor demektir.
