In [1]:
#!/usr/bin/env python3
import argparse
import random
import numpy as np
import pandas as pd
from pathlib import Path
from pkg_resources import load_entry_point
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix, classification_report,
    roc_auc_score, roc_curve, ConfusionMatrixDisplay
)
from tqdm import tqdm
import matplotlib.pyplot as plt

# ──────────────────────────────────────────────────────────────────────────────
# 0) CLI arguments
# ──────────────────────────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="Stego Classification Ablation")
parser.add_argument(
    "--model", choices=["baseline","hp","twostream"],
    default="baseline", help="Which architecture to train"
)
args = parser.parse_args()

# ──────────────────────────────────────────────────────────────────────────────
# 1) Configuration & reproducibility
# ──────────────────────────────────────────────────────────────────────────────
STEGO_CSV   = Path("csv/stego_final.csv")
IMAGES_DIR  = Path("Images")
BATCH_SIZE  = 16
EPOCHS      = 15
LR          = 1e-3
RANDOM_SEED = 42
STEP_SIZE   = 5
GAMMA       = 0.1

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ──────────────────────────────────────────────────────────────────────────────
# 2) DataFrame building & splitting
# ──────────────────────────────────────────────────────────────────────────────
df = pd.read_csv(STEGO_CSV)
df = df[df.method.isin(["lsb","iwt"])].copy()
df["binary_method"] = "stego"
df["img_path"]      = df["stego_path"]

n_stego   = len(df)
all_imgs  = list(IMAGES_DIR.glob("*"))
used      = set(Path(p).name for p in df["img_path"])
candidates= [str(p) for p in all_imgs if p.name not in used]
none_samp = random.sample(candidates, n_stego)

df_none = pd.DataFrame({
    "binary_method": ["none"]*n_stego,
    "img_path":      none_samp
})
df = pd.concat([df[["binary_method","img_path"]], df_none], ignore_index=True)
df = df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

df_train, df_test = train_test_split(
    df, test_size=0.30, stratify=df["binary_method"], random_state=RANDOM_SEED
)
df_train, df_val = train_test_split(
    df_train, test_size=0.20, stratify=df_train["binary_method"], random_state=RANDOM_SEED
)

# ──────────────────────────────────────────────────────────────────────────────
# 3) Dataset & DataLoader
# ──────────────────────────────────────────────────────────────────────────────
label_map = {"none":0, "stego":1}
train_tf = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_tf = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

class StegoDataset(Dataset):
    def __init__(self, df, tf): 
        self.df, self.tf = df.reset_index(drop=True), tf
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img = Image.open(row.img_path).convert("RGB")
        if self.tf: img = self.tf(img)
        lbl = label_map[row.binary_method]
        return img, lbl

train_loader = DataLoader(
    StegoDataset(df_train, train_tf),
    batch_size=BATCH_SIZE, shuffle=True,  num_workers=4, pin_memory=True
)
val_loader = DataLoader(
    StegoDataset(df_val, val_tf),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    StegoDataset(df_test, val_tf),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True
)



  from pkg_resources import load_entry_point
2025-05-01 21:35:50.280887: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-01 21:35:50.291641: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746135350.304808   55524 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746135350.308883   55524 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-01 21:35:50.322359: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow b

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
# ──────────────────────────────────────────────────────────────────────────────
# 4) Model definitions
# ──────────────────────────────────────────────────────────────────────────────
# 4.1) Baseline
def make_baseline():
    m = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    m.fc = nn.Linear(m.fc.in_features, 2)
    return m

# 4.2) HPNet
class HPNet(nn.Module):
    def __init__(self):
        super().__init__()
        # high-pass kernel
        hp = torch.tensor([[-1,2,-1],[2,-4,2],[-1,2,-1]],dtype=torch.float32)
        hp = hp.view(1,1,3,3)
        self.hp = nn.Conv2d(3,1,3,1,bias=False)
        with torch.no_grad():
            self.hp.weight[:] = hp.repeat(1,3,1,1)
            self.hp.weight.requires_grad = False
        base = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.backbone = base
        self.backbone.fc = nn.Linear(base.fc.in_features, 2)
    def forward(self, x):
        r = self.hp(x)             # → [B,1,H,W]
        r = r.repeat(1,3,1,1)      # → [B,3,H,W]
        return self.backbone(r)

# 4.3) Two-Stream
class TwoStreamNet(nn.Module):
    def __init__(self):
        super().__init__()
        # hp filter
        hp = torch.tensor([[-1,2,-1],[2,-4,2],[-1,2,-1]],dtype=torch.float32)
        hp = hp.view(1,1,3,3)
        self.hp_f = nn.Conv2d(3,1,3,1,bias=False)
        with torch.no_grad():
            self.hp_f.weight[:] = hp.repeat(1,3,1,1)
            self.hp_f.weight.requires_grad=False
        # raw backbone
        b1 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        feat_dim = b1.fc.in_features
        b1.fc = nn.Identity()
        # hp backbone
        b2 = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        b2.fc = nn.Identity()
        self.raw_b = b1
        self.hp_b  = b2
        self.cls   = nn.Linear(feat_dim*2, 2)
    def forward(self, x):
        raw_feat = self.raw_b(x)
        hp_img   = self.hp_f(x).repeat(1,3,1,1)
        hp_feat  = self.hp_b(hp_img)
        cat      = torch.cat([raw_feat, hp_feat], dim=1)
        return self.cls(cat)

# instantiate chosen model
if args.model=="baseline":
    model = make_baseline().to(DEVICE)
elif args.model=="hp":
    model = HPNet().to(DEVICE)
else:
    model = TwoStreamNet().to(DEVICE)

# ──────────────────────────────────────────────────────────────────────────────
# 5) Loss, optimizer, scheduler, logging
# ──────────────────────────────────────────────────────────────────────────────
weights   = torch.tensor([1.0,1.5],device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=weights)
optimizer = Adam(model.parameters(), lr=LR)
scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

writer       = SummaryWriter(f"runs/stego_{args.model}")
train_losses = []; val_losses = []
train_accs   = []; val_accs   = []

# ──────────────────────────────────────────────────────────────────────────────
# 6) Train / Val loops
# ──────────────────────────────────────────────────────────────────────────────
def train_epoch():
    model.train()
    ls=correct=total=0
    for imgs,lbls in tqdm(train_loader, desc="Train", leave=False):
        imgs,lbls = imgs.to(DEVICE), lbls.to(DEVICE)
        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out,lbls)
        loss.backward(); optimizer.step()
        ls     += loss.item()*lbls.size(0)
        preds   = out.argmax(1)
        correct+= (preds==lbls).sum().item()
        total  += lbls.size(0)
    return ls/total, correct/total

@torch.no_grad()
def validate():
    model.eval()
    ls=correct=total=0
    for imgs,lbls in tqdm(val_loader, desc="Val  ", leave=False):
        imgs,lbls = imgs.to(DEVICE), lbls.to(DEVICE)
        out = model(imgs)
        loss=criterion(out,lbls)
        ls      += loss.item()*lbls.size(0)
        preds   = out.argmax(1)
        correct+= (preds==lbls).sum().item()
        total  += lbls.size(0)
    return ls/total, correct/total

best_acc=0.0
for e in range(1, EPOCHS+1):
    tr_loss, tr_acc = train_epoch()
    val_loss, val_acc = validate()
    scheduler.step()

    train_losses.append(tr_loss); val_losses.append(val_loss)
    train_accs.append(tr_acc);   val_accs.append(val_acc)

    writer.add_scalar("Loss/train", tr_loss, e)
    writer.add_scalar("Loss/val",   val_loss, e)
    writer.add_scalar("Acc/train",  tr_acc, e)
    writer.add_scalar("Acc/val",    val_acc, e)

    print(f"[{args.model}] Epoch {e}/{EPOCHS}  "
          f"Tr: {tr_loss:.4f}/{tr_acc:.3f}  "
          f"Val: {val_loss:.4f}/{val_acc:.3f}")

    if val_acc>best_acc:
        best_acc=val_acc
        torch.save(model.state_dict(), f"best_{args.model}.pth")
        print(" → new best!")

writer.close()

# ──────────────────────────────────────────────────────────────────────────────
# 7) Plot curves
# ──────────────────────────────────────────────────────────────────────────────
plt.figure()
plt.plot(range(1,EPOCHS+1), train_losses, label="Train Loss", marker='o')
plt.plot(range(1,EPOCHS+1), val_losses,   label="Val   Loss", marker='o')
plt.title("Loss"); plt.xlabel("Epoch"); plt.legend(); plt.show()

plt.figure()
plt.plot(range(1,EPOCHS+1), train_accs, label="Train Acc", marker='o')
plt.plot(range(1,EPOCHS+1), val_accs,   label="Val   Acc", marker='o')
plt.title("Accuracy"); plt.xlabel("Epoch"); plt.legend(); plt.show()

# ──────────────────────────────────────────────────────────────────────────────
# 8) Test inference & metrics
# ──────────────────────────────────────────────────────────────────────────────
model.load_state_dict(torch.load(f"best_{args.model}.pth"))
model.eval()

ALL_L, ALL_P, ALL_PROB = [], [], []
with torch.no_grad():
    for imgs,lbls in test_loader:
        imgs = imgs.to(DEVICE)
        out  = model(imgs)
        prob = torch.softmax(out,dim=1)[:,1].cpu().numpy()
        pred = out.argmax(1).cpu().numpy()

        ALL_L.extend(lbls.numpy())
        ALL_P.extend(pred)
        ALL_PROB.extend(prob)

# optimal threshold
fpr,tpr,thr = roc_curve(ALL_L, ALL_PROB)
opt = thr[np.argmax(tpr-fpr)]
print("Optimal thr:", opt)

# confusion @ 0.5
cm = confusion_matrix(ALL_L, ALL_P)
ConfusionMatrixDisplay(cm, display_labels=["none","stego"]).plot(cmap="Blues")
plt.title("Confusion Matrix (0.5)"); plt.show()

print(classification_report(ALL_L, ALL_P, target_names=["none","stego"]))

# ROC
auc = roc_auc_score(ALL_L, ALL_PROB)
plt.figure()
plt.plot(fpr,tpr, label=f"AUC={auc:.3f}")
plt.plot([0,1],[0,1],"k--")
plt.title("ROC Curve"); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.legend(); plt.show()
