In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

# ---------------- CONFIG ----------------
MODEL_PATH   = "/content/mobilenet_f.keras"
DATASET_PATH = "/root/.cache/kagglehub/datasets/puneet6060/intel-image-classification/versions/2/seg_train/seg_train"


BATCH = 32
KEEP_RATIO = 0.85
ALPHA = 0.1
EPOCHS = 10
PATIENCE = 3
UNFREEZE_LAST = 20
IMG_SIZE = (128, 128)

# ---------------- DATA ----------------
def load_ds(path, img_size):
    train = tf.keras.utils.image_dataset_from_directory(
        path,
        validation_split=0.2,
        subset="training",
        seed=42,
        image_size=img_size,
        batch_size=BATCH
    )
    val = tf.keras.utils.image_dataset_from_directory(
        path,
        validation_split=0.2,
        subset="validation",
        seed=42,
        image_size=img_size,
        batch_size=BATCH
    )

    scale = layers.Rescaling(1. / 255)
    train = train.map(lambda x, y: (scale(x), y)).prefetch(tf.data.AUTOTUNE)
    val   = val.map(lambda x, y: (scale(x), y)).prefetch(tf.data.AUTOTUNE)
    return train, val


# ---------------- FIND PRUNABLE CONVS ----------------
# MobileNet rule:
# ✔ Conv2D kernel=(1,1)  (pointwise)
# ✘ DepthwiseConv2D
def find_prunable_convs(model):
    convs = []

    def walk(l):
        if isinstance(l, layers.Conv2D):
            if l.kernel_size == (1, 1):
                convs.append(l)
        if isinstance(l, tf.keras.Model):
            for x in l.layers:
                walk(x)

    walk(model)
    return convs


# ---------------- SOFT CHANNEL MASK ----------------
@tf.keras.utils.register_keras_serializable()
class ChannelMask(layers.Layer):
    def __init__(self, mask, alpha=0.1, **kwargs):
        super().__init__(**kwargs)
        mask = np.asarray(mask, dtype="float32")
        mask = np.where(mask > 0, 1.0, alpha)
        mask = mask / np.mean(mask)  # activation scale stability
        self.mask_init = mask
        self.alpha = alpha

    def build(self, input_shape):
        self.mask = self.add_weight(
            name="mask",
            shape=(input_shape[-1],),
            initializer=tf.constant_initializer(self.mask_init),
            trainable=False
        )

    def call(self, x):
        return x * self.mask[None, None, None, :]

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "mask": self.mask_init.tolist(),
            "alpha": self.alpha
        })
        return cfg

    @classmethod
    def from_config(cls, config):
        mask = config.pop("mask")
        return cls(mask=mask, **config)


# ---------------- MASK CREATION ----------------
def make_masks(convs):
    masks = {}
    for l in convs:
        W = l.kernel.numpy()
        scores = np.mean(np.abs(W), axis=(0, 1, 2))
        k = max(1, int(len(scores) * KEEP_RATIO))
        thr = np.partition(scores, -k)[-k]
        mask = (scores >= thr).astype("float32")
        masks[l.name] = mask
        print(f"[PRUNE] {l.name}: keep {int(mask.sum())}/{len(mask)}")
    return masks


# ---------------- GRAPH-SAFE MASK INSERT ----------------
def build_masked_mobilenet(base_model, masks):
    tensor_map = {}

    for inp in base_model.inputs:
        tensor_map[inp] = inp

    for layer in base_model.layers:

        if isinstance(layer, layers.InputLayer):
            continue

        inbound = layer.input
        if isinstance(inbound, list):
            mapped = [tensor_map.get(x, x) for x in inbound]
        else:
            mapped = tensor_map.get(inbound, inbound)

        x = layer(mapped)

        # Insert mask AFTER BatchNorm if previous layer was Conv2D(1x1)
        if isinstance(layer, layers.BatchNormalization):
            prev = layer.input._keras_history[0]
            if isinstance(prev, layers.Conv2D):
                if prev.kernel_size == (1, 1) and prev.name in masks:
                    x = ChannelMask(
                        masks[prev.name],
                        alpha=ALPHA,
                        name=prev.name + "_mask"
                    )(x)

        tensor_map[layer.output] = x

    return tf.keras.Model(
        inputs=base_model.inputs,
        outputs=x,
        name=base_model.name + "_masked"
    )


# ---------------- FLOPs ----------------
# Count ONLY pointwise convs
def baseline_flops(model):
    total = 0
    for l in find_prunable_convs(model):
        if not l._inbound_nodes:
            continue
        _, h, w, cout = l.output.shape
        cin = l.input.shape[-1]
        total += h * w * cin * cout * 2
    return int(total)


def effective_flops_from_masks(model, masks):
    total = 0
    for l in find_prunable_convs(model):
        if l.name in masks and l._inbound_nodes:
            _, h, w, _ = l.output.shape
            cin = l.input.shape[-1]
            active = int(np.sum(masks[l.name]))
            total += h * w * cin * active * 2
    return int(total)


# ---------------- MAIN ----------------
def run():
    print("[INFO] Loading base MobileNet...")
    base = tf.keras.models.load_model(MODEL_PATH, compile=False)

    # force graph build
    dummy = tf.zeros((1,) + base.input_shape[1:])
    _ = base(dummy)

    train, val = load_ds(DATASET_PATH, base.input_shape[1:3])

    convs = find_prunable_convs(base)
    masks = make_masks(convs)

    masked = build_masked_mobilenet(base, masks)

    # force masked graph build
    _ = masked(dummy)

    for l in masked.layers[:-UNFREEZE_LAST]:
        l.trainable = False

    masked.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    early = tf.keras.callbacks.EarlyStopping(
        patience=PATIENCE,
        restore_best_weights=True,
        monitor="val_loss"
    )

    masked.fit(
        train,
        validation_data=val,
        epochs=EPOCHS,
        callbacks=[early],
        verbose=1
    )

    masked.save("masked_finetuned_mobilenet_Intel.keras")

    b = baseline_flops(base)
    e = effective_flops_from_masks(base, masks)

    print("\n=========== EFFECTIVE FLOPs ANALYSIS ===========")
    print(f"Baseline FLOPs : {b:,}")
    print(f"Effective FLOPs: {e:,}")
    print(f"Reduction (%)  : {(b - e) / b * 100:.2f}%")
    print("===============================================")


# ---------------- RUN ----------------
if __name__ == "__main__":
    run()
