# AdaptSRNet

This notebook trains the **enhanced AdaptSRNet** model (1-2M parameters) for steganalysis on the WOW algorithm (0.4bpp payload) using **PyTorch Lightning**.

In [None]:
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)


In [None]:
# ===================== CONFIG CELL =====================
import torch

#  Paths (manual setup)
COVER_DIR = "/kaggle/input/bossbase-bows2/GBRASNET/BOSSbase-1.01/cover"
STEGO_DIR = "/kaggle/input/bossbase-bows2/GBRASNET/BOSSbase-1.01/stego/WOW/0.4bpp/stego"

#  Training parameters
BATCH_SIZE = 4       # Increase if you have more GPU memory (e.g., 8 or 16)
IMG_SIZE = 256
EPOCHS = 50          # Can increase to 100 for better convergence
NUM_WORKERS = 4
SEED = 42
LR = 1e-4           # Will be adjusted by CosineAnnealingWarmRestarts

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("="*60)
print(f"Using device(s): {torch.cuda.device_count()} GPU(s)")
print(f"Device type: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print("="*60)
print(f"Cover directory: {COVER_DIR}")
print(f"Stego directory: {STEGO_DIR}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Max epochs: {EPOCHS}")
print(f"Learning rate: {LR}")
print("="*60)

# Set memory fragmentation handling
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
torch.cuda.empty_cache()
# ========================================================


In [None]:
!pip install --upgrade pip setuptools --quiet
!pip install pytorch-lightning --quiet
!pip install tokenizers==0.13.3 --quiet
!pip install transformers==4.28.1 --quiet

In [None]:
import os
import numpy as np
from glob import glob
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchmetrics.classification import Accuracy, F1Score

pl.seed_everything(SEED)

In [None]:
from glob import glob
import os
from pathlib import Path
import numpy as np

from glob import glob
from pathlib import Path
import os, re

def normalize_name(name):
    """
    Normalize filenames by:
    - removing prefixes (WOW_, HILL_, etc.)
    - removing payload suffix (_0.2bpp, _02bpp)
    - stripping zeros from numbers (00001 -> 1)
    """
    name = Path(name).stem
    # Remove common stego prefixes or suffixes
    name = re.sub(r'^(WOW_|HILL_|HUGO_|MiPOD_|S-UNIWARD_)', '', name, flags=re.IGNORECASE)
    name = re.sub(r'(_0?\\.\\d+bpp|_\\d+bpp)$', '', name, flags=re.IGNORECASE)
    # Remove leading zeros in numeric part
    name = re.sub(r'^0+', '', name)
    return name

def collect_pairs(cover_dir, stego_dir):
    cover_paths = sorted([p for p in glob(os.path.join(cover_dir, '*')) if p.lower().endswith(('.pgm', '.png', '.jpg', '.jpeg', '.bmp'))])
    stego_paths = sorted([p for p in glob(os.path.join(stego_dir, '*')) if p.lower().endswith(('.pgm', '.png', '.jpg', '.jpeg', '.bmp'))])

    cover_map = {normalize_name(p): p for p in cover_paths}
    stego_map = {normalize_name(p): p for p in stego_paths}

    common = sorted(set(cover_map.keys()) & set(stego_map.keys()))
    pairs = [(cover_map[k], stego_map[k], k) for k in common]

    print(f"Found {len(pairs)} matching cover/stego pairs")
    if len(pairs) == 0:
        print("No matches found â€” check filename formats or folder paths.")
    return pairs



def make_splits_from_pairs(pairs, val_frac=0.1, test_frac=0.2, seed=42):
    rng = np.random.RandomState(seed)
    rng.shuffle(pairs)
    n = len(pairs)
    ntest, nval = int(n * test_frac), int(n * val_frac)
    test = pairs[:ntest]
    val = pairs[ntest:ntest + nval]
    train = pairs[ntest + nval:]
    return {'train': train, 'val': val, 'test': test}


pairs = collect_pairs(COVER_DIR, STEGO_DIR)
splits = make_splits_from_pairs(pairs)
print("Splits:", {k: len(v) for k, v in splits.items()})


In [None]:
class PairedStegoDataset(Dataset):
    def __init__(self, pairs, transform=None):
        self.samples = []
        for c, s, b in pairs:
            self.samples.append((c, 0, b))
            self.samples.append((s, 1, b))
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label, base = self.samples[idx]
        img = Image.open(path).convert('L')
        if self.transform:
            img = self.transform(img)
        return img, label, base

train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_loader = DataLoader(PairedStegoDataset(splits['train'], transform=train_tf), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(PairedStegoDataset(splits['val'], transform=val_tf), batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
print("DataLoaders ready")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------- Enhanced Learnable SRM layer with more filters ----------
class LearnableSRMLayer(nn.Module):
    def __init__(self, num_filters=64):  # Increased to 64 for better feature extraction
        super().__init__()
        self.num_filters = num_filters
        srm_filters = self._initialize_srm_filters()
        self.filters = nn.Parameter(srm_filters, requires_grad=True)
        self.scale = nn.Parameter(torch.tensor(1.0), requires_grad=True)

    def _initialize_srm_filters(self):
        # KV filter - effective for steganalysis
        kv_filter = torch.tensor([
            [-1,  2,  -2,  2, -1],
            [ 2, -6,   8, -6,  2],
            [-2,  8, -12,  8, -2],
            [ 2, -6,   8, -6,  2],
            [-1,  2,  -2,  2, -1]
        ], dtype=torch.float32) / 12.0

        # Edge detection filters
        edge_h = torch.tensor([[-1, -1, -1], [0, 0, 0], [1, 1, 1]], dtype=torch.float32)
        edge_v = torch.tensor([[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]], dtype=torch.float32)
        laplacian = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32)

        filters = []
        for i in range(self.num_filters):
            if i == 0:
                f = kv_filter.clone()
            elif i < 4:
                f = torch.rot90(kv_filter, k=(i % 4), dims=[0,1])
            elif i < 8:
                # Add edge detection variations
                base = edge_h if i % 2 == 0 else edge_v
                f = F.pad(base, (1, 1, 1, 1), mode='constant', value=0)
                f = torch.rot90(f, k=(i % 4), dims=[0,1])
            elif i < 12:
                # Laplacian variations
                f = F.pad(laplacian, (1, 1, 1, 1), mode='constant', value=0)
                f = torch.rot90(f, k=(i % 4), dims=[0,1])
            else:
                # Random variations of KV filter
                f = kv_filter + torch.randn_like(kv_filter) * 0.02

            f = f - f.mean()
            f_norm = torch.norm(f)
            if f_norm > 0:
                f = f / f_norm
            filters.append(f)

        return torch.stack(filters, dim=0).unsqueeze(1)  # [num_filters,1,5,5]

    def forward(self, x):
        # x expected shape: [B, 1, H, W]
        residuals = F.conv2d(x, self.filters, padding=2)  # -> [B, num_filters, H, W]
        residuals = torch.tanh(residuals * self.scale)
        return residuals


# ---------- Squeeze-and-Excitation with better reduction ----------
class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        reduced = max(1, in_channels // reduction_ratio)
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, reduced, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(reduced, in_channels, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        s = self.squeeze(x).view(b, c)
        e = self.excitation(s).view(b, c, 1, 1)
        return x * e


# ---------- Standard Conv Block (for richer features) ----------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


# ---------- Enhanced Residual Block with standard convolutions ----------
class EnhancedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_se=True):
        super().__init__()
        # Use standard convolutions for better feature extraction
        self.conv1 = ConvBlock(in_channels, out_channels, stride=stride)
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        self.se = SqueezeExcitation(out_channels, reduction_ratio=16) if use_se else nn.Identity()

        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.skip = nn.Identity()

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = self.skip(x)
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.se(out)
        out = out + identity
        return self.relu(out)


# ---------- Enhanced Multi-scale feature extractor ----------
class MultiScaleFeatureExtractor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        # out_channels should be divisible by 4
        assert out_channels % 4 == 0
        quarter = out_channels // 4

        self.branch_1x1 = nn.Sequential(
            nn.Conv2d(in_channels, quarter, 1, bias=False),
            nn.BatchNorm2d(quarter), nn.ReLU(inplace=True)
        )
        self.branch_3x3 = nn.Sequential(
            nn.Conv2d(in_channels, quarter, 3, padding=1, bias=False),
            nn.BatchNorm2d(quarter), nn.ReLU(inplace=True)
        )
        self.branch_5x5 = nn.Sequential(
            nn.Conv2d(in_channels, quarter, 5, padding=2, bias=False),
            nn.BatchNorm2d(quarter), nn.ReLU(inplace=True)
        )
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            nn.Conv2d(in_channels, quarter, 1, bias=False),
            nn.BatchNorm2d(quarter), nn.ReLU(inplace=True)
        )
        self.se = SqueezeExcitation(out_channels, reduction_ratio=16)

    def forward(self, x):
        b1 = self.branch_1x1(x)
        b2 = self.branch_3x3(x)
        b3 = self.branch_5x5(x)
        b4 = self.branch_pool(x)
        concat = torch.cat([b1, b2, b3, b4], dim=1)
        return self.se(concat)


# ---------- Spatial Attention Module (NEW - addresses shortcoming) ----------
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Channel-wise max and mean pooling
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        mean_pool = torch.mean(x, dim=1, keepdim=True)
        combined = torch.cat([max_pool, mean_pool], dim=1)
        attention = self.sigmoid(self.conv(combined))
        return x * attention


# ---------- CBAM (Convolutional Block Attention Module) ----------
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.channel_attention = SqueezeExcitation(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x


# ---------- Enhanced AdaptSRNet (target 1-2M params) ----------
class AdaptSRNetEnhanced(nn.Module):
    def __init__(self, num_classes=1, srm_filters=64):
        super().__init__()

        # 1) Enhanced SRM preprocessing - outputs 64 channels
        self.srm_layer = LearnableSRMLayer(num_filters=srm_filters)

        # Enhanced channel plan for 1-2M parameters
        init_ch = 64      # Increased from 32
        ms_ch = 128       # Increased from 64
        l1_ch = 128       # Increased from 64
        l2_ch = 192       # Increased from 96
        l3_ch = 256       # Increased from 128
        l4_ch = 384       # Increased from 256

        # 2) Initial conv with better capacity
        self.init_conv = nn.Sequential(
            nn.Conv2d(srm_filters, init_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(init_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(init_ch, init_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(init_ch),
            nn.ReLU(inplace=True)
        )

        # 3) Multi-scale feature extraction
        self.multiscale = MultiScaleFeatureExtractor(init_ch, ms_ch)

        # 4) Enhanced Residual layers with more blocks
        self.layer1 = self._make_layer(ms_ch, l1_ch, num_blocks=3, stride=1)
        self.layer2 = self._make_layer(l1_ch, l2_ch, num_blocks=3, stride=2)
        self.layer3 = self._make_layer(l2_ch, l3_ch, num_blocks=4, stride=2)
        self.layer4 = self._make_layer(l3_ch, l4_ch, num_blocks=3, stride=2)

        # 5) Additional CBAM attention after layer4
        self.cbam = CBAM(l4_ch, reduction_ratio=16)

        # 6) Enhanced classifier head
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.global_max_pool = nn.AdaptiveMaxPool2d((1, 1))  # Add max pooling

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(l4_ch * 2, 256),  # Combined avg + max pooling features
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(128),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )

        # Safe initialization
        self._initialize_weights()

    def _make_layer(self, in_c, out_c, num_blocks, stride):
        layers = [EnhancedResidualBlock(in_c, out_c, stride=stride, use_se=True)]
        for _ in range(1, num_blocks):
            layers.append(EnhancedResidualBlock(out_c, out_c, stride=1, use_se=True))
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if getattr(m, "weight", None) is not None:
                    nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                if getattr(m, "weight", None) is not None:
                    nn.init.ones_(m.weight)
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                if getattr(m, "weight", None) is not None:
                    nn.init.normal_(m.weight, 0, 0.01)
                if getattr(m, "bias", None) is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        # x: [B,1,H,W]
        x = self.srm_layer(x)           # -> [B, 64, H, W]
        x = self.init_conv(x)           # -> [B, 64, H, W]
        x = self.multiscale(x)          # -> [B, 128, H, W]
        x = self.layer1(x)              # -> [B, 128, H, W]
        x = self.layer2(x)              # -> [B, 192, H/2, W/2]
        x = self.layer3(x)              # -> [B, 256, H/4, W/4]
        x = self.layer4(x)              # -> [B, 384, H/8, W/8]
        x = self.cbam(x)                # Apply attention

        # Dual pooling strategy
        avg_pool = self.global_pool(x)
        max_pool = self.global_max_pool(x)
        x = torch.cat([avg_pool, max_pool], dim=1)

        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# -------------------- Utility: print params --------------------
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

if __name__ == "__main__":
    # Quick sanity test
    model = AdaptSRNetEnhanced(num_classes=1, srm_filters=64)
    total, trainable = count_parameters(model)
    print(f"Total params: {total:,}; Trainable: {trainable:,}")
    print(f"Model size: {total / 1e6:.2f}M parameters")

    # Test forward pass with dummy input
    x = torch.randn(2, 1, 256, 256)
    with torch.no_grad():
        y = model(x)
    print(f"Output shape: {y.shape}")
    print(f"Model ready for training!")


In [None]:
class StegoLightningModule(pl.LightningModule):
    def __init__(self, model, lr=1e-3):
        super().__init__()
        self.model = model
        self.criterion = nn.BCEWithLogitsLoss()
        self.lr = lr

        # Metrics for binary classification
        self.train_acc = Accuracy(task="binary")
        self.train_f1 = F1Score(task="binary")
        self.val_acc = Accuracy(task="binary")
        self.val_f1 = F1Score(task="binary")

        # Track best validation accuracy
        self.best_val_acc = 0.0

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        y = y.float().unsqueeze(1)
        logits = self(x)
        loss = self.criterion(logits, y)

        # Predictions
        preds = (torch.sigmoid(logits) > 0.5).float()
        acc = self.train_acc(preds, y.long())
        f1 = self.train_f1(preds, y.long())

        # Logging
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
        self.log("train_f1", f1, prog_bar=False, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y, _ = batch
        y = y.float().unsqueeze(1)
        logits = self(x)
        loss = self.criterion(logits, y)

        # Predictions
        preds = (torch.sigmoid(logits) > 0.5).float()
        acc = self.val_acc(preds, y.long())
        f1 = self.val_f1(preds, y.long())

        # Logging
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_f1", f1, prog_bar=True, on_step=False, on_epoch=True)

        return {"val_loss": loss, "val_acc": acc, "val_f1": f1}

    def on_validation_epoch_end(self):
        # Track best accuracy
        val_acc = self.val_acc.compute()
        if val_acc > self.best_val_acc:
            self.best_val_acc = val_acc
            print(f"\nNew best validation accuracy: {val_acc:.4f}")

    def configure_optimizers(self):
        # Use AdamW with weight decay for regularization
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=1e-4,
            betas=(0.9, 0.999)
        )

        # Cosine annealing with warm restarts for better convergence
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,  # Restart every 10 epochs
            T_mult=2,  # Double the restart period each time
            eta_min=1e-6
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1
            }
        }


In [None]:
import os

# Set CUDA_LAUNCH_BLOCKING to 1 for synchronous error reporting
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Initialize enhanced model
model = AdaptSRNetEnhanced(num_classes=1, srm_filters=64)
lightning_model = StegoLightningModule(model, lr=LR)

# Print model statistics
total, trainable = count_parameters(model)
print(f"\n{'='*60}")
print(f"Model: AdaptSRNetEnhanced")
print(f"Total Parameters: {total:,}")
print(f"Trainable Parameters: {trainable:,}")
print(f"Model Size: {total / 1e6:.2f}M parameters")
print(f"{'='*60}\n")

# Enhanced callbacks
checkpoint_cb = ModelCheckpoint(
    monitor="val_acc",
    mode="max",
    save_top_k=3,  # Save top 3 models
    filename="adaptsrnet-{epoch:02d}-{val_acc:.4f}-{val_f1:.4f}",
    verbose=True
)

early_stop_cb = EarlyStopping(
    monitor="val_acc",
    mode="max",
    patience=10,  # Increased patience
    verbose=True,
    min_delta=0.001  # Minimum improvement threshold
)

# Learning rate monitor
from pytorch_lightning.callbacks import LearningRateMonitor
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Enhanced trainer with gradient clipping
trainer = Trainer(
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    max_epochs=EPOCHS,
    callbacks=[checkpoint_cb, early_stop_cb, lr_monitor],
    log_every_n_steps=10,
    gradient_clip_val=1.0,  # Gradient clipping to prevent exploding gradients
    accumulate_grad_batches=1,  # Can increase for larger effective batch size
    deterministic=False,  # Set to True for reproducibility (slower)
    enable_progress_bar=True
)

print("\nStarting training...")
print(f"Batch size: {BATCH_SIZE}")
print(f"Max epochs: {EPOCHS}")
print(f"Initial learning rate: {LR}")
print(f"Training samples: {len(splits['train'])}")
print(f"Validation samples: {len(splits['val'])}\n")

trainer.fit(lightning_model, train_loader, val_loader)

print("\nTraining completed!")
print(f"Best validation accuracy: {lightning_model.best_val_acc:.4f}")