In [7]:
import torch
import torch.nn as nn
import torchvision.models as models


class CustomCNNFusion(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # [B, 32, 224, 224]
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 32, 112, 112]
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 64, 56, 56]

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 128, 28, 28]
        )
        
        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Linear(128 * 28 * 28 + 8, 128),  # +8 pour les features statistiques
            nn.ReLU(),
            nn.Linear(128, 2)  # 2 classes : normal / steg
        )

    def forward(self, x_img, x_stats):
        x = self.cnn(x_img)
        x = self.flatten(x)              # [B, 128 * 28 * 28]
        x = torch.cat((x, x_stats), dim=1)  # fusion avec features stats
        return self.fc(x)
    



class ResStatFusion(nn.Module):
    def __init__(self):
        super().__init__()

        base_model = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(base_model.children())[:-1])  # [B, 512, 1, 1]

        self.stat_fc = nn.Sequential(
            nn.Linear(8, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )

        self.final_fc = nn.Sequential(
            nn.Linear(512 + 64, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, image, stat_feats):
        # CNN branch
        cnn_feat = self.cnn(image).squeeze()  # [B, 512]
        if cnn_feat.dim() == 1:
            cnn_feat = cnn_feat.unsqueeze(0)

        # MLP branch
        stat_feat = self.stat_fc(stat_feats)  # [B, 64]

        fusion = torch.cat((cnn_feat, stat_feat), dim=1)  # [B, 576]
        out = self.final_fc(fusion)
        return out




Ce dataset retourne trois éléments :

📷 L’image transformée (Tensor)

📊 Les features statistiques extraits de cette image

🏷️ Le label (0 = normal, 1 = steg)

In [8]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from scipy.stats import kurtosis, skew
from torchvision import transforms

class FusionFeatureDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        root_dir: dossier contenant deux sous-dossiers 'normal' et 'stego'
        transform: transformations à appliquer à l'image (ex: resize, normalize)
        """
        self.image_paths = []
        self.labels = []
        self.transform = transform

        for label_name in os.listdir(root_dir):
            label_dir = os.path.join(root_dir, label_name)
            label = 0 if label_name.lower() == 'normal' else 1

            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(label)

    def extract_stat_features(self, img):
        """
        Extrait 8 features statistiques d'une image grayscale
        """
        img_gray = img.convert('L')
        data = np.asarray(img_gray).flatten()

        std = np.std(data)
        range_val = np.max(data) - np.min(data)
        median = np.median(data)
        geo_median = np.exp(np.mean(np.log(data + 1e-5)))  # +eps pour log(0)
        skewness = skew(data)
        kurt = kurtosis(data)
        d1 = np.diff(data)
        d2 = np.diff(d1)
        var0 = np.var(data)
        var1 = np.var(d1)
        var2 = np.var(d2)
        mobility = np.sqrt(var1 / var0)
        complexity = np.sqrt((var2 / var1) - (var1 / var0))

        return torch.tensor([std, range_val, median, geo_median, skewness, kurt, mobility, complexity], dtype=torch.float32)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]
        img = Image.open(img_path).convert('RGB')

        stat_features = self.extract_stat_features(img)

        if self.transform:
            img = self.transform(img)

        return img, stat_features, torch.tensor(label, dtype=torch.long)

    def __len__(self):
        return len(self.image_paths)


In [9]:
import random
from torchvision import transforms
import torchvision.transforms.functional as TF

class RandomRotation:
    def __call__(self, img):
        angles = [0, 90, 180, 270]
        angle = random.choice(angles)
        return TF.rotate(img, angle)


transform = transforms.Compose([
    #RandomRotation(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

train_dir = 'stegoimagesdataset/train/train/'
val_dir = 'stegoimagesdataset/val/val/'
test_dir = 'stegoimagesdataset/test/test/'

train_dataset = FusionFeatureDataset(root_dir=train_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = FusionFeatureDataset(root_dir=val_dir, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

test_dataset = FusionFeatureDataset(root_dir=test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Exemple de batch :
for img_batch, stat_batch, label_batch in train_loader:
    print(img_batch.shape)       # [32, 3, 224, 224]
    print(stat_batch.shape)      # [32, 8]
    print(label_batch.shape)     # [32]
    break


  complexity = np.sqrt((var2 / var1) - (var1 / var0))


torch.Size([32, 3, 224, 224])
torch.Size([32, 8])
torch.Size([32])


In [10]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
def evaluate_model(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, stat_feats, labels in val_loader:
            images, stat_feats, labels = images.to(device), stat_feats.to(device), labels.to(device)
            outputs = model(images, stat_feats)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    acc = correct / total * 100
    print(f"Validation Accuracy: {acc:.2f}%\n")


In [None]:
def train_model(model, train_loader, val_loader, num_epochs, device, patience, save_path):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_val_acc = 0.0
    epochs_without_improvement = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, stat_feats, labels in train_loader:
            images, stat_feats, labels = images.to(device), stat_feats.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images, stat_feats)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_loss = running_loss / total
        epoch_acc = correct / total * 100
        print(f"[Train] Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.2f}%")

        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_images, val_feats, val_labels in val_loader:
                val_images, val_feats, val_labels = val_images.to(device), val_feats.to(device), val_labels.to(device)
                val_outputs = model(val_images, val_feats)
                _, val_pred = torch.max(val_outputs, 1)
                val_correct += (val_pred == val_labels).sum().item()
                val_total += val_labels.size(0)

        val_acc = val_correct / val_total * 100
        print(f"[Validation] Accuracy: {val_acc:.2f}%")

        # === Early Stopping & Best Model Saving ===
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_without_improvement = 0
            torch.save(model.state_dict(), save_path)
            print("Best model saved.")
        else:
            epochs_without_improvement += 1
            print(f"No improvement. ({epochs_without_improvement}/{patience})")

        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

    print(f"Training finished. Best Validation Accuracy: {best_val_acc:.2f}%")
    return model


In [None]:
model = CustomCNNFusion()
model2 = ResStatFusion()
trained_model = train_model(model, train_loader, val_loader, num_epochs=30, device=device, patience=5, save_path='best_model_CNNFUSION.pth')
trained_model2 = train_model(model2, train_loader, val_loader, num_epochs=30, device=device, patience=5, save_path='best_model_ResStatFusion.pth')

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

def test_model(model_class, test_loader, model_path, device):
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, stat_feats, labels in test_loader:
            images, stat_feats, labels = images.to(device), stat_feats.to(device), labels.to(device)
            outputs = model(images, stat_feats)
            _, predicted = torch.max(outputs, 1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred)
    labels_names = ['Normal (0)', 'Stego (1)']

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels_names, yticklabels=labels_names)
    plt.xlabel('Prédit')
    plt.ylabel('Vrai')
    plt.title('Matrice de confusion')
    plt.show()

    print("Rapport de classification :")
    print(classification_report(y_true, y_pred, target_names=labels_names))

#test_model(CustomCNNFusion, test_loader, 'best_model_CNNFUSION.pth', device)
#test_model(ResStatFusion, test_loader, 'best_model_ResStatFusion.pth', device)