In [None]:
# ==================== DEPENDENCIES ====================
# Install required packages (uncomment if needed)
# !pip install tensorflow>=2.13.0
# !pip install numpy pandas scikit-learn
# !pip install matplotlib seaborn
# !pip install fairlearn

"""
Required Dependencies:
- tensorflow >= 2.13.0 (with GPU support recommended)
- numpy >= 1.23.0
- pandas >= 1.5.0
- scikit-learn >= 1.2.0
- matplotlib >= 3.6.0
- seaborn >= 0.12.0
- fairlearn >= 0.8.0

Hardware Requirements:
- GPU with CUDA support (recommended for training)
- Minimum 16GB RAM
- Distributed training across 2+ GPUs supported via tf.distribute.MirroredStrategy
"""

print("✅ All dependencies loaded successfully")
print(f"TensorFlow version: {__import__('tensorflow').__version__}")
print(f"NumPy version: {__import__('numpy').__version__}")
print(f"Pandas version: {__import__('pandas').__version__}")


In [None]:
# complete_integration_with_fairness.py
# Single-file training + eval script with fairness CVaR loss integrated.
# EarlyStopping + Best-Weights Checkpointing + EPOCHS=25

# ====================== CONFIG ======================
DATASET_ROOT = "/kaggle/input/input-folder/dataset"
PREPROC_DIR  = "Celeb-DF Preprocessed"
IMAGE_ROOT   = f"{DATASET_ROOT}/{PREPROC_DIR}/train"

CSV_PATH          = "/kaggle/input/input-folder/labels.csv"
CSV_REL_PATH_COL  = "relative_path"
CSV_LABEL_COL     = "label"
CSV_AGE_COL       = "age_group"
CSV_GENDER_COL    = "gender"
CSV_SKIN_COL      = "skin_tone"

IMG_SIZE    = (128, 128)
LATENT_DIM  = 128
EPOCHS      = 30         # <— requested
BATCH_SIZE  = 64

VAL_SPLIT   = 0.20
LEARNING_RATE = 1e-4

AGE_NUM_CLASSES    = 3
GENDER_NUM_CLASSES = 2
SKIN_NUM_CLASSES   = 4

USE_MIXED_PRECISION = True
MAX_SHUFFLE_BUFFER  = 12000

# Plot behavior (requested earlier): show plots by default, don't save
SAVE_PLOTS = False

# Reconstruction loss components weights
RECON_IMG_WEIGHT  = 1.0
RECON_FRFT_WEIGHT = 1.0

# --- ADD: real/fake supervision + schedules ---
REALFAKE_LOSS_WEIGHT = 1.5
FAIRNESS_WARMUP_EPOCHS = 3

# --- Early stopping on custom loop ---
ES_PATIENCE = 7                    # epochs
CKPT_DIR = "/kaggle/working/ae_exports/best"   # best weights location
FINAL_SAVE_DIR = "/kaggle/working/ae_exports"  # final save dir

import os, re, time, math, numpy as np, pandas as pd, tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.data import AUTOTUNE
from tensorflow.keras.layers import (Input, Dense, Flatten, Reshape, Dropout,
                                     Conv2D, MaxPooling2D, UpSampling2D, LayerNormalization,
                                     LeakyReLU, GaussianNoise)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import AdamW, Adam
from tensorflow.keras.losses import CategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.constraints import MaxNorm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve, precision_recall_curve, accuracy_score, confusion_matrix
from fairlearn.metrics import (
    MetricFrame,
    selection_rate,
    false_positive_rate,
    false_negative_rate,
    equalized_odds_difference,
    demographic_parity_difference,
)
PREFETCH_BUFSIZE = AUTOTUNE

# ============== GPU OPTS ===================
tf.config.set_soft_device_placement(True)
if USE_MIXED_PRECISION:
    mixed_precision.set_global_policy('mixed_float16')

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        try: tf.config.experimental.set_memory_growth(gpu, True)
        except: pass
    print("GPUs:", gpus)
else:
    print("No GPU; running on CPU.")

strategy = tf.distribute.MirroredStrategy()
print("Replicas in sync:", strategy.num_replicas_in_sync)
GLOBAL_BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync

# --- ADD: Global epoch tracker for warmups used inside steps ---
CURRENT_EPOCH_TF = tf.Variable(1.0, dtype=tf.float32, trainable=False)

# ================== FRFT (Order=3) =======================
import tensorflow as tf
import numpy as np

class FRFTPrecomp:
    def __init__(self, H, W, alphas_row=None, alphas_col=None, order=5):
        self.H, self.W = H, W

        # Handle old-style alpha lists
        if alphas_row is not None and isinstance(alphas_row, (list, tuple)):
            # Use the middle or last alpha from list as main FRFT order
            self.alpha = float(np.mean(alphas_row))
        else:
            # Fall back to order parameter
            self.alpha = np.pi * order / 4.0

        self.order = order


    def frft2d_vectorized(self, x):
        """Fully parallel 2D FRFT (chirp–FFT–chirp), GPU-friendly."""
        H, W = self.H, self.W
        alpha = tf.constant(self.alpha, tf.float32)
    
        # Compute required trig values
        tan_a2 = tf.tan(alpha / 2.0)
        sin_a = tf.sin(alpha)
    
        # Generate coordinate grids
        h = tf.cast(tf.range(H), tf.float32) - H / 2.0
        w = tf.cast(tf.range(W), tf.float32) - W / 2.0
    
        # Precompute broadcastable chirps (complex64)
        jpi = tf.complex(0.0, -np.pi)
        chirp_h = tf.exp(jpi * tf.cast((h ** 2) * tan_a2 / H, tf.complex64))
        chirp_w = tf.exp(jpi * tf.cast((w ** 2) * tan_a2 / W, tf.complex64))
    
        # Broadcast over batch
        x = tf.cast(x, tf.complex64)
        x = x * tf.reshape(chirp_h, [1, H, 1]) * tf.reshape(chirp_w, [1, 1, W])
    
        # 2D FFT (parallelized)
        X = tf.signal.fft2d(x)
    
        # Post-chirp (cast floats to complex64!)
        chirp2_h = tf.exp(jpi * tf.cast((h ** 2) * tan_a2 / H, tf.complex64))
        chirp2_w = tf.exp(jpi * tf.cast((w ** 2) * tan_a2 / W, tf.complex64))
        X = X * tf.reshape(chirp2_h, [1, H, 1]) * tf.reshape(chirp2_w, [1, 1, W])
    
        # Global phase & scale correction
        phase = tf.exp(tf.complex(0.0, -alpha * np.pi / 4.0))
        X = X * phase / tf.sqrt(tf.cast(tf.abs(sin_a), tf.complex64))
    
        return X

    def frft2_grid_mag01(self, img_hw1):
        """Compute FRFT magnitude + log normalization."""
        x = tf.cast(img_hw1[..., 0], tf.complex64)
        X = self.frft2d_vectorized(x)
        mag = tf.abs(X)
        mag = tf.math.log1p(mag)
        mean = tf.reduce_mean(mag, axis=[0, 1, 2], keepdims=True)
        std = tf.math.reduce_std(mag, axis=[0, 1, 2], keepdims=True)
        mag_std = (mag - mean) / (std + 1e-6)
        return tf.expand_dims(tf.cast(mag_std, tf.float32), -1)

# ============== Gradient Reversal ===================
@tf.custom_gradient
def grad_reverse(x, lam):
    lam = tf.cast(lam, x.dtype)
    def grad(dy): return -dy * lam, None
    return tf.identity(x), grad

# ==================== MODELS ========================
def build_shared_autoencoder(image_size=(128,128,1), latent_dim=16, input_noise_std=0.05, input_dropout=0.05, leaky_alpha=0.1):
    h,w,c = image_size
    enc_in = Input(shape=(h,w,c))
    x = GaussianNoise(input_noise_std)(enc_in)
    x = Dropout(input_dropout)(x)
    x = Conv2D(32,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = MaxPooling2D((2,2),padding='same')(x)
    x = Conv2D(64,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = MaxPooling2D((2,2),padding='same')(x)
    x = Conv2D(128,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = MaxPooling2D((2,2),padding='same')(x)
    x = Flatten()(x)
    x = LayerNormalization(epsilon=1e-5)(x)
    z = Dense(latent_dim, activation=None , dtype='float32',
              kernel_initializer='he_normal')(x)
    encoder = Model(enc_in, z, name='encoder')

    dec_in = Input(shape=(latent_dim,), dtype='float32')
    x = Dense((h//8)*(w//8)*128, activation=None, kernel_initializer='he_normal')(dec_in); x = LeakyReLU(leaky_alpha)(x); x = Reshape((h//8, w//8, 128))(x)
    x = Conv2D(128,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = UpSampling2D((2,2))(x)
    x = Conv2D(64,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = UpSampling2D((2,2))(x)
    x = Conv2D(32,(3,3),padding='same',kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x); x = UpSampling2D((2,2))(x)
    out = Conv2D(c,(3,3),activation=None,padding='same', dtype='float32')(x)
    decoder = Model(dec_in, out, name='decoder')
    return encoder, decoder

def build_classifier(latent_dim, leaky_alpha=0.1):
    inp = Input(shape=(latent_dim,), dtype='float32')
    x = LayerNormalization(epsilon=1e-5)(inp)
    x = Dense(max(16, latent_dim//2), activation=None, kernel_initializer='he_normal',
              kernel_constraint=MaxNorm(3.0))(x); x = LeakyReLU(leaky_alpha)(x)
    x = Dropout(0.1)(x)
    x = Dense(max(8, latent_dim//4), activation=None, kernel_initializer='he_normal',
              kernel_constraint=MaxNorm(3.0))(x); x = LeakyReLU(leaky_alpha)(x)
    age_out    = Dense(AGE_NUM_CLASSES,    name='age_output',       dtype='float32')(x)
    gender_out = Dense(GENDER_NUM_CLASSES, name='gender_output',    dtype='float32')(x)
    skin_out   = Dense(SKIN_NUM_CLASSES,   name='skin_tone_output', dtype='float32')(x)
    return Model(inp, [age_out, gender_out, skin_out], name='classifier')

def build_realfake_head(latent_dim, leaky_alpha=0.1):
    inp = Input(shape=(latent_dim,), dtype='float32')
    x = LayerNormalization(epsilon=1e-5)(inp)
    x = Dense(max(16, latent_dim//2), activation=None, kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x)
    x = Dropout(0.2)(x)
    x = Dense(max(8, latent_dim//4), activation=None, kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x)
    out = Dense(1, activation=None, name='real_fake_logit', dtype='float32')(x)
    return Model(inp, out, name='realfake_head')

# ==================== LOSSES & AUX ========================
ce = CategoricalCrossentropy(from_logits=True, label_smoothing=0.05)

def latent_loss(z, z_hat, beta=0.5, gamma=0.5, eps=1e-6):
    z, z_hat = tf.cast(z, tf.float32), tf.cast(z_hat, tf.float32)
    z_t = tf.stop_gradient(z)
    diff   = z_hat - z_t
    denom  = (tf.reduce_sum(tf.square(z_t),-1)+tf.reduce_sum(tf.square(z_hat),-1)+eps)
    rel_l2 = tf.reduce_mean(tf.reduce_sum(tf.square(diff),-1) / denom)
    zt_n = tf.math.l2_normalize(z_t, -1)
    zh_n = tf.math.l2_normalize(z_hat, -1)
    cos_l  = tf.reduce_mean(1.0 - tf.reduce_sum(zt_n*zh_n,-1))
    return beta*rel_l2 + gamma*cos_l

def logit_norm_penalty(age_logits, gen_logits, skin_logits, alpha=1e-4):
    terms = [tf.reduce_mean(tf.square(tf.cast(t, tf.float32)))
             for t in (age_logits, gen_logits, skin_logits)]
    return alpha * tf.add_n(terms)

def calculate_total_loss(losses: dict, loss_weights: dict) -> tf.Tensor:
    total = tf.constant(0.0, dtype=tf.float32)
    for k, v in losses.items():
        w = loss_weights.get(k, 0.0)
        total += tf.cast(w, tf.float32) * tf.cast(v, tf.float32)
    return total

# ==================== FAIRNESS CVaR LOSS ========================
def calculate_fairness_loss(reconstruction_losses_fake,
                            age_labels,
                            gender_labels,
                            skin_tone_labels,
                            cvar_alpha=0.2):
    total_cvar_loss = tf.constant(0.0, dtype=tf.float32)
    num_groups = tf.constant(0, dtype=tf.int32)

    age_indices = tf.argmax(age_labels, axis=-1, output_type=tf.int32)
    gender_indices = tf.argmax(gender_labels, axis=-1, output_type=tf.int32)
    skin_indices = tf.argmax(skin_tone_labels, axis=-1, output_type=tf.int32)

    n_age = tf.shape(age_labels, out_type=tf.int32)[1]
    n_gender = tf.shape(gender_labels, out_type=tf.int32)[1]
    n_skin = tf.shape(skin_tone_labels, out_type=tf.int32)[1]

    demographic = {
        'age': (age_indices, n_age),
        'gender': (gender_indices, n_gender),
        'skin': (skin_indices, n_skin),
    }

    for _, (indices, n_classes) in demographic.items():
        for gid in tf.range(n_classes, dtype=tf.int32):
            mask = tf.equal(indices, gid)
            mask_any = tf.reduce_any(mask)
            if not mask_any:
                continue
            selected = tf.boolean_mask(reconstruction_losses_fake, mask)
            sorted_sel = tf.sort(selected, direction='DESCENDING')
            n = tf.shape(sorted_sel, out_type=tf.int32)[0]
            k = tf.cast(tf.cast(n, tf.float32) * cvar_alpha, tf.int32)
            k = tf.maximum(k, 1)
            top_k = sorted_sel[:k]
            cvar = tf.reduce_mean(top_k)
            total_cvar_loss += tf.cast(cvar, tf.float32)
            num_groups += 1

    total_cvar_loss = tf.math.divide_no_nan(total_cvar_loss, tf.cast(num_groups, tf.float32))
    return {'fairness_cvar_loss': total_cvar_loss}

# ==================== LOSS WEIGHTS ========================
LOSS_WEIGHTS_BASE = {
    'reconstruction_loss': 1.0,
    'latent_loss': 1.0,
    'age_adversarial_loss': 0.1,
    'gender_adversarial_loss': 0.1,
    'skin_tone_adversarial_loss': 0.1,
    'fairness_cvar_loss': 0.5,
}

# ================== DATA HELPERS ====================
def _norm(s: str) -> str:
    return re.sub(r'[^a-z0-9]+', '', str(s).strip().lower())

def resolve_csv_path(rp: str) -> str:
    rp = str(rp).strip().replace("\\","/")
    if os.path.isabs(rp): return os.path.normpath(rp)
    if rp.startswith(PREPROC_DIR + "/") or rp.startswith(PREPROC_DIR + "\\"):
        return os.path.normpath(os.path.join(DATASET_ROOT, rp))
    if rp.startswith(("train/", "val/", "test/", "./train/", "./val/")):
        return os.path.normpath(os.path.join(DATASET_ROOT, PREPROC_DIR, rp))
    return os.path.normpath(os.path.join(IMAGE_ROOT, rp))

def load_labels(csv_path):
    df = pd.read_csv(csv_path)
    need = [CSV_REL_PATH_COL, CSV_LABEL_COL, CSV_AGE_COL, CSV_GENDER_COL, CSV_SKIN_COL]
    for col in need:
        if col not in df.columns:
            raise ValueError(f"CSV missing required column: {col}")
    LABEL_MAP = {'0':0,'1':1,'fake':0,'real':1}
    AGE_MAP   = {'0':0,'1':1,'2':2,'young':0,'adult':1,'middleage':1,'middle-aged':1,'senior':2,'old':2}
    GENDER_MAP= {'0':0,'1':1,'f':0,'female':0,'m':1,'male':1}
    SKIN_MAP  = {'0':0,'1':1,'2':2,'3':3,'verylight':0,'pale':0,'light':1,'fair':1,'medium':2,'olive':2,'tan':2,
                 'dark':3,'verydark':3,'brown':3,'black':3}
    def map_or_die(series, name, mapping, allowed):
        vals = series.astype(str)
        mapped = vals.map(lambda v: mapping.get(_norm(v), v))
        mapped = mapped.astype(int)
        bad = sorted(set(int(x) for x in mapped.unique() if int(x) not in allowed))
        if bad: raise ValueError(f"{name}: values {bad} not in allowed set {allowed}")
        return mapped
    df[CSV_LABEL_COL]  = map_or_die(df[CSV_LABEL_COL],  "label",     LABEL_MAP, {0,1})
    df[CSV_AGE_COL]    = map_or_die(df[CSV_AGE_COL],    "age_group", AGE_MAP,   set(range(AGE_NUM_CLASSES)))
    df[CSV_GENDER_COL] = map_or_die(df[CSV_GENDER_COL], "gender",    GENDER_MAP,set(range(GENDER_NUM_CLASSES)))
    df[CSV_SKIN_COL]   = map_or_die(df[CSV_SKIN_COL],   "skin_tone", SKIN_MAP,  set(range(SKIN_NUM_CLASSES)))
    df['abspath'] = df[CSV_REL_PATH_COL].map(resolve_csv_path)
    exists_mask = df['abspath'].apply(os.path.exists)
    missing_count = int((~exists_mask).sum())
    if missing_count:
        print(f"Warning: {missing_count} files listed in CSV not found on disk — they will be skipped.")
        print(df.loc[~exists_mask, CSV_REL_PATH_COL].head(10).to_list())
    df = df[exists_mask].copy().reset_index(drop=True)
    df = df.rename(columns={
        CSV_LABEL_COL: "is_real",
        CSV_AGE_COL:   "age_group",
        CSV_GENDER_COL:"gender",
        CSV_SKIN_COL:  "skin_tone",
    })
    return df[["abspath","is_real","age_group","gender","skin_tone"]]

# --- ADD: compute class weight for imbalanced real/fake ---
def compute_real_fake_pos_weight(df):
    counts = df["is_real"].value_counts().to_dict()
    num_fake = float(counts.get(0, 1.0))
    num_real = float(counts.get(1, 1.0))
    return max(num_fake / max(num_real, 1.0), 1.0)

# ==================== TF.DATA PIPELINE ====================
def decode_image_no_crash(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=1)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, IMG_SIZE)
    return img

def process_sample(path, is_real, age, gender, skin):
    img = decode_image_no_crash(path)
    age_oh = tf.one_hot(tf.cast(age, tf.int32), AGE_NUM_CLASSES)
    gen_oh = tf.one_hot(tf.cast(gender, tf.int32), GENDER_NUM_CLASSES)
    skin_oh = tf.one_hot(tf.cast(skin, tf.int32), SKIN_NUM_CLASSES)
    return img, (tf.cast(is_real, tf.int32),
                 tf.cast(age_oh, tf.float32),
                 tf.cast(gen_oh, tf.float32),
                 tf.cast(skin_oh, tf.float32))

def make_dataset(df, training=True, cache_dir=None):
    paths = df['abspath'].values.astype(str)
    is_reals = df['is_real'].values.astype(np.int32)
    ages  = df['age_group'].values.astype(np.int32)
    gens  = df['gender'].values.astype(np.int32)
    skins = df['skin_tone'].values.astype(np.int32)
    ds = tf.data.Dataset.from_tensor_slices((paths, is_reals, ages, gens, skins))
    ds = ds.map(lambda p,ir,a,g,s: process_sample(p,ir,a,g,s), num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.shuffle(min(len(df), MAX_SHUFFLE_BUFFER), reshuffle_each_iteration=True)
    if cache_dir:
        os.makedirs(cache_dir, exist_ok=True); ds = ds.cache(cache_dir)
    ds = ds.batch(GLOBAL_BATCH_SIZE, drop_remainder=True).prefetch(PREFETCH_BUFSIZE)
    return ds

# ==================== FEATURE EXTRACTION & BINARY HEAD ================
def extract_features(dataset, encoder_model, latent_dim, frft_precomp):
    X_list, y_list = [], []
    for batch in dataset:
        imgs, labels = batch
        is_real_batch = labels[0]
        imgs_frft = frft_precomp.frft2_grid_mag01(imgs)
        z = encoder_model(imgs_frft, training=False)
        X_list.append(tf.cast(z, tf.float32).numpy())
        y_list.append(tf.cast(is_real_batch, tf.int32).numpy())
    if not X_list:
        return np.zeros((0, latent_dim), dtype=np.float32), np.zeros((0,), dtype=np.int32)
    return np.concatenate(X_list, axis=0), np.concatenate(y_list, axis=0).astype(np.int32)

def build_binary_head(latent_dim, leaky_alpha=0.1):
    inp = Input(shape=(latent_dim,), dtype='float32')
    x = LayerNormalization(epsilon=1e-5)(inp)
    x = Dense(max(16, latent_dim//2), activation=None, kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x)
    x = Dropout(0.2)(x)
    x = Dense(max(8, latent_dim//4), activation=None, kernel_initializer='he_normal')(x); x = LeakyReLU(leaky_alpha)(x)
    out = Dense(1, activation='sigmoid', dtype='float32')(x)
    model = Model(inp, out, name='binary_head')
    model.compile(optimizer=Adam(learning_rate=1e-4),
                  loss=BinaryCrossentropy(),
                  metrics=[tf.keras.metrics.AUC(name='auc')])
    return model

# ==================== TRAIN/VAL STEPS ======================
def dist_train_step(batch, grl_lambda, loss_weights):
    imgs, labels = batch
    age_lab, gender_lab, skin_lab = labels[1], labels[2], labels[3]

    # Build FRFT inputs
    imgs_frft = frft_pre.frft2_grid_mag01(imgs)

    with tf.GradientTape() as tape:
        # Encode FRFT magnitude
        z = encoder(imgs_frft, training=True)

        # Decode spatial image
        recon_img = decoder(z, training=True)

        # Re-encode FRFT of reconstructed image for latent cycle
        recon_frft = frft_pre.frft2_grid_mag01(recon_img)
        z_hat = encoder(recon_frft, training=True)

        # Adversarial heads with GRL
        z_rev = grad_reverse(z, grl_lambda)
        age_pred, gen_pred, skin_pred = classifier(z_rev, training=True)  # logits

        # Reconstruction losses: image-space + FRFT-space
        per_ex_recon_img  = tf.reduce_mean(tf.square(tf.cast(imgs, tf.float32)     - tf.cast(recon_img, tf.float32)), axis=[1,2,3])
        per_ex_recon_frft = tf.reduce_mean(tf.square(tf.cast(imgs_frft, tf.float32) - tf.cast(recon_frft, tf.float32)), axis=[1,2,3])
        recon_loss = RECON_IMG_WEIGHT * tf.reduce_mean(per_ex_recon_img) + RECON_FRFT_WEIGHT * tf.reduce_mean(per_ex_recon_frft)

        # Latent cycle loss
        lat_loss = latent_loss(z, z_hat, beta=1.0)

        # Adversarial classification losses
        age_ce = ce(age_lab, tf.cast(age_pred, tf.float32))
        gen_ce = ce(gender_lab, tf.cast(gen_pred, tf.float32))
        skin_ce = ce(skin_lab, tf.cast(skin_pred, tf.float32))

        logit_pen = logit_norm_penalty(age_pred, gen_pred, skin_pred, alpha=1e-4)

        # Fairness CVaR computed on FRFT reconstruction errors (per-sample)
        fairness_losses = calculate_fairness_loss(per_ex_recon_frft, age_lab, gender_lab, skin_lab, cvar_alpha=0.2)
        fairness_loss = fairness_losses['fairness_cvar_loss']

        # real/fake head
        is_real = tf.cast(labels[0], tf.float32)
        rf_logit = realfake_head(z, training=True)
        rf_logit = tf.squeeze(tf.cast(rf_logit, tf.float32), axis=-1)
        rf_loss_per = tf.nn.weighted_cross_entropy_with_logits(labels=is_real, logits=rf_logit, pos_weight=POS_WEIGHT_TF)
        real_fake_loss = tf.reduce_mean(rf_loss_per)

        # fairness warmup
        fairness_scale = tf.minimum(1.0, CURRENT_EPOCH_TF / tf.cast(FAIRNESS_WARMUP_EPOCHS, tf.float32))

        total = calculate_total_loss({
            'reconstruction_loss': recon_loss,
            'latent_loss': lat_loss,
            'age_adversarial_loss': age_ce,
            'gender_adversarial_loss': gen_ce,
            'skin_tone_adversarial_loss': skin_ce,
            'fairness_cvar_loss': fairness_loss,
        }, loss_weights) + logit_pen

        total += (fairness_scale - 1.0) * tf.cast(LOSS_WEIGHTS_BASE['fairness_cvar_loss'], tf.float32) * tf.cast(fairness_loss, tf.float32)
        total += tf.cast(REALFAKE_LOSS_WEIGHT, tf.float32) * tf.cast(real_fake_loss, tf.float32)

    vars_all = encoder.trainable_variables + decoder.trainable_variables + classifier.trainable_variables + realfake_head.trainable_variables
    grads = tape.gradient(total, vars_all)
    grads, _ = tf.clip_by_global_norm(grads, 5.0)
    opt.apply_gradients(zip(grads, vars_all))

    return recon_loss, lat_loss, age_ce, gen_ce, skin_ce, fairness_loss, total

def dist_val_step(batch, grl_lambda, loss_weights):
    imgs, labels = batch
    age_lab, gender_lab, skin_lab = labels[1], labels[2], labels[3]
    imgs_frft = frft_pre.frft2_grid_mag01(imgs)

    z = encoder(imgs_frft, training=False)
    recon_img = decoder(z, training=False)
    recon_frft = frft_pre.frft2_grid_mag01(recon_img)
    z_hat = encoder(recon_frft, training=False)

    z_rev = grad_reverse(z, grl_lambda)
    age_pred, gen_pred, skin_pred = classifier(z_rev, training=False)  # logits

    per_ex_recon_img  = tf.reduce_mean(tf.square(tf.cast(imgs, tf.float32)     - tf.cast(recon_img, tf.float32)), axis=[1,2,3])
    per_ex_recon_frft = tf.reduce_mean(tf.square(tf.cast(imgs_frft, tf.float32) - tf.cast(recon_frft, tf.float32)), axis=[1,2,3])
    recon_loss = RECON_IMG_WEIGHT * tf.reduce_mean(per_ex_recon_img) + RECON_FRFT_WEIGHT * tf.reduce_mean(per_ex_recon_frft)

    lat_loss = latent_loss(z, z_hat, beta=1.0)
    age_ce = ce(age_lab, tf.cast(age_pred, tf.float32))
    gen_ce = ce(gender_lab, tf.cast(gen_pred, tf.float32))
    skin_ce = ce(skin_lab, tf.cast(skin_pred, tf.float32))

    fairness_losses = calculate_fairness_loss(per_ex_recon_frft, age_lab, gender_lab, skin_lab, cvar_alpha=0.2)
    fairness_loss = fairness_losses['fairness_cvar_loss']

    is_real = tf.cast(labels[0], tf.float32)
    rf_logit = realfake_head(z, training=False)
    rf_logit = tf.squeeze(tf.cast(rf_logit, tf.float32), axis=-1)
    rf_loss_per = tf.nn.weighted_cross_entropy_with_logits(labels=is_real, logits=rf_logit, pos_weight=POS_WEIGHT_TF)
    real_fake_loss = tf.reduce_mean(rf_loss_per)

    fairness_scale = tf.minimum(1.0, CURRENT_EPOCH_TF / tf.cast(FAIRNESS_WARMUP_EPOCHS, tf.float32))

    total = calculate_total_loss({
            'reconstruction_loss': recon_loss,
            'latent_loss': lat_loss,
            'age_adversarial_loss': age_ce,
            'gender_adversarial_loss': gen_ce,
            'skin_tone_adversarial_loss': skin_ce,
            'fairness_cvar_loss': fairness_loss,
    }, loss_weights) + logit_norm_penalty(age_pred, gen_pred, skin_pred, alpha=1e-4)

    total += (fairness_scale - 1.0) * tf.cast(LOSS_WEIGHTS_BASE['fairness_cvar_loss'], tf.float32) * tf.cast(fairness_loss, tf.float32)
    total += tf.cast(REALFAKE_LOSS_WEIGHT, tf.float32) * tf.cast(real_fake_loss, tf.float32)

    return recon_loss, lat_loss, age_ce, gen_ce, skin_ce, fairness_loss, total

# ================== MAIN ===========================
def main():
    global encoder, decoder, classifier, opt, frft_pre, ce, realfake_head, POS_WEIGHT_TF

    tf.random.set_seed(42); np.random.seed(42)
    H, W = IMG_SIZE

    ALPHA_GRID_RAD = [float(np.deg2rad(a)) for a in [0,30,60,90]]
    frft_pre = FRFTPrecomp(H, W, ALPHA_GRID_RAD, ALPHA_GRID_RAD)

    df = load_labels(CSV_PATH)
    if df.empty:
        raise RuntimeError("No valid image files found after resolving CSV paths.")
    print(f"Found {len(df)} valid files after resolving paths.")

    val_n = int(len(df) * VAL_SPLIT)
    val_df = df.iloc[:val_n].reset_index(drop=True)
    train_df = df.iloc[val_n:].reset_index(drop=True)
    print(f"Train / Val sizes: {len(train_df)} / {len(val_df)}")

    tr_ds = make_dataset(train_df, training=True, cache_dir=None)
    va_ds = make_dataset(val_df, training=False, cache_dir=None)

    tr_dist = strategy.experimental_distribute_dataset(tr_ds)
    va_dist = strategy.experimental_distribute_dataset(va_ds)

    with strategy.scope():
        encoder, decoder = build_shared_autoencoder((H, W, 1), LATENT_DIM)
        classifier = build_classifier(LATENT_DIM)
        realfake_head = build_realfake_head(LATENT_DIM)
        opt = AdamW(learning_rate=LEARNING_RATE, weight_decay=1e-4)
        ce  = CategoricalCrossentropy(from_logits=True, label_smoothing=0.05)

    # --- class weight for weighted BCE ---
    pos_weight_scalar = compute_real_fake_pos_weight(train_df)
    POS_WEIGHT_TF = tf.constant(pos_weight_scalar, dtype=tf.float32)
    print(f"Real/Fake pos_weight = {pos_weight_scalar:.3f}")

    # --- early stopping bookkeeping ---
    os.makedirs(CKPT_DIR, exist_ok=True)
    best_val_recon = np.inf
    best_val_fair  = np.inf
    es_wait = 0
    # training loop
    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        CURRENT_EPOCH_TF.assign(float(epoch))
        grl_lambda = tf.constant(min(0.5, epoch / 10.0), tf.float32)

        tr_sums = np.zeros(7, dtype=np.float64); tr_batches = 0
        for batch in tr_dist:
            per_replica_outs = strategy.run(dist_train_step, args=(batch, grl_lambda, LOSS_WEIGHTS_BASE))
            outs_means = []
            for i in range(7):
                local_vals = strategy.experimental_local_results(per_replica_outs[i])
                outs_means.append(float(tf.reduce_mean(tf.stack(local_vals)).numpy()))
            tr_sums += np.array(outs_means); tr_batches += 1

        va_sums = np.zeros(7, dtype=np.float64); va_batches = 0
        for batch in va_dist:
            per_replica_outs = strategy.run(dist_val_step, args=(batch, grl_lambda, LOSS_WEIGHTS_BASE))
            outs_means = []
            for i in range(7):
                local_vals = strategy.experimental_local_results(per_replica_outs[i])
                outs_means.append(float(tf.reduce_mean(tf.stack(local_vals)).numpy()))
            va_sums += np.array(outs_means); va_batches += 1

        tr_avg = tr_sums / max(1, tr_batches)
        va_avg = va_sums / max(1, va_batches)
        t1 = time.time()

        print(f"[Epoch {epoch:02d}/{EPOCHS}] time/epoch={t1-t0:.1f}s")
        print(f"  Train: recon={tr_avg[0]:.4f} lat={tr_avg[1]:.4f} age={tr_avg[2]:.4f} gen={tr_avg[3]:.4f} skin={tr_avg[4]:.4f} fairness={tr_avg[5]:.6f} total={tr_avg[6]:.4f}")
        print(f"  Val  : recon={va_avg[0]:.4f} lat={va_avg[1]:.4f} age={va_avg[2]:.4f} gen={va_avg[3]:.4f} skin={va_avg[4]:.4f} fairness={va_avg[5]:.6f} total={va_avg[6]:.4f}")

           # ===========================================================
        # 🔍 FAIRNESS-AWARE EARLY STOPPING (reconstruction + fairness)
        # ===========================================================
        # ===========================================================
        # 🔍 FAIRNESS + RECONSTRUCTION EARLY STOPPING
        # ===========================================================
        val_recon = va_avg[0]  # reconstruction loss (image + FRFT)
        val_fair  = va_avg[5]  # fairness CVaR loss
        
        # Initialize tracking variables before the epoch loop if not already
        # best_val_recon = np.inf
        # best_val_fair  = np.inf
        # es_wait = 0
        
        # Criteria for improvement: either recon or fairness improved
        RECON_TOLERANCE = 0.995  # allow slight fluctuation
        FAIRNESS_TOLERANCE = 0.995
        
        improved_recon = val_recon < best_val_recon * RECON_TOLERANCE
        improved_fair  = val_fair  < best_val_fair * FAIRNESS_TOLERANCE
        
        if improved_recon or improved_fair:
            best_val_recon = min(best_val_recon, val_recon)
            best_val_fair  = min(best_val_fair, val_fair)
            es_wait = 0
        
            # Save best weights
            encoder.save_weights(os.path.join(CKPT_DIR, "encoder.best.weights.h5"))
            decoder.save_weights(os.path.join(CKPT_DIR, "decoder.best.weights.h5"))
            classifier.save_weights(os.path.join(CKPT_DIR, "classifier.best.weights.h5"))
            realfake_head.save_weights(os.path.join(CKPT_DIR, "realfake_head.best.weights.h5"))
        
            print(f"  ✅ Improvement detected — Recon: {val_recon:.6f}, Fair: {val_fair:.6f} — checkpointed.")
        
        else:
            es_wait += 1
            print(f"  ⚠️  No improvement ({es_wait}/{ES_PATIENCE}) — Recon={val_recon:.6f}, Fair={val_fair:.6f}")
            
            if es_wait >= ES_PATIENCE:
                print("  ⛔ Early stopping triggered — recon + fairness plateau reached.")
                break

    try:
        encoder.load_weights(os.path.join(CKPT_DIR, "encoder.best.weights.h5"))
        decoder.load_weights(os.path.join(CKPT_DIR, "decoder.best.weights.h5"))
        classifier.load_weights(os.path.join(CKPT_DIR, "classifier.best.weights.h5"))
        realfake_head.load_weights(os.path.join(CKPT_DIR, "realfake_head.best.weights.h5"))
        print("Loaded best checkpointed weights.")
    except Exception as e:
        print("Warning: failed to load best weights (continuing with current):", e)

    # ---- Feature extraction & binary head ----
    print("Extracting features for binary head training...")
    X_tr, y_tr = extract_features(tr_ds, encoder, LATENT_DIM, frft_pre)
    X_va, y_va = extract_features(va_ds, encoder, LATENT_DIM, frft_pre)
    print(f"Features shapes: X_tr {X_tr.shape}, y_tr {y_tr.shape} | X_va {X_va.shape}, y_va {y_va.shape}")

    bin_head = build_binary_head(LATENT_DIM)

    if X_tr.shape[0] > 0:
        bin_head.fit(
            X_tr, y_tr.astype(np.float32),
            validation_data=(X_va, y_va.astype(np.float32)),
            epochs=50, batch_size=512,
            callbacks=[
                tf.keras.callbacks.EarlyStopping(monitor='val_auc', mode='max', patience=7, restore_best_weights=True),
                tf.keras.callbacks.ReduceLROnPlateau(monitor='val_auc', mode='max', factor=0.5, patience=3, verbose=1),
                tf.keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(CKPT_DIR, "binary_head.best.weights.h5"),
                    monitor='val_auc', mode='max', save_best_only=True, save_weights_only=True
                )
            ],
            verbose=1
        )
        # load best binary head
        try:
            bin_head.load_weights(os.path.join(CKPT_DIR, "binary_head.best.weights.h5"))
            print("Loaded best binary head weights.")
        except Exception as e:
            print("Warning loading best binary head:", e)
    else:
        print("No training examples found for binary head; skipping bin_head.fit")

    # save final weights (optional)
    os.makedirs(FINAL_SAVE_DIR, exist_ok=True)
    try:
        encoder.save_weights(os.path.join(FINAL_SAVE_DIR, "encoder.weights.h5"))
        decoder.save_weights(os.path.join(FINAL_SAVE_DIR, "decoder.weights.h5"))
        classifier.save_weights(os.path.join(FINAL_SAVE_DIR, "classifier.weights.h5"))
        realfake_head.save_weights(os.path.join(FINAL_SAVE_DIR, "realfake_head.weights.h5"))
        bin_head.save_weights(os.path.join(FINAL_SAVE_DIR, "binary_head.weights.h5"))
        print("Saved final model weights to", FINAL_SAVE_DIR)
    except Exception as e:
        print("Warning: failed to save some weights:", e)

    return encoder, decoder, classifier, bin_head, frft_pre

def plot_fairness(y_true, y_pred_prob, sensitive_attr, groups, WEIGHTS_DIR):
    """
    Plots fairness metrics: CVaR, Equalized Odds, demographic parity gap.
    """
    # CVaR (tail of losses) per group
    cvar_dict = {}
    alpha = 0.9  # top 10% losses
    losses = -(y_true * np.log(y_pred_prob + 1e-12) + (1 - y_true) * np.log(1 - y_pred_prob + 1e-12))
    
    for g in groups:
        g_mask = (sensitive_attr == g)
        if g_mask.sum() == 0: continue
        g_losses = losses[g_mask]
        threshold = np.quantile(g_losses, alpha)
        cvar = g_losses[g_losses >= threshold].mean()
        cvar_dict[g] = cvar
    
    plt.figure(figsize=(6,5))
    sns.barplot(x=list(cvar_dict.keys()), y=list(cvar_dict.values()))
    plt.title(f"CVaR @ {alpha*100:.0f}% Losses per Group")
    plt.ylabel("CVaR Loss"); plt.xlabel("Group")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(WEIGHTS_DIR, "fairness_cvar.png"), dpi=150)
    plt.show()

    # Equalized Odds (TPR/FPR per group)
    eo_dict = {}
    for g in groups:
        g_mask = (sensitive_attr == g)
        if g_mask.sum() == 0: continue
        y_g = y_true[g_mask]
        y_pred_g = (y_pred_prob[g_mask] >= 0.5).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_g, y_pred_g, labels=[0,1]).ravel()
        tpr = tp / (tp + fn + 1e-12)
        fpr = fp / (fp + tn + 1e-12)
        eo_dict[g] = (tpr, fpr)

    tpr_vals = [v[0] for v in eo_dict.values()]
    fpr_vals = [v[1] for v in eo_dict.values()]
    plt.figure(figsize=(6,5))
    sns.barplot(x=list(eo_dict.keys()), y=tpr_vals, alpha=0.7, label="TPR")
    sns.barplot(x=list(eo_dict.keys()), y=fpr_vals, alpha=0.7, label="FPR")
    plt.ylabel("Rate"); plt.xlabel("Group")
    plt.title("Equalized Odds: TPR / FPR per Group")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(WEIGHTS_DIR, "fairness_equalized_odds.png"), dpi=150)
    plt.show()

    # Demographic Parity
    dp_dict = {}
    for g in groups:
        g_mask = (sensitive_attr == g)
        if g_mask.sum() == 0: continue
        dp_dict[g] = y_pred_prob[g_mask].mean()
    
    plt.figure(figsize=(6,5))
    sns.barplot(x=list(dp_dict.keys()), y=list(dp_dict.values()))
    plt.title("Demographic Parity (mean predicted probability) per group")
    plt.ylabel("Mean predicted probability")
    plt.xlabel("Group")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(WEIGHTS_DIR, "fairness_demographic_parity.png"), dpi=150)
    plt.show()


# ================== EVAL ON TEST ==================
def eval_on_test(encoder_model, bin_head_model, frft_precomp, use_frft=True):
    WEIGHTS_DIR = FINAL_SAVE_DIR
    TEST_IMAGE_ROOT = f"{DATASET_ROOT}/{PREPROC_DIR}/test"

    global IMAGE_ROOT
    prev_root = IMAGE_ROOT
    IMAGE_ROOT = TEST_IMAGE_ROOT
    try:
        df_test = load_labels(CSV_PATH)
    finally:
        IMAGE_ROOT = prev_root

    print(f"Test rows: {len(df_test)} | real(1)/fake(0): {df_test['is_real'].value_counts().to_dict()}")

    def _decode_resize(path):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=1)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = tf.image.resize(img, IMG_SIZE)
        return img

    def _apply_frft_py(img, frft_obj):
        return frft_obj.frft2_grid_mag01(tf.expand_dims(img, 0))[0]

    def make_plain_ds_from_df(df_sub):
        paths = df_sub["abspath"].values.astype(str)
        age   = df_sub["age_group"].values.astype(np.int32)
        gen   = df_sub["gender"].values.astype(np.int32)
        skin  = df_sub["skin_tone"].values.astype(np.int32)
        real  = df_sub["is_real"].values.astype(np.int32)
        ds = tf.data.Dataset.from_tensor_slices((paths, age, gen, skin, real))
        ds = ds.map(lambda p,a,g,s,r: (_decode_resize(p), a, g, s, r), num_parallel_calls=AUTOTUNE)
        if use_frft and frft_precomp is not None:
            ds = ds.map(lambda x,a,g,s,r: (_apply_frft_py(x, frft_precomp), a,g,s,r), num_parallel_calls=AUTOTUNE)
        ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
        return ds

    def extract_features_eval(ds):
        feats, labels = [], []
        for batch in ds:
            x, *_, ri = batch
            z = encoder_model(tf.cast(x, tf.float32), training=False)
            feats.append(tf.cast(z, tf.float32).numpy())
            labels.append(tf.cast(ri, tf.int32).numpy())
        if not feats:
            return np.zeros((0, LATENT_DIM), dtype=np.float32), np.zeros((0,), dtype=np.int32)
        X = np.concatenate(feats, axis=0).astype(np.float32)
        y = np.concatenate(labels, axis=0).astype(np.int32)
        return X, y

    test_ds = make_plain_ds_from_df(df_test)
    X_te, y_te = extract_features_eval(test_ds)
    if X_te.shape[0] == 0:
        print("No test samples were found or extracted.")
        return

    y_prob = bin_head_model.predict(X_te, batch_size=2048, verbose=1).ravel()

    auc = roc_auc_score(y_te, y_prob)
    ap  = average_precision_score(y_te, y_prob)
    y_pred05 = (y_prob >= 0.5).astype(int)
    acc05 = accuracy_score(y_te, y_pred05)
    tn, fp, fn, tp = confusion_matrix(y_te, y_pred05, labels=[0,1]).ravel()
    fpr05 = fp / (fp + tn + 1e-12)

    fpr, tpr, thr = roc_curve(y_te, y_prob)
    fnr = 1.0 - tpr
    idx = np.nanargmin(np.abs(fpr - fnr))
    eer = max(fpr[idx], fnr[idx])
    thr_eer = thr[idx]
    y_pred_eer = (y_prob >= thr_eer).astype(int)
    tn2, fp2, fn2, tp2 = confusion_matrix(y_te, y_pred_eer, labels=[0,1]).ravel()
    fpr_eer = fp2 / (fp2 + tn2 + 1e-12)

    print(f"\nAUC: {auc:.4f}")
    print(f"AP : {ap:.4f}")
    print(f"ACC @0.5: {acc05:.4f}")
    print(f"FPR @0.5: {fpr05:.4f}")
    print(f"EER: {eer:.4f} at threshold {thr_eer:.6f} (FPR at EER thr: {fpr_eer:.4f})")

    # ----- Plots -----
    # ROC
    plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, label=f'ROC (AUC={auc:.3f})')
    plt.plot([0,1],[0,1], ls='--', lw=1)
    plt.scatter([eer], [1-eer], s=30, label=f'EER={eer:.3f}')
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
    plt.title('ROC Curve (Test)'); plt.legend(); plt.grid(True, alpha=0.3)
    if SAVE_PLOTS:
        os.makedirs(WEIGHTS_DIR, exist_ok=True)
        plt.tight_layout(); plt.savefig(f"{WEIGHTS_DIR}/test_roc_curve.png", dpi=150)
    plt.show()

    # PR
    prec, rec, _ = precision_recall_curve(y_te, y_prob)
    plt.figure(figsize=(6,5))
    plt.plot(rec, prec, label=f'AP={ap:.3f}')
    plt.xlabel('Recall'); plt.ylabel('Precision')
    plt.title('Precision-Recall (Test)'); plt.legend(); plt.grid(True, alpha=0.3)
    if SAVE_PLOTS:
        plt.tight_layout(); plt.savefig(f"{WEIGHTS_DIR}/test_pr_curve.png", dpi=150)
    plt.show()

    # Score hist
    plt.figure(figsize=(6,5))
    plt.hist(y_prob[y_te==0], bins=50, alpha=0.6, label='fake (0)')
    plt.hist(y_prob[y_te==1], bins=50, alpha=0.6, label='real (1)')
    plt.axvline(0.5, ls='--', lw=1)
    plt.axvline(thr_eer, ls=':', lw=1)
    plt.xlabel('Predicted probability (real)'); plt.ylabel('Count')
    plt.title('Score distribution (Test)'); plt.legend()
    if SAVE_PLOTS:
        plt.tight_layout(); plt.savefig(f"{WEIGHTS_DIR}/test_score_hist.png", dpi=150)
    plt.show()

    # Fairness plots
    for attr in ["gender", "skin_tone"]:
        plot_fairness(
            y_true=y_te,
            y_pred_prob=y_prob,
            sensitive_attr=df_test[attr].values,
            groups=np.unique(df_test[attr]),
            WEIGHTS_DIR=WEIGHTS_DIR
        )

    # ========== FAIRNESS METRICS: FPR, MEO, DP, OAE ==========
    print("\n" + "="*60)
    print("FAIRNESS METRICS (Lower = Better)")
    print("="*60)
    
    fairness_results = {}
    
    for attr_name in ["gender", "age_group", "skin_tone"]:
        print(f"\n--- Attribute: {attr_name.upper()} ---")
        sensitive = df_test[attr_name].values
        groups = np.unique(sensitive)
        
        # Compute per-group metrics
        fpr_dict = {}
        tpr_dict = {}
        acc_dict = {}
        pred_rate_dict = {}
        
        for g in groups:
            g_mask = (sensitive == g)
            if g_mask.sum() == 0:
                continue
            
            y_g = y_te[g_mask]
            y_pred_g = y_pred05[g_mask]
            y_prob_g = y_prob[g_mask]
            
            # Confusion matrix
            tn_g, fp_g, fn_g, tp_g = confusion_matrix(y_g, y_pred_g, labels=[0,1]).ravel()
            
            # FPR, TPR, Accuracy
            fpr_g = fp_g / (fp_g + tn_g + 1e-12)
            tpr_g = tp_g / (tp_g + fn_g + 1e-12)
            acc_g = (tp_g + tn_g) / (tp_g + tn_g + fp_g + fn_g + 1e-12)
            pred_rate_g = y_pred_g.mean()  # Positive prediction rate
            
            fpr_dict[g] = fpr_g
            tpr_dict[g] = tpr_g
            acc_dict[g] = acc_g
            pred_rate_dict[g] = pred_rate_g
        
        if len(fpr_dict) < 2:
            print(f"  Skipping (only {len(fpr_dict)} group(s) found)")
            continue
        
        # F_FPR: Fairness in False Positive Rate (max gap in FPR)
        fpr_vals = list(fpr_dict.values())
        f_fpr = max(fpr_vals) - min(fpr_vals)
        
        # F_MEO: Fairness in Mean Equalized Odds (avg disparity in TPR and FPR)
        tpr_vals = list(tpr_dict.values())
        tpr_gap = max(tpr_vals) - min(tpr_vals)
        fpr_gap = max(fpr_vals) - min(fpr_vals)
        f_meo = (tpr_gap + fpr_gap) / 2.0
        
        # F_DP: Fairness in Demographic Parity (gap in positive prediction rate)
        pred_vals = list(pred_rate_dict.values())
        f_dp = max(pred_vals) - min(pred_vals)
        
        # F_OAE: Fairness in Overall Accuracy Equality (gap in accuracy)
        acc_vals = list(acc_dict.values())
        f_oae = max(acc_vals) - min(acc_vals)
        
        print(f"  F_FPR (FPR gap):        {f_fpr:.4f}")
        print(f"  F_MEO (Eq. Odds gap):   {f_meo:.4f}")
        print(f"  F_DP  (Pred. rate gap): {f_dp:.4f}")
        print(f"  F_OAE (Accuracy gap):   {f_oae:.4f}")
        
        # Store for CSV
        fairness_results[f"{attr_name}_F_FPR"] = f_fpr
        fairness_results[f"{attr_name}_F_MEO"] = f_meo
        fairness_results[f"{attr_name}_F_DP"] = f_dp
        fairness_results[f"{attr_name}_F_OAE"] = f_oae
        
        # Per-group breakdown
        print(f"\n  Per-group breakdown:")
        for g in groups:
            if g not in fpr_dict:
                continue
            print(f"    Group {g}: FPR={fpr_dict[g]:.4f}, TPR={tpr_dict[g]:.4f}, Acc={acc_dict[g]:.4f}, PredRate={pred_rate_dict[g]:.4f}")
    
    print("\n" + "="*60 + "\n")
    
    # Save metrics & predictions
    try:
        os.makedirs(WEIGHTS_DIR, exist_ok=True)
        # Merge standard and fairness metrics
        all_metrics = {
            "auc":[auc], "ap":[ap], "acc@0.5":[acc05],
            "fpr@0.5":[fpr05], "eer":[eer], "thr_eer":[thr_eer], "fpr@thr_eer":[fpr_eer]
        }
        # Add fairness metrics
        for k, v in fairness_results.items():
            all_metrics[k] = [v]
        
        pd.DataFrame(all_metrics).to_csv(f"{WEIGHTS_DIR}/test_metrics.csv", index=False)

        pred_df = pd.DataFrame({"abspath": df_test["abspath"].values, "label": df_test["is_real"].values, "p_real": y_prob})
        pred_df.to_csv(f"{WEIGHTS_DIR}/test_predictions.csv", index=False)
        print("\nSaved test evaluation outputs in", WEIGHTS_DIR)
    except Exception as e:
        print("Warning while saving metrics/predictions:", e)

# ================== RUN ====================
if __name__ == "__main__":
    encoder, decoder, classifier, bin_head, frft_pre = main()
    eval_on_test(encoder, bin_head, frft_pre)