## Change some Images from RGBA to RGB

In [16]:
import os
from PIL import Image

# Paths to your datasets
train_dir = 'stegoimagesdataset/train/train'
val_dir = 'stegoimagesdataset/val/val'
test_dir = 'stegoimagesdataset/test/test'

def find_non_rgb_images(directories):
    non_rgb = []
    rgb_count = 0

    for root in directories:
        for dirpath, _, filenames in os.walk(root):
            for fname in filenames:
                if fname.lower().endswith('.png'):
                    path = os.path.join(dirpath, fname)
                    try:
                        with Image.open(path) as img:
                            if img.mode != 'RGB':
                                non_rgb.append((path, img.mode))
                            else:
                                rgb_count += 1
                    except Exception as e:
                        non_rgb.append((path, f'Error opening: {e}'))

    return rgb_count, non_rgb

# Run the check
dirs = [train_dir, val_dir, test_dir]
rgb_count, non_rgb_images = find_non_rgb_images(dirs)

print(f"Total RGB images: {rgb_count}")
print(f"Images with non-RGB modes or errors: {len(non_rgb_images)}")
for path, mode in non_rgb_images:
    print(f"  {mode}: {path}")



Total RGB images: 32000
Images with non-RGB modes or errors: 0


In [None]:
import os
from PIL import Image

def convert_clean_rgba_to_rgb(root_dirs):
    """
    Parcourt uniquement les sous-dossiers 'clean' de chacun des répertoires racines
    et convertit les images PNG RGBA en RGB.
    """
    for root in root_dirs:
        clean_dir = os.path.join(root, 'clean')
        if not os.path.isdir(clean_dir):
            continue

        for fname in os.listdir(clean_dir):
            if not fname.lower().endswith('.png'):
                continue

            path = os.path.join(clean_dir, fname)
            try:
                with Image.open(path) as img:
                    if img.mode == 'RGBA':
                        rgb_img = img.convert('RGB')
                        rgb_img.save(path)
                        print(f"Converted RGBA → RGB: {path}")
            except Exception as e:
                print(f"Error processing {path}: {e}")

# Exemple d'utilisation :
train_dir = 'stegoimagesdataset/train/train'
val_dir   = 'stegoimagesdataset/val/val'
test_dir  = 'stegoimagesdataset/test/test'

convert_clean_rgba_to_rgb([train_dir, val_dir, test_dir])


## Start

### Customed Models

In [1]:
import torch
import torch.nn as nn
import torchvision.models as models


class CustomCNNFusion(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # [B, 32, 224, 224]
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 32, 112, 112]
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 64, 56, 56]

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                             # [B, 128, 28, 28]
        )
        
        self.flatten = nn.Flatten()

        self.fc = nn.Sequential(
            nn.Linear(128 * 28 * 28 + 8, 128),  # +8 pour les features statistiques
            nn.ReLU(),
            nn.Linear(128, 2)  # 2 classes : normal / steg
        )

    def forward(self, x_img, x_stats):
        x = self.cnn(x_img)
        x = self.flatten(x)              # [B, 128 * 28 * 28]
        x = torch.cat((x, x_stats), dim=1)  # fusion avec features stats
        return self.fc(x)
    



class ResStatFusion(nn.Module):
    def __init__(self):
        super().__init__()

        base_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.cnn = nn.Sequential(*list(base_model.children())[:-1])  # form of the tensor just after the last layer :
        #[B, 512, 1, 1], Remove the last fully connected layer and keep the pipeline which extracts features
        # from the image. The last layer is a fully connected layer that outputs 1000 classes (ImageNet).
        """
        Why 512 ? : 
        ResNet18 has 512 channels(features maps, filters) in the last convolutional layer before the fully connected layer.
        Each convolutional filter in the last layer yields a feature map, and the number of filters in the last layer is 512.
        The output of the last convolutional layer is a tensor with shape [B, 512, 1, 1], where B is the batch size.
        The 1x1 spatial dimension indicates that the feature maps have been pooled down to a single value per channel.
        The AvgPool2d layer reduces(computes the average of the pixels in the image) the spatial dimensions to 1x1, effectively summarizing each feature map into a single value.
        So in the end, we get a tensor of shape [B, 512] after the squeeze operation.
        
        """


        self.stat_fc = nn.Sequential( #pipeline for the statistical features
            nn.Linear(8, 64), # [B, 8] -> [B, 64] : output = input * weight + bias
            nn.ReLU(), # activation function : f(x) = max(0, x)
            nn.Linear(64, 64),
            nn.ReLU()
        )

        # fusion layer
        self.final_fc = nn.Sequential(
            nn.Linear(512 + 64, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, image, stat_feats):
        # CNN branch
        cnn_feat = self.cnn(image).squeeze()  # [B, 512], we dont need 1x1 spatial dimension, so we remove it with squeeze()
        if cnn_feat.dim() == 1: # If the batch size is 1, add a dimension to make it [1, 512]
            cnn_feat = cnn_feat.unsqueeze(0)

        # MLP branch
        stat_feat = self.stat_fc(stat_feats)  # [B, 64], passing the statistical features through the MLP

        fusion = torch.cat((cnn_feat, stat_feat), dim=1)  # [B, 576], concatenate the two branches
        out = self.final_fc(fusion) # [B, 2], passing through the final fully connected layer
        return out




### Initialization of the sets and loaders

In [2]:
import random
from torchvision import transforms
import torchvision.transforms.functional as TF
from custom_dataset import FusionFeatureDataset
torch.backends.cudnn.benchmark = True 

class RandomRotation:
    def __call__(self, img):
        angles = [0, 90, 180, 270]
        angle = random.choice(angles)
        return TF.rotate(img, angle)


transform = transforms.Compose([
    #RandomRotation(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

train_dir = 'stegoimagesdataset/train/train/'
val_dir = 'stegoimagesdataset/val/val/'
test_dir = 'stegoimagesdataset/test/test/'

train_dataset = FusionFeatureDataset(root_dir=train_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4, persistent_workers=True, prefetch_factor=2)

val_dataset = FusionFeatureDataset(root_dir=val_dir, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

test_dataset = FusionFeatureDataset(root_dir=test_dir, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

for i, (img, stat_feats, label) in enumerate(train_loader):
    print(f"Batch {i+1}:")
    print(f"  Image shape: {img.shape}")
    print(f"  Stat features shape: {stat_feats.shape}")
    print(f"  Label shape: {label.shape}")
    break  # Just to show the first batch

Batch 1:
  Image shape: torch.Size([64, 3, 224, 224])
  Stat features shape: torch.Size([64, 8])
  Label shape: torch.Size([64])


In [3]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [8]:
import time
from torch.amp import GradScaler, autocast

class_counts = [4000, 12000]
total = sum(class_counts)
class_weights = [total / c for c in class_counts]  # inverse fréquence

weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

def train_model(model, train_loader, val_loader, num_epochs, device, patience, save_path):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    best_val_acc = 0.0
    epochs_without_improvement = 0

    scaler = GradScaler() 
    print(f"Starting training for model: {model.__class__.__name__} at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())} for {num_epochs} epochs")
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, stat_feats, labels in train_loader:
            images, stat_feats, labels = images.to(device, non_blocking=True), stat_feats.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(images, stat_feats)
                loss = criterion(outputs, labels)
            #loss.backward()
            #optimizer.step()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            print(total)

        epoch_loss = running_loss / total
        epoch_acc = correct / total * 100
        print(f"[Train] Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.2f}%")

        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for val_images, val_feats, val_labels in val_loader:
                val_images, val_feats, val_labels = val_images.to(device), val_feats.to(device), val_labels.to(device)
                val_outputs = model(val_images, val_feats)
                _, val_pred = torch.max(val_outputs, 1)
                val_correct += (val_pred == val_labels).sum().item()
                val_total += val_labels.size(0)

        val_acc = val_correct / val_total * 100
        print(f"[Validation] Accuracy: {val_acc:.2f}%")

        # === Early Stopping & Best Model Saving ===
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_without_improvement = 0
            torch.save(model.state_dict(), save_path)
            print("Best model saved.")
        else:
            epochs_without_improvement += 1
            print(f"No improvement. ({epochs_without_improvement}/{patience})")

        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

    print(f"Training finished. Best Validation Accuracy: {best_val_acc:.2f}%")
    return model


In [9]:
model = CustomCNNFusion()
model2 = ResStatFusion()
trained_model = train_model(model, train_loader, val_loader, num_epochs=30, device=device, patience=5, save_path='best_model_CNNFUSION.pth')
trained_model2 = train_model(model2, train_loader, val_loader, num_epochs=30, device=device, patience=5, save_path='best_model_ResStatFusion.pth')

Starting training for model: CustomCNNFusion at 2025-04-25 17:03:17 for 30 epochs
64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024
1088
1152
1216
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
4480
4544
4608
4672
4736
4800
4864
4928
4992
5056
5120
5184
5248
5312
5376
5440
5504
5568
5632
5696
5760
5824
5888
5952
6016
6080
6144
6208
6272
6336
6400
6464
6528
6592
6656
6720
6784
6848
6912
6976
7040
7104
7168
7232
7296
7360
7424
7488
7552
7616
7680
7744
7808
7872
7936
8000
8064
8128
8192
8256
8320
8384
8448
8512
8576
8640
8704
8768
8832
8896
8960
9024
9088
9152
9216
9280
9344
9408
9472
9536
9600
9664
9728
9792
9856
9920
9984
10048
10112
10176
10240
10304
10368
10432
10496
10560
10624
10688
10752
10816
10880
10944
11008
11072
11136
11200
11264
11328
11392
11456
11520
11584
1164

KeyboardInterrupt: 

In [None]:
test_dataset = FusionFeatureDataset(root_dir=test_dir, transform=transform, device=device)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

def test_model(model_class, test_loader, model_path, device):
    model = model_class().to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval() # set to evaluation mode, deactivate dropout and batch normalization

    y_true = []
    y_pred = []

    with torch.no_grad(): # deactivate gradient calculation for inference
        for images, stat_feats, labels in test_loader:
            images, stat_feats, labels = images.to(device), stat_feats.to(device), labels.to(device)
            outputs = model(images, stat_feats)
            _, predicted = torch.max(outputs, 1) # returns the max values and the indices, select the index of the max value

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred)
    labels_names = ['Normal (0)', 'Stego (1)']

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels_names, yticklabels=labels_names)
    plt.xlabel('Prédit')
    plt.ylabel('Vrai')
    plt.title('Matrice de confusion')
    plt.show()

    print("Rapport de classification :")
    print(classification_report(y_true, y_pred, target_names=labels_names))

test_model(CustomCNNFusion, test_loader, 'best_model_CNNFUSION.pth', device)
test_model(ResStatFusion, test_loader, 'best_model_ResStatFusion.pth', device)