In [1]:
!pip install resampy pandas scikit-learn matplotlib librosa tensorflow


Collecting resampy
  Downloading resampy-0.4.3-py3-none-any.whl.metadata (3.0 kB)
Downloading resampy-0.4.3-py3-none-any.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: resampy
Successfully installed resampy-0.4.3
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, Dropout, BatchNormalization, Activation
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras import regularizers
from tensorflow.keras.metrics import Precision, Recall, AUC, BinaryAccuracy
from sklearn.metrics import precision_recall_curve, roc_curve, auc as sklearn_auc, f1_score, confusion_matrix, classification_report
import random
import wave
import shutil
import pandas as pd
GLOBAL_SEED = 2025
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)
tf.random.set_seed(GLOBAL_SEED)



2025-10-29 17:25:30.298888: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761758730.536537      79 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761758730.599817      79 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# --- Cấu hình GPU ---
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        if len(gpus) > 1:
            tf.config.set_visible_devices([gpus[0]], 'GPU')
            print(f'Dang gioi han su dung 1 GPU: {gpus[0].name}')
        visible_gpus = tf.config.list_physical_devices('GPU')
        for gpu in visible_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(visible_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)
else:
    print("No GPU detected, running on CPU.")


2 Physical GPUs, 2 Logical GPUs


I0000 00:00:1761758743.950402      79 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1761758743.951120      79 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


In [4]:
# --- CAU HINH CHUNG ---
DATA_PARENT_DIR = '/kaggle/input/data-lsb/data' # Đường dẫn dataset Kaggle (Cần thay đổi nếu chạy local)
PROCESSED_DATA_FILE = 'audio_data.npz' # Tên file data đã xử lý (Lưu ở thư mục hiện tại)

# --- Tham số Audio ---
SAMPLE_RATE = 22050
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
FIXED_DURATION_SEC = 4
EXPECTED_SPECTROGRAM_COLS = int(np.ceil(FIXED_DURATION_SEC * SAMPLE_RATE / HOP_LENGTH))

# --- Tham số Model ---
IMG_WIDTH, IMG_HEIGHT = EXPECTED_SPECTROGRAM_COLS, N_MELS
INPUT_SHAPE = (IMG_HEIGHT, IMG_WIDTH, 3)
EPOCHS = 180
BATCH_SIZE = 16  # Batch size co so cho moi GPU (GLOBAL_BATCH_SIZE = BATCH_SIZE * so replica)

# --- Đường dẫn Output ---
WORKING_DIR = '/kaggle/working/' if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ else '.'
MODEL_SAVE_DIR = os.path.join(WORKING_DIR, 'model_output')
RESULT_IMAGE_DIR = os.path.join(WORKING_DIR, 'result_images')
PROCESSED_DATA_FULL_PATH = os.path.join(WORKING_DIR, PROCESSED_DATA_FILE)

# --- Tên file Output ---
run_id = "audio_steganalysis_v1"
model_checkpoint_filename = f'best_model_{run_id}.keras'
training_history_plot_filename = f'training_history_{run_id}.png'
pr_curve_filename = f'precision_recall_curve_val_{run_id}.png'

# Tạo thư mục output
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
print("Các đường dẫn và tham số đã được cấu hình.")
print(f"Dữ liệu xử lý sẽ được lưu/tải tại: {PROCESSED_DATA_FULL_PATH}")
print(f"Model checkpoints sẽ được lưu tại: {MODEL_SAVE_DIR}")
print(f"Hình ảnh kết quả sẽ được lưu tại: {RESULT_IMAGE_DIR}")


Các đường dẫn và tham số đã được cấu hình.
Dữ liệu xử lý sẽ được lưu/tải tại: /kaggle/working/audio_data.npz
Model checkpoints sẽ được lưu tại: /kaggle/working/model_output
Hình ảnh kết quả sẽ được lưu tại: /kaggle/working/result_images


In [5]:
# --- HÀM AUGMENTATION ---
def time_masking(spectrogram, T=20, num_masks=2, replace_with_zero=True, axis_to_mask=1):
    cloned = np.copy(spectrogram)
    len_frames = cloned.shape[axis_to_mask]
    for _ in range(num_masks):
        t = random.randint(0, T)
        if len_frames == 0:
            continue
        t0 = random.randint(0, len_frames - t) if len_frames > t else 0
        if replace_with_zero:
            if cloned.ndim == 3:
                cloned[:, t0:t0 + t, :] = 0
            else:
                cloned[:, t0:t0 + t] = 0
        else:
            mean_val = np.mean(cloned)
            if cloned.ndim == 3:
                cloned[:, t0:t0 + t, :] = mean_val
            else:
                cloned[:, t0:t0 + t] = mean_val
    return cloned


def frequency_masking(spectrogram, F=15, num_masks=2, replace_with_zero=True, axis_to_mask=0):
    cloned = np.copy(spectrogram)
    num_mels = cloned.shape[axis_to_mask]
    for _ in range(num_masks):
        f = random.randint(0, F)
        if num_mels == 0:
            continue
        f0 = random.randint(0, num_mels - f) if num_mels > f else 0
        if replace_with_zero:
            if cloned.ndim == 3:
                cloned[f0:f0 + f, :, :] = 0
            else:
                cloned[f0:f0 + f, :] = 0
        else:
            mean_val = np.mean(cloned)
            if cloned.ndim == 3:
                cloned[f0:f0 + f, :, :] = mean_val
            else:
                cloned[f0:f0 + f, :] = mean_val
    return cloned


def additive_gaussian_noise(spectrogram, std=0.01):
    noise = np.random.normal(0.0, std, spectrogram.shape)
    return spectrogram + noise


def random_gain(spectrogram, min_gain_db=-2.0, max_gain_db=2.0):
    gain_db = np.random.uniform(min_gain_db, max_gain_db)
    gain = 10.0 ** (gain_db / 20.0)
    return spectrogram * gain


print("Đã định nghĩa các hàm augmentation (SpecAugment + noise/gain).")


Đã định nghĩa các hàm augmentation (SpecAugment + noise/gain).


In [6]:

# --- Keras Sequence (SpecAugment + Noise/Gain) ---
class SpecAugmentSequence(Sequence):
    def __init__(
        self,
        x_set,
        y_set,
        batch_size,
        apply_time_mask=True,
        time_mask_param_T=20,
        num_time_masks=2,
        apply_freq_mask=True,
        freq_mask_param_F=15,
        num_freq_masks=2,
        apply_noise=True,
        noise_std=0.01,
        apply_gain=True,
        gain_db_range=(-2.0, 2.0),
        augment_prob=0.4,
        shuffle=True
    ):
        super().__init__()
        self.x = x_set
        self.y = y_set
        self.batch_size = batch_size
        self.apply_time_mask = apply_time_mask
        self.time_mask_param_T = time_mask_param_T
        self.num_time_masks = num_time_masks
        self.apply_freq_mask = apply_freq_mask
        self.freq_mask_param_F = freq_mask_param_F
        self.num_freq_masks = num_freq_masks
        self.apply_noise = apply_noise
        self.noise_std = noise_std
        self.apply_gain = apply_gain
        self.gain_db_range = gain_db_range
        self.augment_prob = augment_prob
        self.shuffle = shuffle
        self.indices = np.arange(len(self.x))
        self.on_epoch_end()

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_x_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = self.x[batch_x_indices]
        batch_y = self.y[batch_x_indices]

        augmented_batch_x = []
        for spec_idx in range(batch_x.shape[0]):
            spec = batch_x[spec_idx]
            augmented_spec = np.copy(spec)
            if random.random() < self.augment_prob:
                if self.apply_gain:
                    augmented_spec = random_gain(
                        augmented_spec,
                        min_gain_db=self.gain_db_range[0],
                        max_gain_db=self.gain_db_range[1]
                    )
                if self.apply_noise:
                    augmented_spec = additive_gaussian_noise(augmented_spec, std=self.noise_std)
                if self.apply_time_mask:
                    augmented_spec = time_masking(
                        augmented_spec,
                        T=self.time_mask_param_T,
                        num_masks=self.num_time_masks,
                        axis_to_mask=1
                    )
                if self.apply_freq_mask:
                    augmented_spec = frequency_masking(
                        augmented_spec,
                        F=self.freq_mask_param_F,
                        num_masks=self.num_freq_masks,
                        axis_to_mask=0
                    )
            augmented_batch_x.append(augmented_spec)

        batch_x_to_return = np.asarray(augmented_batch_x, dtype=np.float32)
        batch_x_to_return = np.nan_to_num(batch_x_to_return, nan=0.0, posinf=0.0, neginf=0.0)
        np.clip(batch_x_to_return, -6.0, 6.0, out=batch_x_to_return)
        batch_y = batch_y.astype(np.float32, copy=False)
        batch_y = np.nan_to_num(batch_y, nan=0.0, posinf=0.0, neginf=0.0)

        if np.isnan(batch_x_to_return).any() or np.isinf(batch_x_to_return).any():
            print(f"*** LOI NGHIEM TRONG (Sequence): NaN/Inf tim thay TRUOC KHI RETURN batch_x index {idx} !!!")
        if np.isnan(batch_y).any() or np.isinf(batch_y).any():
            print(f"*** LOI NGHIEM TRONG (Sequence): NaN/Inf tim thay trong batch_y index {idx} !!!")

        return batch_x_to_return, batch_y

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

print('Da dinh nghia class SpecAugmentSequence (aug mo rong).')


Đã định nghĩa class SpecAugmentSequence (aug mở rộng).


In [7]:
# --- HAM XAY DUNG MO HINH CNN (Spectrogram 3 kenh) ---
def build_cnn_model_with_bn(input_shape):
    """Xây dựng CNN nhiều tầng với BatchNorm + GlobalAveragePooling."""
    model = Sequential(name="audio_steganalysis_cnn_bn_multichannel")
    l2_rate = 5e-4

    # Block 1
    model.add(Conv2D(32, (3, 3), padding='same', input_shape=input_shape,
                     kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(32, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.2))

    # Block 2
    model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.25))

    # Block 3
    model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.3))

    # Block 4
    model.add(Conv2D(192, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Conv2D(192, (3, 3), padding='same', kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(MaxPooling2D((2, 2)))
    model.add(Dropout(0.35))

    model.add(GlobalAveragePooling2D())
    model.add(Dense(256, kernel_regularizer=regularizers.l2(l2_rate)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Dropout(0.4))
    model.add(Dense(1, activation='sigmoid'))
    return model

print("Đã định nghĩa hàm build_cnn_model_with_bn (multi-channel).")


Đã định nghĩa hàm build_cnn_model_with_bn (multi-channel).


In [8]:
# --- Hàm tải và tiền xử lý dữ liệu ---
def load_and_preprocess_audio(
    file_path,
    sr=SAMPLE_RATE,
    duration=FIXED_DURATION_SEC,
    n_mels=N_MELS,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH
):
    try:
        samples, _ = librosa.load(
            file_path,
            sr=sr,
            duration=duration,
            res_type='kaiser_fast'
        )
        target_length = int(duration * sr)
        if len(samples) < target_length:
            samples = librosa.util.fix_length(samples, size=target_length)

        mel_spectrogram = librosa.feature.melspectrogram(
            y=samples,
            sr=sr,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels
        )
        mel_spectrogram = np.maximum(1e-10, mel_spectrogram)

        log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)

        if log_mel_spectrogram.shape[1] < EXPECTED_SPECTROGRAM_COLS:
            pad_width = EXPECTED_SPECTROGRAM_COLS - log_mel_spectrogram.shape[1]
            log_mel_spectrogram = np.pad(
                log_mel_spectrogram,
                ((0, 0), (0, pad_width)),
                mode='constant'
            )
        elif log_mel_spectrogram.shape[1] > EXPECTED_SPECTROGRAM_COLS:
            log_mel_spectrogram = log_mel_spectrogram[:, :EXPECTED_SPECTROGRAM_COLS]

        delta_1 = librosa.feature.delta(log_mel_spectrogram, order=1, mode='nearest')
        delta_2 = librosa.feature.delta(log_mel_spectrogram, order=2, mode='nearest')

        stacked = np.stack([log_mel_spectrogram, delta_1, delta_2], axis=-1)
        stacked = np.nan_to_num(stacked, nan=0.0, posinf=0.0, neginf=0.0)
        return stacked
    except Exception as e:
        print(f"Loi xu ly file {file_path}: {e}")
        return None


def create_dataset_from_paths(data_paths_labels):
    features, labels = [], []
    total_files = len(data_paths_labels)
    if total_files == 0:
        return np.array([]), np.array([])
    print(f"Bat dau tao dataset tu {total_files} file...")
    for i, (file_path, label) in enumerate(data_paths_labels):
        spectrogram = load_and_preprocess_audio(file_path)
        if spectrogram is not None:
            features.append(spectrogram)
            labels.append(label)
        if (i + 1) % 500 == 0 or (i + 1) == total_files:
            print(f"  Da xu ly {i+1}/{total_files} file.")

    if not features:
        print("Khong co features nao duoc tao.")
        return np.array([]), np.array([])

    features = np.array(features)
    labels = np.array(labels)

    if features.ndim == 3:
        features = features[..., np.newaxis]
    elif features.ndim != 4:
        raise ValueError(f"Hinh dang feature khong hop le: {features.shape}")

    print(f"Hoan tat tao dataset. Features: {features.shape}, Labels: {labels.shape}")
    return features, labels


def reconstruct_file_lists(base_dir):
    sets = ['train', 'val', 'test']
    all_data = {'train': [], 'val': [], 'test': []}
    print('--- Tai tao danh sach file ---')
    for s in sets:
        clean_dir = os.path.join(base_dir, 'clean', s)
        stego_dir = os.path.join(base_dir, 'stego', s)
        clean_files = []
        if os.path.exists(clean_dir):
            clean_files = [
                (os.path.join(clean_dir, fn), 0)
                for fn in sorted(os.listdir(clean_dir))
                if fn.lower().endswith('.wav')
            ]
        else:
            print(f"CANH BAO: Khong tim thay thu muc sach: {clean_dir}")
        stego_files = []
        if os.path.exists(stego_dir):
            stego_files = [
                (os.path.join(stego_dir, fn), 1)
                for fn in sorted(os.listdir(stego_dir))
                if fn.lower().endswith('.wav')
            ]
        else:
            print(f"CANH BAO: Khong tim thay thu muc stego: {stego_dir}")
        all_data[s] = clean_files + stego_files
        seed_offset = {'train': GLOBAL_SEED, 'val': GLOBAL_SEED + 1, 'test': GLOBAL_SEED + 2}[s]
        rng = random.Random(seed_offset)
        rng.shuffle(all_data[s])
    print(f"Train: {len(all_data['train'])}, Val: {len(all_data['val'])}, Test: {len(all_data['test'])} file.")
    return all_data['train'], all_data['val'], all_data['test']


def prepare_and_save_data(data_parent_dir, output_filepath):
    print("\n--- Buoc 1 & 2: Tai, Tien xu ly Du lieu va Luu tru ---")
    final_train_data, final_val_data, final_test_data = reconstruct_file_lists(data_parent_dir)
    if not final_train_data:
        print("LOI: Khong co du lieu huan luyen.")
        return [np.array([])] * 6

    X_train, y_train = create_dataset_from_paths(final_train_data)
    X_val, y_val = create_dataset_from_paths(final_val_data)
    X_test, y_test = create_dataset_from_paths(final_test_data)

    if X_train.size == 0:
        print("LOI: Khong tao duoc features cho tap huan luyen.")
        return [np.array([])] * 6

    print(f"\n--- Dang luu du lieu da xu ly vao '{output_filepath}' ---")
    data_to_save = {'X_train': X_train, 'y_train': y_train}
    if X_val.size > 0:
        data_to_save.update({'X_val': X_val, 'y_val': y_val})
    if X_test.size > 0:
        data_to_save.update({'X_test': X_test, 'y_test': y_test})
    np.savez_compressed(output_filepath, **data_to_save)
    print('--- Luu du lieu thanh cong ---')
    return X_train, y_train, X_val, y_val, X_test, y_test


print("Đã định nghĩa các hàm xử lý dữ liệu (multi-channel spectrogram).")



Đã định nghĩa các hàm xử lý dữ liệu (multi-channel spectrogram).


In [9]:
import os
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Activation
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras import regularizers
from tensorflow.keras.metrics import Precision, Recall, AUC, BinaryAccuracy
from sklearn.metrics import precision_recall_curve, roc_curve, auc as sklearn_auc, f1_score, confusion_matrix, classification_report
import random
import wave
import shutil
import pandas as pd




In [10]:
# --- BUOC CHUAN BI DU LIEU (TAI HOAC TAO MOI) ---

# Khởi tạo các biến dữ liệu
X_train, y_train, X_val, y_val, X_test, y_test = (np.array([]),) * 6

# --- Tải hoặc Tạo Dữ liệu ---
expected_channels = INPUT_SHAPE[-1]
if os.path.exists(PROCESSED_DATA_FULL_PATH):
    print(f"--- Dang tai du lieu da xu ly tu '{PROCESSED_DATA_FULL_PATH}' ---")
    try:
        data = np.load(PROCESSED_DATA_FULL_PATH)
        X_train = data['X_train']
        y_train = data['y_train']
        X_val = data.get('X_val', np.array([]))
        y_val = data.get('y_val', np.array([]))
        X_test = data.get('X_test', np.array([]))
        y_test = data.get('y_test', np.array([]))
        print('--- Tai du lieu thanh cong ---')
        if X_train.size == 0:
            raise ValueError('Du lieu huan luyen tai len bi rong.')
        # Kiểm tra số kênh, nếu không khớp thì tạo lại dataset
        if X_train.ndim == 4 and X_train.shape[-1] != expected_channels:
            raise ValueError(
                f'Du lieu cu co {X_train.shape[-1]} kenh, nhung INPUT_SHAPE yeu cau {expected_channels}.'
            )
    except Exception as e:
        print(f"Loi khi tai file: {e}. Se thu tao lai.")
        if os.path.exists(PROCESSED_DATA_FULL_PATH):
            os.remove(PROCESSED_DATA_FULL_PATH)
        X_train, y_train, X_val, y_val, X_test, y_test = prepare_and_save_data(
            DATA_PARENT_DIR,
            PROCESSED_DATA_FULL_PATH
        )
else:
    print(f"--- Khong tim thay du lieu da xu ly. Dang tao moi... ---")
    X_train, y_train, X_val, y_val, X_test, y_test = prepare_and_save_data(
        DATA_PARENT_DIR,
        PROCESSED_DATA_FULL_PATH
    )

# --- Kiểm tra, Xử lý và In thông tin Dữ liệu ---
if X_train is None or X_train.size == 0:
    print('LOI NGHIEM TRONG: Khong co du lieu huan luyen.')
    # Dừng thực thi hoặc xử lý lỗi phù hợp
else:
    print(
        f"\nKich thuoc du lieu TRUOC xu ly: Train={X_train.shape}, "
        f"Val={X_val.shape if X_val.size > 0 else 'N/A'}, "
        f"Test={X_test.shape if X_test.size > 0 else 'N/A'}"
    )

    train_labels, train_counts = np.unique(y_train.astype(int), return_counts=True)
    print(f'Phan bo y_train: {dict(zip(train_labels.tolist(), train_counts.tolist()))}')
    if X_val.size > 0:
        val_labels, val_counts = np.unique(y_val.astype(int), return_counts=True)
        print(f'Phan bo y_val: {dict(zip(val_labels.tolist(), val_counts.tolist()))}')
    if X_test.size > 0:
        test_labels, test_counts = np.unique(y_test.astype(int), return_counts=True)
        print(f'Phan bo y_test: {dict(zip(test_labels.tolist(), test_counts.tolist()))}')


    # Kiểm tra NaN/Inf ban đầu
    if np.isnan(X_train).any() or np.isinf(X_train).any():
        print('!!! CANH BAO: NaN/Inf trong X_train TRUOC chuan hoa !!!')

    # Chuẩn hóa
    print('\n--- CHUAN HOA DU LIEU ---')
    train_mean = np.mean(X_train, axis=(0, 1, 2), keepdims=True)
    train_std = np.std(X_train, axis=(0, 1, 2), keepdims=True)
    train_std = np.where(train_std == 0, 1.0, train_std)
    train_mean_flat = train_mean.reshape(-1)
    train_std_flat = train_std.reshape(-1)
    mean_str = ', '.join(f"{v:.4f}" for v in train_mean_flat)
    std_str = ', '.join(f"{v:.4f}" for v in train_std_flat)
    print(f'Train mean (per channel): {mean_str}')
    print(f'Train std (per channel): {std_str}')
    X_train = (X_train - train_mean) / train_std
    if X_val.size > 0:
        X_val = (X_val - train_mean) / train_std
    if X_test.size > 0:
        X_test = (X_test - train_mean) / train_std
    print('Da ap dung chuan hoa.')

    # Kiểm tra Min/Max sau chuẩn hóa
    print('\n--- KIEM TRA MIN/MAX (SAU CHUAN HOA) ---')
    print(f'X_train Min/Max: {np.min(X_train):.4f} / {np.max(X_train):.4f}')
    if X_val.size > 0:
        print(f'X_val Min/Max: {np.min(X_val):.4f} / {np.max(X_val):.4f}')
    if np.isnan(X_train).any() or np.isinf(X_train).any():
        print('!!! CANH BAO: NaN/Inf trong X_train SAU chuan hoa !!!')

    # Ép kiểu float32
    print('\n--- EP KIEU DU LIEU SANG float32 ---')
    X_train = X_train.astype(np.float32)
    y_train = y_train.astype(np.float32)
    if X_val.size > 0:
        X_val, y_val = X_val.astype(np.float32), y_val.astype(np.float32)
    if X_test.size > 0:
        X_test, y_test = X_test.astype(np.float32), y_test.astype(np.float32)
    print('Da ep kieu sang float32.')


--- Khong tim thay du lieu da xu ly. Dang tao moi... ---

--- Buoc 1 & 2: Tai, Tien xu ly Du lieu va Luu tru ---
--- Tai tao danh sach file ---
Train: 8130, Val: 1762, Test: 1710 file.
Bat dau tao dataset tu 8130 file...
  Da xu ly 500/8130 file.
  Da xu ly 1000/8130 file.
  Da xu ly 1500/8130 file.
  Da xu ly 2000/8130 file.
  Da xu ly 2500/8130 file.
  Da xu ly 3000/8130 file.
  Da xu ly 3500/8130 file.
  Da xu ly 4000/8130 file.
  Da xu ly 4500/8130 file.
  Da xu ly 5000/8130 file.
  Da xu ly 5500/8130 file.
  Da xu ly 6000/8130 file.
  Da xu ly 6500/8130 file.
  Da xu ly 7000/8130 file.
  Da xu ly 7500/8130 file.
  Da xu ly 8000/8130 file.
  Da xu ly 8130/8130 file.
Hoan tat tao dataset. Features: (8130, 128, 173, 3), Labels: (8130,)
Bat dau tao dataset tu 1762 file...
  Da xu ly 500/1762 file.
  Da xu ly 1000/1762 file.
  Da xu ly 1500/1762 file.
  Da xu ly 1762/1762 file.
Hoan tat tao dataset. Features: (1762, 128, 173, 3), Labels: (1762,)
Bat dau tao dataset tu 1710 file...
  Da

In [None]:
# --- BUOC HUAN LUYEN MODEL ---
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TerminateOnNaN, ReduceLROnPlateau
from tensorflow.keras.metrics import AUC, BinaryAccuracy, Precision, Recall
from tensorflow.keras.losses import BinaryCrossentropy
import numpy as np
import os

tf.config.optimizer.set_jit(False)
print('--- Da tat toi uu hoa XLA JIT ---')
tf.debugging.enable_check_numerics()
print('*** Da bat tf.debugging.enable_check_numerics() ***')

def _clean_array_in_place(name, arr):
    if arr is None or not isinstance(arr, np.ndarray):
        return
    nan_count = int(np.isnan(arr).sum())
    inf_count = int(np.isinf(arr).sum())
    if nan_count or inf_count:
        print(f"CANH BAO: {name} co {nan_count} NaN va {inf_count} Inf. Thay the bang 0.")
    np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0, copy=False)
    print(f"  {name}: shape={arr.shape}, dtype={arr.dtype}, min={arr.min():.4f}, max={arr.max():.4f}")

EPOCHS = 180
print(f"--- So epochs: {EPOCHS} ---")

available_gpus = tf.config.list_physical_devices('GPU')
if available_gpus:
    gpu_names = [gpu.name for gpu in available_gpus]
    print(f"Phat hien {len(available_gpus)} GPU vat ly: {gpu_names}")
else:
    print('Khong tim thay GPU vat ly. Huon luyen tren CPU hoac TPU neu co.')

if len(available_gpus) > 1:
    print('Khoi tao tf.distribute.MirroredStrategy de tan dung nhieu GPU.')
    strategy = tf.distribute.MirroredStrategy()
else:
    strategy = tf.distribute.get_strategy()

print(f"Dang su dung chien luoc: {strategy.__class__.__name__}")
num_replicas = strategy.num_replicas_in_sync
print(f"So luong thiet bi (GPU logic): {num_replicas}")

GLOBAL_BATCH_SIZE = BATCH_SIZE * max(1, num_replicas)
if num_replicas > 1:
    print(f"Batch size moi replica: {BATCH_SIZE}, batch size toan cuc: {GLOBAL_BATCH_SIZE}")
else:
    print(f"Batch size su dung: {GLOBAL_BATCH_SIZE}")

print('Khoi tao Data Generator (SpecAugmentSequence)')
TIME_MASK_PARAM_T_GEN = 12
NUM_TIME_MASKS_GEN = 1
FREQ_MASK_PARAM_F_GEN = 12
NUM_FREQ_MASKS_GEN = 1
AUGMENT_PROB_GEN = 0.12
NOISE_STD_GEN = 0.002
GAIN_RANGE_DB = (-1.0, 1.0)

train_data_generator = None
val_data_for_fit = None

for name in ('X_train', 'X_val', 'X_test', 'y_train', 'y_val', 'y_test'):
    if name in globals():
        _clean_array_in_place(name, globals()[name])
    else:
        print(f"  {name}: khong ton tai trong bo nho.")

if 'X_train' in locals() and isinstance(X_train, np.ndarray) and X_train.size > 0:
    train_data_generator = SpecAugmentSequence(
        X_train,
        y_train,
        GLOBAL_BATCH_SIZE,
        apply_time_mask=True,
        time_mask_param_T=TIME_MASK_PARAM_T_GEN,
        num_time_masks=NUM_TIME_MASKS_GEN,
        apply_freq_mask=True,
        freq_mask_param_F=FREQ_MASK_PARAM_F_GEN,
        num_freq_masks=NUM_FREQ_MASKS_GEN,
        apply_noise=True,
        noise_std=NOISE_STD_GEN,
        apply_gain=True,
        gain_db_range=GAIN_RANGE_DB,
        augment_prob=AUGMENT_PROB_GEN,
        shuffle=True
    )
    print(f"Da tao Data Generator voi Batch Size = {GLOBAL_BATCH_SIZE}, augment_prob = {AUGMENT_PROB_GEN}")
    sample_x, sample_y = train_data_generator[0]
    # Random batch sanity check
    if len(train_data_generator) > 1:
        rand_idx = np.random.randint(0, len(train_data_generator))
        rand_x, _ = train_data_generator[rand_idx]
        rand_x = np.nan_to_num(rand_x, nan=0.0, posinf=0.0, neginf=0.0)
        if not np.isfinite(rand_x).all():
            raise ValueError('Batch nga u nhien co NaN/Inf sau augment (rand batch).')
    sample_x = sample_x.astype(np.float32, copy=False)
    sample_x = np.nan_to_num(sample_x, nan=0.0, posinf=0.0, neginf=0.0)
    finite_ratio = np.mean(np.isfinite(sample_x))
    print(f'  Kiem tra batch dau tien: ti le gia tri huu han = {finite_ratio:.4f}, min = {np.min(sample_x):.4f}, max = {np.max(sample_x):.4f}')
    channel_means = sample_x.mean(axis=(0, 1, 2))
    channel_stds = sample_x.std(axis=(0, 1, 2))
    channel_report = ', '.join([f'ch{i}: mean={channel_means[i]:.4f}, std={channel_stds[i]:.4f}' for i in range(len(channel_means))])
    print(f'  {channel_report}')
    if finite_ratio < 1.0:
        raise ValueError('Batch dau tien co NaN/Inf sau augment')
    if 'X_val' in locals() and isinstance(X_val, np.ndarray) and X_val.size > 0:
        val_data_for_fit = (X_val, y_val)
        print('Su dung (X_val, y_val) cho validation_data.')
else:
    print('LOI: Khong co du lieu X_train de tao Data Generator.')

if 'y_train' in locals() and isinstance(y_train, np.ndarray) and y_train.size > 0:
    train_clean_ratio = float(np.mean(y_train == 0))
    print(f"Ti le lop (train) - sach: {train_clean_ratio:.3f}, stego: {1 - train_clean_ratio:.3f}")
if 'y_val' in locals() and isinstance(y_val, np.ndarray) and y_val.size > 0:
    val_clean_ratio = float(np.mean(y_val == 0))
    print(f"Ti le lop (val)   - sach: {val_clean_ratio:.3f}, stego: {1 - val_clean_ratio:.3f}")
if 'y_test' in locals() and isinstance(y_test, np.ndarray) and y_test.size > 0:
    test_clean_ratio = float(np.mean(y_test == 0))
    print(f"Ti le lop (test)  - sach: {test_clean_ratio:.3f}, stego: {1 - test_clean_ratio:.3f}")

class_weights = None
print(f"Su dung Class Weights: {class_weights}")
print('  -> Khong dung class weight (None) de tranh bias.')

with strategy.scope():
    print('Xay dung Mo hinh CNN (relu, BN, multi-channel)')
    model = build_cnn_model_with_bn(INPUT_SHAPE)

    print('Cau hinh Optimizer (Adam) va Compile Model')
    initial_learning_rate = 2e-4
    if train_data_generator is not None:
        steps_per_epoch = len(train_data_generator)
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=initial_learning_rate,
            clipnorm=1.0
        )
        print(f"Su dung Adam voi learning_rate co dinh = {initial_learning_rate} (steps_per_epoch={steps_per_epoch}).")
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=initial_learning_rate, clipnorm=1.0)
        print(f"CANH BAO: Khong co du lieu huan luyen, su dung Adam LR={initial_learning_rate}.")

    loss_fn = BinaryCrossentropy()
    print('Ap dung BinaryCrossentropy khong label smoothing.')

    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=[
            BinaryAccuracy(name='accuracy', threshold=0.5),
            Precision(name='precision', thresholds=0.5),
            Recall(name='recall', thresholds=0.5),
            AUC(name='auc'),
            AUC(name='pr_auc', curve='PR')
        ]
    )
    print('Da compile model voi cac metric: accuracy, precision, recall, auc, pr_auc.')

model_checkpoint_path = os.path.join(MODEL_SAVE_DIR, model_checkpoint_filename)

callbacks_list = []
callbacks_list.append(TerminateOnNaN())
if val_data_for_fit:
    callbacks_list.append(
        EarlyStopping(
            monitor='val_auc',
            mode='max',
            patience=28,
            restore_best_weights=True,
            verbose=1
        )
    )
    print("Da them EarlyStopping (monitor=val_auc, patience=28).")
else:
    print('Khong co du lieu validation, bo qua EarlyStopping.')

callbacks_list.append(
    ModelCheckpoint(
        model_checkpoint_path,
        monitor='val_auc',
        mode='max',
        save_best_only=True,
        verbose=1
    )
)

if val_data_for_fit:
    callbacks_list.append(
        ReduceLROnPlateau(
            monitor='val_auc',
            mode='max',
            factor=0.5,
            patience=8,
            min_lr=5e-6,
            verbose=1
        )
    )

print(f"BAT DAU QUA TRINH HUAN LUYEN (multi-channel, Adam lr={initial_learning_rate}, augment={AUGMENT_PROB_GEN}, {EPOCHS} epochs, class_weights={class_weights})")
history = None
fit_kwargs = {}
if class_weights is not None:
    fit_kwargs['class_weight'] = class_weights
if train_data_generator:
    try:
        history = model.fit(
            train_data_generator,
            epochs=EPOCHS,
            validation_data=val_data_for_fit,
            callbacks=callbacks_list,
            **fit_kwargs
        )
        print('--- HUAN LUYEN HOAN TAT ---')
    except tf.errors.InvalidArgumentError as e:
        print(f"LOI InvalidArgumentError KHI HUAN LUYEN: {e}")
else:
    print('HUAN LUYEN BI HUY DO KHONG TAO DUOC Data Generator')

if history is not None:
    print('--- QUA TRINH HOAN TAT ---')











In [None]:
# --- BUOC DANH GIA VA VE BIEU DO ---
import matplotlib.pyplot as plt
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
    auc as sklearn_auc,
    f1_score,
    confusion_matrix,
    classification_report,
    precision_score,
    recall_score,
)
import numpy as np
import tensorflow as tf
import os

CUSTOM_DECISION_THRESHOLD = None  # Dat None neu muon dung nguong F1 toi uu
THRESHOLD_CANDIDATES = [0.10, 0.15, 0.20, 0.25, 0.30, 0.32, 0.35, 0.40, 0.45]
ROC_PLOT_FILENAME = f'roc_curve_val_{run_id}.png'

def summary_probability_stats(name, probs):
    probs = np.nan_to_num(np.asarray(probs).ravel(), nan=0.0, posinf=0.0, neginf=0.0)
    if probs.size == 0:
        print(f'{name}: khong co du lieu de thong ke.')
        return probs
    print(f'Thong ke xac suat ({name}):')
    print(f'  Mean={probs.mean():.4f}, Std={probs.std():.4f}, Min={probs.min():.4f}, Max={probs.max():.4f}')
    hist_vals, hist_edges = np.histogram(probs, bins=np.linspace(0.0, 1.0, 11))
    for idx, count in enumerate(hist_vals):
        print(f'  {hist_edges[idx]:.1f}-{hist_edges[idx + 1]:.1f}: {int(count)}')
    return probs

def ensure_dir(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

ensure_dir(RESULT_IMAGE_DIR)

print()
print(f"--- Tai model tu '{model_checkpoint_path}' de danh gia ---")
best_model_to_evaluate = None
current_strategy = tf.distribute.get_strategy()
if os.path.exists(model_checkpoint_path):
    try:
        with current_strategy.scope():
            best_model_to_evaluate = tf.keras.models.load_model(model_checkpoint_path)
        print('Da load model tot nhat tu checkpoint.')
    except Exception as exc:  # pylint: disable=broad-except
        print(f'Loi load checkpoint: {exc}')
        if 'model' in locals():
            best_model_to_evaluate = model
            print('Su dung model cuoi cung tren bo nho thay the.')
else:
    if 'model' in locals():
        best_model_to_evaluate = model
        print('Khong thay checkpoint, su dung model cuoi cung.')
    else:
        print('Khong co model nao de danh gia.')

best_threshold_f1 = 0.5
val_probs = None

if best_model_to_evaluate is not None and 'X_val' in locals() and isinstance(X_val, np.ndarray) and X_val.size > 0:
    print()
    print('--- Danh gia tren tap Validation ---')
    val_probs = summary_probability_stats('Validation', best_model_to_evaluate.predict(X_val))
    precisions_pr, recalls_pr, thresholds_pr = precision_recall_curve(y_val, val_probs)
    thresholds_from_pr = thresholds_pr if thresholds_pr.size > 0 else np.array([])
    candidate_thresholds_val = np.unique(
        np.concatenate([
            thresholds_from_pr,
            np.asarray(THRESHOLD_CANDIDATES, dtype=np.float32),
            np.array([0.5], dtype=np.float32)
        ])
    )
    candidate_thresholds_val = candidate_thresholds_val[np.isfinite(candidate_thresholds_val)]

    best_threshold_macro = 0.5
    best_macro_f1 = -1.0
    best_macro_details = None

    best_threshold_positive = 0.5
    best_positive_f1 = -1.0
    best_positive_details = None

    for thr in candidate_thresholds_val:
        preds_thr = (val_probs >= thr).astype(int)
        precs_thr = precision_score(y_val, preds_thr, average=None, labels=[0, 1], zero_division=0)
        recs_thr = recall_score(y_val, preds_thr, average=None, labels=[0, 1], zero_division=0)
        f1s_thr = f1_score(y_val, preds_thr, average=None, labels=[0, 1], zero_division=0)
        macro_thr = float(np.mean(f1s_thr))
        if macro_thr > best_macro_f1:
            best_macro_f1 = macro_thr
            best_threshold_macro = float(thr)
            best_macro_details = (precs_thr, recs_thr, f1s_thr)
        if f1s_thr[1] > best_positive_f1:
            best_positive_f1 = float(f1s_thr[1])
            best_threshold_positive = float(thr)
            best_positive_details = (precs_thr[1], recs_thr[1], f1s_thr[1])

    if best_macro_details is not None:
        precs_macro, recs_macro, f1s_macro = best_macro_details
        print(
            f"Nguong toi uu theo macro F1: {best_threshold_macro:.4f} "
            f"(F1_clean={f1s_macro[0]:.4f}, F1_stego={f1s_macro[1]:.4f}, macro={best_macro_f1:.4f})"
        )
    else:
        print('Khong tinh duoc macro F1 hop le, giu nguong 0.5.')
        best_threshold_macro = 0.5
        best_macro_f1 = 0.0

    if best_positive_details is not None:
        prec_pos, rec_pos, f1_pos = best_positive_details
        print(f"Nguong toi uu theo F1 lop stego: {best_threshold_positive:.4f} (precision={prec_pos:.4f}, recall={rec_pos:.4f}, F1={f1_pos:.4f})")
    else:
        print('Khong tinh duoc F1 cho lop stego, giu nguong 0.5.')
        best_threshold_positive = 0.5
        best_positive_f1 = 0.0

    best_threshold_f1 = best_threshold_macro

    plt.figure(figsize=(7, 6))
    plt.plot(recalls_pr, precisions_pr, label='PR Curve')
    if thresholds_pr.size > 0 and len(recalls_pr) > 0:
        try:
            close_idx = np.argmin(np.abs(thresholds_pr - best_threshold_f1))
            point_idx = min(close_idx + 1, len(recalls_pr) - 1)
            plt.scatter(recalls_pr[point_idx], precisions_pr[point_idx], color='red', s=80, label=f'Macro F1 (thr={best_threshold_f1:.2f})')
        except Exception as exc:  # pylint: disable=broad-except
            print(f'Khong ve duoc diem nguong macro: {exc}')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve (Validation)')
    plt.grid(True)
    plt.legend()
    pr_path = os.path.join(RESULT_IMAGE_DIR, pr_curve_filename)
    plt.savefig(pr_path)
    plt.close()
    print(f'Da luu PR curve: {pr_path}')

    fpr, tpr, _ = roc_curve(y_val, val_probs)
    roc_auc = sklearn_auc(fpr, tpr)
    plt.figure(figsize=(7, 6))
    plt.plot(fpr, tpr, label=f'ROC (AUC={roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], linestyle='--', color='grey')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve (Validation)')
    plt.grid(True)
    plt.legend()
    roc_path = os.path.join(RESULT_IMAGE_DIR, ROC_PLOT_FILENAME)
    plt.savefig(roc_path)
    plt.close()
    print(f'Da luu ROC curve: {roc_path}')
else:
    print('Khong co du lieu Validation hoac model de tinh PR/ROC.')

candidate_thresholds = sorted(set(THRESHOLD_CANDIDATES + [best_threshold_macro, best_threshold_positive]))
if CUSTOM_DECISION_THRESHOLD is not None:
    candidate_thresholds.append(CUSTOM_DECISION_THRESHOLD)
candidate_thresholds = sorted(set(candidate_thresholds))

if best_model_to_evaluate is not None and 'X_test' in locals() and isinstance(X_test, np.ndarray) and X_test.size > 0:
    print()
    print('--- Danh gia tren tap Test ---')
    test_probs = summary_probability_stats('Test', best_model_to_evaluate.predict(X_test))

    def summarize_threshold(threshold, probs, labels):
        preds = (probs >= threshold).astype(int)
        precs = precision_score(labels, preds, average=None, labels=[0, 1], zero_division=0)
        recs = recall_score(labels, preds, average=None, labels=[0, 1], zero_division=0)
        f1s = f1_score(labels, preds, average=None, labels=[0, 1], zero_division=0)
        cm = confusion_matrix(labels, preds, labels=[0, 1])
        return precs, recs, f1s, cm

    print('Threshold | P_clean | R_clean | F1_clean | P_stego | R_stego | F1_stego')
    print('---------|---------|---------|---------|---------|---------|---------')
    threshold_summaries = {}
    for thr in candidate_thresholds:
        precs_thr, recs_thr, f1s_thr, cm_thr = summarize_threshold(thr, test_probs, y_test)
        threshold_summaries[thr] = {
            'precision': precs_thr,
            'recall': recs_thr,
            'f1': f1s_thr,
            'cm': cm_thr,
        }
        print(
            f'{thr:9.4f}| {precs_thr[0]:.4f} | {recs_thr[0]:.4f} | {f1s_thr[0]:.4f} | '
            f'{precs_thr[1]:.4f} | {recs_thr[1]:.4f} | {f1s_thr[1]:.4f}'
        )

    if CUSTOM_DECISION_THRESHOLD is not None:
        final_threshold = CUSTOM_DECISION_THRESHOLD
        print()
        print(f'Su dung nguong tuy chinh: {final_threshold:.4f}')
    else:
        final_threshold = best_threshold_macro
        print()
        print(f'Su dung nguong toi uu theo macro F1 tren Validation: {final_threshold:.4f}')
        print(f' (Nguong F1 lop stego: {best_threshold_positive:.4f})')

    final_preds = (test_probs >= final_threshold).astype(int)
    print()
    print('Classification Report (Test):')
    print(classification_report(y_test, final_preds, target_names=['Lop Sach (0)', 'Lop Stego (1)'], digits=4, zero_division=0))
    cm_final = confusion_matrix(y_test, final_preds, labels=[0, 1])
    print('Confusion Matrix (Test):')
    print(cm_final)
    fp = int(cm_final[0, 1])
    fn = int(cm_final[1, 0])
    print(f'False Positive (Sach->Stego): {fp}, False Negative (Stego->Sach): {fn}')
else:
    print('Khong co du lieu Test hoac model de danh gia.')

print()
print('--- Ve do thi lich su huan luyen ---')
if 'history' in locals() and history is not None and hasattr(history, 'history') and history.history:
    metrics_to_plot = ['accuracy', 'loss', 'precision', 'recall', 'auc', 'pr_auc']
    available = [m for m in metrics_to_plot if m in history.history]
    if available:
        rows = (len(available) + 1) // 2
        plt.figure(figsize=(15, 5 * rows))
        for idx, metric in enumerate(available, start=1):
            plt.subplot(rows, 2, idx)
            plt.plot(history.history[metric], label=f'Train {metric}')
            val_key = f'val_{metric}'
            if val_key in history.history:
                plt.plot(history.history[val_key], label=f'Val {metric}')
            plt.title(metric.capitalize())
            plt.xlabel('Epoch')
            plt.ylabel(metric.capitalize())
            plt.grid(True)
            plt.legend()
        plt.tight_layout()
        hist_path = os.path.join(RESULT_IMAGE_DIR, training_history_plot_filename)
        plt.savefig(hist_path)
        plt.close()
        print(f'Da luu bieu do training: {hist_path}')
    else:
        print('History khong co metric nao hop le de ve.')
else:
    print('Khong co history (co the do training bi huy truoc do).')

print()
print('--- HOAN TAT BUOC DANH GIA ---')

