In [None]:
"""
sequential_cnn_pruning_full_fixed.py
Full Sequential CNN pruning pipeline (single-file)

- Loads user Sequential CNN (.h5/.keras / SavedModel)
- Sanitizes layer names in .h5 if they contain '/'
- Loads folder dataset (image_dataset_from_directory)
- Computes activation/gradient/variance importance
- Creates masks (keep_ratio)
- Builds masked model (gating layer)
- Fine-tunes masked model
- Structurally prunes model (conv filter removal, prune Dense outputs only)
- Computes FLOPS & timings, evaluates models
- Saves models and masks

Fixes:
- Ensures final layer matches dataset classes (automatically rebuilds final layer if needed)
- Uses appropriate loss function (binary vs sparse categorical) during stat collection and training
- Adds GFLOPS, accuracy reduction, inference timing
"""

import os
import json
import math
import time
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models

# ---------------------------
# USER CONFIG (edit paths)
# ---------------------------
MODEL_PATH = r"D:\college\sem-8\models\garbage_cnn_model.h5"   # path to model (.h5/.keras or SavedModel dir)
DATASET_PATH = r"D:\college\sem-8\dataset\Garbage classification\Garbage classification"      # folder with subfolders per class
SAVE_DIR = r"pruning_output"                                 # where outputs are saved
base_name = os.path.splitext(os.path.basename(MODEL_PATH))[0]


KEEP_RATIO = 0.83            # fraction of channels/units to keep
ALPHA, BETA, GAMMA = 0.4, 0.3, 0.3  # importance weights
BATCH_SIZE = 64
CALIB_BATCHES = 30
FT_EPOCHS = 3
FT_BATCHES_TO_USE = 150
PLOT_RESULTS = True
VERBOSE = True

os.makedirs(SAVE_DIR, exist_ok=True)
tf.random.set_seed(42)
np.random.seed(42)

def log(*args):
    if VERBOSE:
        print("[INFO]", *args)

# ---------------------------
# Safe load model (handles '/' in h5 layer names)
# ---------------------------
def safe_load_model(model_path):
    """
    Try normal load_model; if fails (e.g., '/' in layer names in H5), sanitize model_config JSON and rebuild.
    Returns a Keras model (compiled=False).
    """
    log("Loading model:", model_path)
    # Try direct load first
    try:
        m = tf.keras.models.load_model(model_path, compile=False)
        log("Loaded model directly.")
        return m
    except Exception as e:
        log("Direct load failed:", e)

    # If HDF5, attempt to sanitize layer names in model_config
    try:
        with h5py.File(model_path, "r") as f:
            if "model_config" in f:
                raw = f["model_config"][()]
                if isinstance(raw, bytes):
                    raw = raw.decode("utf-8")
                cfg_json = json.loads(raw)
                changed = False
                for layer in cfg_json.get("config", {}).get("layers", []):
                    cfg = layer.get("config", {})
                    name = cfg.get("name")
                    if isinstance(name, str) and "/" in name:
                        new_name = name.replace("/", "_")
                        cfg["name"] = new_name
                        changed = True
                        log(f"[FIX] layer name: {name} -> {new_name}")
                if changed:
                    fixed_json = json.dumps(cfg_json)
                    model = tf.keras.models.model_from_json(fixed_json)
                    model.load_weights(model_path)
                    log("Loaded model from sanitized JSON + weights.")
                    return model
    except Exception as e2:
        log("H5 sanitization attempt failed:", e2)

    # final fallback: try load_model with safe_mode=False (older TF)
    try:
        m = tf.keras.models.load_model(model_path, compile=False, safe_mode=False)
        log("Loaded model with safe_mode=False.")
        return m
    except Exception as e3:
        log("All load attempts failed:", e3)
        raise RuntimeError("Failed to load model. Ensure path and format are correct.")

# ---------------------------
# Rebuild final layer to match num_classes (safe, best-effort)
# ---------------------------
def ensure_output_matches_dataset(orig_model, num_classes):
    """
    If the model's current output shape doesn't match num_classes, rebuild final output
    to match. Works for Sequential-style networks. Returns new_model, loss_fn.
    """
    # Determine model's output dim
    out_shape = tuple(orig_model.output_shape) if orig_model.output_shape is not None else None
    # If already matches (and for multiclass softmax case), return original and appropriate loss.
    if out_shape is not None:
        if num_classes == 2 and (out_shape[-1] == 1 or out_shape[-1] == 2):
            # binary case: allow Dense(1) or Dense(2) (Dense(2) could be softmax but labels are 0/1)
            log("Model output seems compatible with binary classification.")
            loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False) if out_shape[-1] == 1 else tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
            return orig_model, loss_fn
        if num_classes > 2 and out_shape[-1] == num_classes:
            log("Model output matches dataset classes.")
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
            return orig_model, loss_fn

    # Need to rebuild final layer
    log("Rebuilding model final layer to match num_classes:", num_classes)
    # We'll clone all layers except the last one and then append a new final dense.
    # Use layer.from_config to avoid reusing layer objects in two models.
    new_seq = models.Sequential(name=(orig_model.name or "rebuilt_model") + "_rebuilt")
    # Add InputLayer
    input_shape = orig_model.input_shape[1:]
    new_seq.add(layers.InputLayer(input_shape=tuple(input_shape)))

    # Clone all layers except the last (we will replace last)
    # We'll attempt to copy weights for layers that remain identical
    for layer in orig_model.layers[:-1]:
        try:
            cfg = layer.get_config()
            Cls = layer.__class__
            cloned = Cls.from_config(cfg)
            new_seq.add(cloned)
            # set weights if possible and shapes match
            try:
                w = layer.get_weights()
                if w:
                    new_seq.layers[-1].set_weights(w)
            except Exception:
                # ignore weight copy failures
                pass
        except Exception:
            # fallback: try to append original layer (may cause errors but best-effort)
            try:
                new_seq.add(layer)
            except Exception:
                log("Warning: couldn't clone layer", layer.name, "- skipping weights copy.")

    # Add new final layer depending on num_classes
    if num_classes == 2:
        # binary: Dense(1, activation='sigmoid')
        new_seq.add(layers.Dense(1, activation="sigmoid", name="output_rebuilt"))
        loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    else:
        new_seq.add(layers.Dense(num_classes, activation="softmax", name="output_rebuilt"))
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

    # Try to compile minimal to ensure shape correctness (we'll let caller compile fully later)
    return new_seq, loss_fn

# ---------------------------
# Dataset loader (folder-based)
# ---------------------------
def load_image_folder_dataset(path, image_size, batch_size=BATCH_SIZE):
    log("Loading dataset folder:", path, "image_size:", image_size)
    train_ds = tf.keras.utils.image_dataset_from_directory(
        path,
        validation_split=0.2,
        subset="training",
        seed=42,
        image_size=image_size,
        batch_size=batch_size,
    )
    val_ds = tf.keras.utils.image_dataset_from_directory(
        path,
        validation_split=0.2,
        subset="validation",
        seed=42,
        image_size=image_size,
        batch_size=batch_size,
    )
    rescaler = layers.Rescaling(1.0 / 255)
    train_ds = train_ds.map(lambda x, y: (rescaler(x), y)).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(lambda x, y: (rescaler(x), y)).prefetch(tf.data.AUTOTUNE)
    return train_ds, val_ds

# ---------------------------
# Activation & gradient stats (uses provided loss function)
# ---------------------------
def compute_activation_grad_stats(model, layer_names, dataset, loss_fn, max_batches=CALIB_BATCHES):
    """
    For each layer in layer_names, compute:
      - A: mean(abs(activation)) per filter/unit
      - G: mean(abs(grad wrt activation)) per filter/unit
      - V: variance(activation) per filter/unit
    Returns: dict name -> (A, G, V)
    """
    log("Computing activation & gradient stats...")
    results = {n: [] for n in layer_names}
    grad_results = {n: [] for n in layer_names}
    var_results = {n: [] for n in layer_names}

    batch_count = 0
    for x_batch, y_batch in dataset:
        if batch_count >= max_batches:
            break
        batch_count += 1
        layer_acts = {}
        with tf.GradientTape(persistent=True) as tape:
            a = x_batch
            for layer in model.layers:
                a = layer(a)
                if layer.name in layer_names:
                    tape.watch(a)
                    layer_acts[layer.name] = a
            preds = a
            # compute loss using provided loss_fn (works for both binary & sparse categorical)
            # When loss_fn is a Keras loss instance, call like a function
            # Ensure y_batch dtype is int for sparse CE, float for binary BCE
            try:
                loss_vals = loss_fn(y_batch, preds)
            except Exception:
                # fallback attempt: convert types for binary
                if isinstance(loss_fn, tf.keras.losses.BinaryCrossentropy):
                    loss_vals = loss_fn(tf.cast(y_batch, tf.float32), preds)
                else:
                    loss_vals = loss_fn(y_batch, preds)
            loss = tf.reduce_mean(loss_vals)

        for name in layer_names:
            a = layer_acts[name]
            if len(a.shape) == 4:
                A = tf.reduce_mean(tf.abs(a), axis=(0,1,2)).numpy()
                V = tf.math.reduce_variance(a, axis=(0,1,2)).numpy()
            else:
                A = tf.reduce_mean(tf.abs(a), axis=0).numpy()
                V = tf.math.reduce_variance(a, axis=0).numpy()

            grad = tape.gradient(loss, a)
            if grad is None:
                G = np.zeros_like(A)
            else:
                if len(grad.shape) == 4:
                    G = tf.reduce_mean(tf.abs(grad), axis=(0,1,2)).numpy()
                else:
                    G = tf.reduce_mean(tf.abs(grad), axis=0).numpy()

            results[name].append(A)
            var_results[name].append(V)
            grad_results[name].append(G)
        del tape

    stats = {}
    for name in layer_names:
        A = np.mean(results[name], axis=0)
        V = np.mean(var_results[name], axis=0)
        G = np.mean(grad_results[name], axis=0)
        stats[name] = (A, G, V)
        log(f"{name}: len={len(A)} meanA={A.mean():.6e} meanG={G.mean():.6e} meanV={V.mean():.6e}")
    return stats

# ---------------------------
# Importance scores & masks
# ---------------------------
def compute_importance_scores(stats, alpha=ALPHA, beta=BETA, gamma=GAMMA):
    def normalize(x):
        x = x - x.min()
        if x.max() > 0:
            x = x / x.max()
        return x
    scores = {}
    for name, (A, G, V) in stats.items():
        scores[name] = alpha * normalize(A) + beta * normalize(G) + gamma * normalize(V)
    return scores

def make_masks_from_scores(score_map, keep_ratio=KEEP_RATIO):
    masks = {}
    for name, score in score_map.items():
        k = max(1, int(len(score) * keep_ratio))
        thresh = np.partition(score, -k)[-k]
        mask = (score >= thresh).astype(np.float32)
        masks[name] = mask
        log(f"{name}: keep {int(mask.sum())}/{len(mask)}")
    return masks

# ---------------------------
# Mask gate layer & masked model
# ---------------------------
class CNNGate(tf.keras.layers.Layer):
    def __init__(self, channels, mask=None, **kwargs):
        """
        channels: int number of channels this gate controls
        mask: 1D array-like of length 'channels' with 0/1 values (or floats in [0,1]).
              If provided, gate variable is initialized to these values; otherwise ones.
        """
        super().__init__(**kwargs)
        self.channels = int(channels)
        # store mask as numpy array if provided (for serialization convenience)
        self._init_mask = None if mask is None else np.array(mask, dtype=np.float32)
        init_val = self._init_mask if self._init_mask is not None else np.ones((self.channels,), dtype=np.float32)
        # gate is non-trainable scalar per-channel multiplier
        self.gate = self.add_weight(
            name="gate",
            shape=(self.channels,),
            initializer=tf.keras.initializers.Constant(init_val),
            trainable=False,
            dtype=tf.float32,
        )

    def call(self, inputs):
        # supports inputs with shape [batch, H, W, C] or [batch, C] (works broadcasting)
        g = self.gate
        # expand dims to match channels in conv output
        if len(inputs.shape) == 4:
            return inputs * g[None, None, None, :]
        elif len(inputs.shape) == 2:
            return inputs * g[None, :]
        else:
            # fallback broadcasting
            return inputs * g

    def get_config(self):
        cfg = super().get_config()
        # do NOT embed gate values directly here (weights are saved by Keras),
        # but we include channels for reconstruction convenience.
        cfg.update({
            "channels": self.channels,
            # don't include mask/gate here to avoid duplicating weight data;
            # the gate weight will be saved/loaded normally by Keras.
        })
        return cfg

    @classmethod
    def from_config(cls, config):
        # config may only contain "channels" — weight values will be restored by Keras load.
        channels = config.pop("channels")
        return cls(channels=channels, **config)


def build_masked_model(orig_model, masks):
    """
    Build a new functional model with cloned layers from orig_model and insert a CNNGate
    after each layer named in `masks`. The gate is initialized from masks[layer_name].
    """
    log("Building masked model with gates (cloning layers)...")
    inp = tf.keras.Input(shape=orig_model.input_shape[1:])
    x = inp

    # keep mapping from original layer name -> new layer object (for weight copying)
    for layer in orig_model.layers:
        # clone layer if possible to avoid reusing original layer objects
        try:
            cfg = layer.get_config()
            Cls = layer.__class__
            new_layer = Cls.from_config(cfg)
        except Exception:
            # fallback: try to reuse the layer (less safe)
            new_layer = layer

        # apply the new layer to current tensor
        x = new_layer(x)

        # copy weights if layer had weights and we cloned it
        try:
            w = layer.get_weights()
            if w:
                try:
                    new_layer.set_weights(w)
                except Exception:
                    # some layers (e.g., fused ops) may not accept direct set_weights — ignore
                    pass
        except Exception:
            pass

        # insert gate if this original layer is in masks
        if layer.name in masks:
            mask = np.array(masks[layer.name], dtype=np.float32)
            channels = int(mask.shape[0])
            gate_layer = CNNGate(channels=channels, mask=mask, name=layer.name + "_gate")
            x = gate_layer(x)
            # gate weight is already initialized from mask in CNNGate.__init__, so nothing else to do

    masked = tf.keras.Model(inputs=inp, outputs=x, name=(orig_model.name or "model") + "_masked")
    log("Masked model created:", masked.name)
    return masked

# ---------------------------
# Structural pruning (safe)
# ---------------------------
def prune_structural_sequential(orig_model, masks, input_shape):
    """
    Structural pruning for Sequential models.
    - Prune Conv2D output filters using masks[layer.name]
    - Prune Dense output units only (do not slice Dense input rows that come from Flatten/Conv)
    - Attempt to slice BatchNorm params to match conv outputs
    """
    log("Structural pruning (safe) start...")
    new_layers = []
    prev_was_conv_like = False  # indicates that Flatten/Conv preceded Dense inputs

    for layer in orig_model.layers:
        # Conv2D: slice output channels
        if isinstance(layer, layers.Conv2D):
            W, b = layer.get_weights()
            orig_out = W.shape[-1]
            mask = masks.get(layer.name, np.ones(orig_out, dtype=np.float32))
            keep_idx = np.where(mask == 1)[0]
            if keep_idx.size == 0:
                keep_idx = np.array([int(np.argmax(mask))], dtype=np.int32)
            W_new = W[:, :, :, keep_idx]
            b_new = b[keep_idx]
            new_conv = layers.Conv2D(
                filters=len(keep_idx),
                kernel_size=layer.kernel_size,
                strides=layer.strides,
                padding=layer.padding,
                activation=layer.activation,
                use_bias=layer.use_bias,
                name=layer.name + "_pruned"
            )
            new_layers.append((new_conv, [W_new, b_new]))
            prev_was_conv_like = True
            continue

        # BatchNorm: slice params if prev was conv-like
        if isinstance(layer, layers.BatchNormalization):
            try:
                weights = layer.get_weights()
                if prev_was_conv_like and new_layers:
                    prev_layer_obj, prev_w = new_layers[-1]
                    if prev_w is not None:
                        out_ch = prev_w[0].shape[-1]  # kernel last dim
                        gamma, beta, mean, var = weights
                        gamma = gamma[:out_ch]
                        beta = beta[:out_ch]
                        mean = mean[:out_ch]
                        var = var[:out_ch]
                        new_bn = layers.BatchNormalization.from_config(layer.get_config())
                        new_layers.append((new_bn, [gamma, beta, mean, var]))
                        continue
            except Exception:
                pass
            # fallback keep BN as-is
            try:
                new_bn = layers.BatchNormalization.from_config(layer.get_config())
                new_layers.append((new_bn, layer.get_weights()))
            except Exception:
                new_layers.append((layer, layer.get_weights() if hasattr(layer, "get_weights") else None))
            # prev_was_conv_like unchanged
            continue

        # MaxPool/Activation/Dropout/Flatten/GlobalAvgPool: clone or reuse
        if isinstance(layer, (layers.MaxPooling2D, layers.Activation, layers.ReLU, layers.Dropout, layers.Flatten, layers.GlobalAveragePooling2D)):
            try:
                cloned = layer.__class__.from_config(layer.get_config())
                w = layer.get_weights() if hasattr(layer, "get_weights") else None
                new_layers.append((cloned, w if w else None))
            except Exception:
                new_layers.append((layer, layer.get_weights() if hasattr(layer, "get_weights") else None))
            if isinstance(layer, layers.Flatten) or isinstance(layer, layers.GlobalAveragePooling2D):
                prev_was_conv_like = True
            continue

        # Dense: prune outputs only (safe)
        if isinstance(layer, layers.Dense):
            W, b = layer.get_weights()  # shape (in_dim, out_dim)
            out_mask = masks.get(layer.name, np.ones(W.shape[1], dtype=np.float32))
            out_idx = np.where(out_mask == 1)[0]
            if out_idx.size == 0:
                out_idx = np.array([int(np.argmax(out_mask))], dtype=np.int32)
            W_new = W[:, out_idx]   # keep all input rows (safe)
            b_new = b[out_idx]
            new_dense = layers.Dense(units=W_new.shape[1], activation=layer.activation, name=layer.name + "_pruned")
            new_layers.append((new_dense, [W_new, b_new]))
            prev_was_conv_like = False
            continue

        # Fallback for other layers
        try:
            cloned = layer.__class__.from_config(layer.get_config())
            w = layer.get_weights() if hasattr(layer, "get_weights") else None
            new_layers.append((cloned, w if w else None))
        except Exception:
            new_layers.append((layer, layer.get_weights() if hasattr(layer, "get_weights") else None))
        prev_was_conv_like = False

    # Build new Sequential model with InputLayer
    seq = models.Sequential(name=orig_model.name + "_struct_pruned")
    seq.add(layers.InputLayer(input_shape=tuple(input_shape)))
    for lyr_obj, w in new_layers:
        seq.add(lyr_obj)
        if w is not None:
            try:
                seq.layers[-1].set_weights(w)
            except Exception as e:
                log("Warning: couldn't set weights for", seq.layers[-1].name, ":", e)
    log("Structural pruning complete. New model summary:")
    seq.summary()
    return seq

# ---------------------------
# FLOPS and timing helpers
# ---------------------------
def calculate_conv_flops(input_shape, kernel_shape, strides=(1,1), padding='same'):
    h_in, w_in, c_in = input_shape
    kh, kw, _, c_out = kernel_shape
    if padding == 'same':
        h_out = math.ceil(h_in / strides[0])
        w_out = math.ceil(w_in / strides[1])
    else:
        h_out = math.ceil((h_in - kh + 1) / strides[0])
        w_out = math.ceil((w_in - kw + 1) / strides[1])
    flops = h_out * w_out * (kh * kw * c_in) * c_out * 2
    return flops, (h_out, w_out, c_out)

def calculate_dense_flops(in_size, out_size):
    return in_size * out_size * 2

def calculate_model_flops(model, input_shape):
    total = 0
    current_shape = tuple(input_shape)
    for layer in model.layers:
        if isinstance(layer, layers.Conv2D):
            weights = layer.get_weights()
            if not weights:
                continue
            kernel_shape = weights[0].shape  # (kh, kw, in_c, out_c)
            layer_flops, current_shape = calculate_conv_flops(current_shape, kernel_shape, strides=layer.strides, padding=layer.padding)
            total += layer_flops
        elif isinstance(layer, layers.Flatten):
            current_shape = (int(np.prod(current_shape)),)
        elif isinstance(layer, layers.Dense):
            in_size = current_shape[0] if isinstance(current_shape, tuple) and len(current_shape)>0 else int(current_shape)
            layer_flops = calculate_dense_flops(in_size, layer.units)
            total += layer_flops
            current_shape = (layer.units,)
        elif isinstance(layer, layers.MaxPooling2D):
            h,w,c = current_shape
            pool = layer.pool_size[0] if hasattr(layer.pool_size, "__getitem__") else layer.pool_size
            current_shape = (h//pool, w//pool, c)
        else:
            # ignore other layers for shape changes
            pass
    return total

def measure_inference_time(model, sample_batch, steps=20):
    model.predict(sample_batch, verbose=0)  # warmup
    t0 = time.time()
    for _ in range(steps):
        model.predict(sample_batch, verbose=0)
    t1 = time.time()
    return (t1 - t0) / steps

# ---------------------------
# Save masks helper
# ---------------------------
def save_masks(masks, path):
    serial = {k: v.tolist() for k,v in masks.items()}
    with open(path, "w") as f:
        json.dump(serial, f, indent=2)
    log("Saved masks to", path)

# ---------------------------
# Plot helper
# ---------------------------
def plot_mask_histograms(masks, outdir=SAVE_DIR):
    if not PLOT_RESULTS:
        return
    for name, mask in masks.items():
        plt.figure(figsize=(5,2))
        plt.title(name)
        plt.hist(mask, bins=2)
        plt.xlabel("0=pruned, 1=kept")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"mask_{name}.png"))
        plt.close()

# ---------------------------
# MAIN pipeline
# ---------------------------
def full_pipeline(model_path, dataset_path):
    # 1) load model (safe)
    model = safe_load_model(model_path)
    log("Model loaded. Summary:")
    model.summary()

    # 2) load dataset and infer num_classes
    # We need input size so infer from model
    input_shape = model.input_shape[1:]
    log("Inferred input shape:", input_shape)

    train_ds, val_ds = load_image_folder_dataset(dataset_path, image_size=input_shape[:2], batch_size=BATCH_SIZE)
    # Determine number of classes from dataset
    try:
        # tf.data.Dataset from image_dataset_from_directory has .class_names on the original Dataset returned object,
        # but not on the batched dataset — so inspect via a fresh loader:
        tmp = tf.keras.utils.image_dataset_from_directory(dataset_path, image_size=input_shape[:2], batch_size=1)
        num_classes = len(tmp.class_names)
        del tmp
    except Exception:
        # fallback: infer from labels in train_ds
        classes = set()
        for _, y in train_ds.take(10):
            classes.update(y.numpy().tolist())
        num_classes = max(classes) + 1 if classes else 2

    log("Detected dataset classes (num_classes):", num_classes)

    # 3) ensure model output matches dataset classes
    model, loss_fn = ensure_output_matches_dataset(model, num_classes)
    # compile original so evaluate works (use small lr default) with detected loss
    if isinstance(loss_fn, tf.keras.losses.BinaryCrossentropy):
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss_fn, metrics=["accuracy"])
    else:
        model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss_fn, metrics=["accuracy"])

    # reprint summary
    log("Final model used (after potential rebuild). Summary:")
    model.summary()

    # calibration subset for stats
    calib_ds = train_ds.take(CALIB_BATCHES)

    # sample for timing
    try:
        sample_x, _ = next(iter(val_ds))
    except Exception:
        sample_x, _ = next(iter(train_ds))
    sample_x_small = sample_x[:min(16, sample_x.shape[0])]

    # 4) choose layers to prune (conv + hidden dense only, not final output)
    dense_layers = [lyr for lyr in model.layers if isinstance(lyr, layers.Dense)]
    last_dense = dense_layers[-1] if dense_layers else None

    prune_layer_names = []
    for lyr in model.layers:
        if isinstance(lyr, layers.Conv2D):
            prune_layer_names.append(lyr.name)
        elif isinstance(lyr, layers.Dense) and lyr is not last_dense:
            prune_layer_names.append(lyr.name)
    log("Layers considered for pruning:", prune_layer_names)

    # 5) compute stats (pass loss_fn)
    stats = compute_activation_grad_stats(model, prune_layer_names, calib_ds, loss_fn=loss_fn, max_batches=CALIB_BATCHES)

    # 6) importance & masks
    score_map = compute_importance_scores(stats)
    masks = make_masks_from_scores(score_map, keep_ratio=KEEP_RATIO)
    save_masks(masks, os.path.join(SAVE_DIR, "masks.json"))
    plot_mask_histograms(masks)

    # 7) build masked model and compile (use same loss)
    masked_model = build_masked_model(model, masks)
    masked_model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss=loss_fn, metrics=["accuracy"])
    log("Masked model built.")

    # quick eval before FT (small subset)
        # quick eval before FT (small subset)
    try:
        loss0, acc0 = masked_model.evaluate(val_ds.take(5), verbose=0)
        log("Masked model pre-FT acc:", acc0)
    except Exception as e:
        log("Masked model pre-eval failed:", e)
        acc0 = None

    # 8) measure baseline flops & time
    baseline_flops = calculate_model_flops(model, input_shape)
    baseline_time = measure_inference_time(model, sample_x_small, steps=10)
    log(f"Baseline FLOPS: {baseline_flops:,}, baseline time (avg batch): {baseline_time:.4f}s")

    # 9) fine-tune masked model
    try:
        masked_model.fit(
            train_ds.take(FT_BATCHES_TO_USE),
            validation_data=val_ds.take(5),
            epochs=FT_EPOCHS,
            verbose=2
        )
    except Exception as e:
        log("Masked fine-tune failed/partial:", e)

    # 10) structural prune
    # NOTE: you can prune the original model or the masked_model.
    #       pruning masked_model preserves gating decisions — often desired.
    try:
        pruned_model = prune_structural_sequential(masked_model, masks, input_shape)
    except Exception:
        # fallback to pruning original model if pruning masked_model fails
        pruned_model = prune_structural_sequential(model, masks, input_shape)

    # compile pruned model for evaluation / training with same loss
    pruned_model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss=loss_fn,
        metrics=["accuracy"]
    )

    log("Starting fine-tuning of PRUNED model (EarlyStopping)...")
    early_cb = tf.keras.callbacks.EarlyStopping(
        monitor="val_accuracy",
        patience=2,
        restore_best_weights=True
    )
    try:
        pruned_model.fit(
            train_ds,
            validation_data=val_ds,
            epochs=20,
            batch_size=BATCH_SIZE,
            callbacks=[early_cb],
            verbose=2
        )
    except Exception as e:
        log("Pruned model fine-tune failed/partial:", e)

    # 11) evaluate models
    log("Evaluating Original model:")
    orig_loss, orig_acc = model.evaluate(val_ds, verbose=0)
    log("Original acc:", orig_acc)

    log("Evaluating Masked model (after FT):")
    try:
        mask_loss, mask_acc = masked_model.evaluate(val_ds, verbose=0)
        log("Masked acc:", mask_acc)
    except Exception as e:
        log("Masked evaluate failed:", e)
        mask_acc = None

    log("Evaluating Pruned model:")
    try:
        pruned_loss, pruned_acc = pruned_model.evaluate(val_ds, verbose=0)
        log("Pruned acc:", pruned_acc)
    except Exception as e:
        log("Pruned evaluate failed:", e)
        pruned_acc = None

    # 12) FLOPS & timing after prune
    pruned_flops = calculate_model_flops(pruned_model, input_shape)
    pruned_time = measure_inference_time(pruned_model, sample_x_small, steps=10)
    log(f"Pruned FLOPS: {pruned_flops:,}, pruned time: {pruned_time:.4f}s")

    # 13) summary & save
    reduction = 1.0 - (pruned_flops / baseline_flops) if baseline_flops > 0 else 0.0
    log("="*60)
    log("SUMMARY:")
    log(f"Baseline FLOPS: {baseline_flops:,}")
    log(f"Pruned FLOPS: {pruned_flops:,}")
    log(f"FLOPS reduction: {reduction:.2%}")
    log(f"Original acc: {orig_acc}, Masked acc: {mask_acc}, Pruned acc: {pruned_acc}")
    log("="*60)

    # ---- Extra Metrics: GFLOPS + Accuracy Reduction ----
    baseline_gflops = baseline_flops / 1e9
    pruned_gflops = pruned_flops / 1e9
    gflops_reduction = 1.0 - (pruned_gflops / baseline_gflops) if baseline_gflops > 0 else 0.0
    acc_reduction = (orig_acc - pruned_acc) if (orig_acc is not None and pruned_acc is not None) else None

    log(f"Baseline GFLOPS: {baseline_gflops:.4f}")
    log(f"Pruned GFLOPS: {pruned_gflops:.4f}")
    log(f"GFLOPS reduction: {gflops_reduction:.2%}")
    if acc_reduction is not None:
        log(f"Accuracy reduction: {acc_reduction:.4f}")
    else:
        log("Accuracy reduction: N/A")

    # save artifacts
    try:
        # use proper file extensions to avoid Keras save errors
        baseline_name = base_name + "_baseline.keras"
        masked_name   = base_name + "_masked.keras"
        pruned_name   = base_name + "_pruned.keras"

        model.save(os.path.join(SAVE_DIR, baseline_name))
        masked_model.save(os.path.join(SAVE_DIR, masked_name))
        pruned_model.save(os.path.join(SAVE_DIR, pruned_name))

        log(f"Saved models under names: {baseline_name}, {masked_name}, {pruned_name}")

        log("Saved models to", SAVE_DIR)
    except Exception as e:
        log("Save models failed:", e)

    return {
        "model": model,
        "masked_model": masked_model,
        "pruned_model": pruned_model,
        "masks": masks,
        "baseline_flops": baseline_flops,
        "pruned_flops": pruned_flops,
        "baseline_time": baseline_time,
        "pruned_time": pruned_time,
        "orig_acc": orig_acc,
        "mask_acc": mask_acc,
        "pruned_acc": pruned_acc
    }

# ---------------------------
# Run
# ---------------------------
if __name__ == "__main__":
    out = full_pipeline(MODEL_PATH, DATASET_PATH)
    log("Pipeline finished. Outputs:", out.keys())
