In [None]:
"""
UNIVERSAL TRANSFORMER ATTENTION HEAD PRUNING + MASKED FINE-TUNING
(FINAL, GUARANTEED-RUNNING VERSION)

✔ Any CSV or JSON Lines dataset
✔ Label = last column fallback
✔ Labels aligned to model output
✔ Binary + multiclass supported
✔ Fresh TextVectorization
✔ Soft attention-head pruning
✔ Masked fine-tuning (graph-safe)
✔ Correct FLOPs / Effective FLOPs
✔ Handles nested Transformer blocks
✔ Legacy .h5 compatible
"""

# ==========================================================
# IMPORTS
# ==========================================================

import json
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.utils import to_categorical

# ==========================================================
# CONFIG
# ==========================================================

MAX_LEN     = 200
VOCAB_SIZE = 20000
BATCH_SIZE = 64
EPOCHS     = 5
KEEP_RATIO = 0.7

# ==========================================================
# CUSTOM LAYERS (LEGACY SAFE)
# ==========================================================

@tf.keras.utils.register_keras_serializable(package="Custom")
class PositionalEmbedding(layers.Layer):
    def __init__(
        self,
        max_len=None,
        vocab_size=None,
        embed_dim=None,
        maxlen=None,          # legacy support
        **kwargs
    ):
        super().__init__(**kwargs)

        if max_len is None and maxlen is not None:
            max_len = maxlen

        if max_len is None:
            raise ValueError("max_len or maxlen must be provided")

        self.max_len = int(max_len)
        self.vocab_size = int(vocab_size)
        self.embed_dim = int(embed_dim)

        self.token_emb = layers.Embedding(self.vocab_size, self.embed_dim)
        self.pos_emb   = layers.Embedding(self.max_len, self.embed_dim)

    def call(self, x):
        pos = tf.range(start=0, limit=tf.shape(x)[-1])
        return self.token_emb(x) + self.pos_emb(pos)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "maxlen": self.max_len,
            "vocab_size": self.vocab_size,
            "embed_dim": self.embed_dim
        })
        return cfg


@tf.keras.utils.register_keras_serializable(package="Custom")
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate

        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            layers.Dense(ff_dim, activation="relu"),
            layers.Dense(embed_dim),
        ])
        self.ln1 = layers.LayerNormalization(epsilon=1e-6)
        self.ln2 = layers.LayerNormalization(epsilon=1e-6)
        self.drop1 = layers.Dropout(rate)
        self.drop2 = layers.Dropout(rate)

        self.last_attn = None

    def call(self, x, training=None):
        attn = self.att(x, x, training=training)
        self.last_attn = attn
        x = self.ln1(x + self.drop1(attn, training=training))
        ffn = self.ffn(x, training=training)
        return self.ln2(x + self.drop2(ffn, training=training))

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "embed_dim": self.embed_dim,
            "num_heads": self.num_heads,
            "ff_dim": self.ff_dim,
            "rate": self.rate
        })
        return cfg


TransformerBlock = TransformerEncoder

# ==========================================================
# MASKED TRANSFORMER
# ==========================================================

class MaskedTransformer(tf.keras.Model):
    def __init__(self, base_model, masks):
        super().__init__()
        self.base = base_model
        self.masks = masks

    def call(self, inputs, training=None):
        x = inputs
        for layer in self.base.layers:
            if isinstance(layer, tf.keras.layers.InputLayer):
                continue

            if isinstance(layer, TransformerEncoder):
                attn = layer.att(x, x, training=training)

                if layer in self.masks:
                    H = layer.att.num_heads
                    D = attn.shape[-1]
                    Hd = D // H
                    m = self.masks[layer]

                    attn = tf.reshape(attn, (-1, tf.shape(attn)[1], H, Hd))
                    attn = attn * m[None, None, :, None]
                    attn = tf.reshape(attn, (-1, tf.shape(attn)[1], D))

                x = layer.ln1(x + layer.drop1(attn, training=training))
                x = layer.ln2(x + layer.drop2(layer.ffn(x), training=training))
            else:
                x = layer(x, training=training)

        return x

# ==========================================================
# DATASET LOADER (ALIGNED TO MODEL)
# ==========================================================

def load_text_dataset(path, model_num_outputs):

    if path.endswith(".csv"):
        df = pd.read_csv(path)
    elif path.endswith(".json"):
        rows = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                rows.append(json.loads(line))
        df = pd.DataFrame(rows)
    else:
        raise ValueError("Only CSV and JSON supported")

    print("[INFO] Dataset columns:", list(df.columns))

    text_cols = [c for c in df.columns if df[c].dtype == object]
    texts = df[text_cols].fillna("").agg(" ".join, axis=1).values

    label_col = df.columns[-1]
    print(f"[INFO] Using label column: {label_col}")

    labels = df[label_col].values

    if pd.api.types.is_numeric_dtype(labels):
        y_int = labels.astype(int)
    else:
        y_int = LabelEncoder().fit_transform(labels)

    # ======================================================
    # CRITICAL: BINARY vs MULTICLASS HANDLING
    # ======================================================

    if model_num_outputs == 1:
        # Dense(1) → sigmoid → binary
        y = y_int.astype("float32")
        loss = "binary_crossentropy"

    else:
        # Dense(N) → softmax
        y_int = np.clip(y_int, 0, model_num_outputs - 1)
        y = to_categorical(y_int, num_classes=model_num_outputs)
        loss = "categorical_crossentropy"

    return texts, y, loss

# ==========================================================
# TEXT VECTORIZATION
# ==========================================================

def build_vectorizer(texts):
    vec = layers.TextVectorization(
        max_tokens=VOCAB_SIZE,
        output_mode="int",
        output_sequence_length=MAX_LEN
    )
    vec.adapt(texts)
    return vec

# ==========================================================
# STRIP TRANSFORMER
# ==========================================================

def strip_transformer(model):
    if model.input.dtype in (tf.int32, tf.int64):
        return model

    for layer in model.layers:
        try:
            if layer.input.dtype in (tf.int32, tf.int64):
                return tf.keras.Model(layer.input, model.output)
        except:
            pass

    token_input = tf.keras.Input(shape=(MAX_LEN,), dtype=tf.int32)
    return tf.keras.Model(token_input, model(token_input))

# ==========================================================
# UTILITIES
# ==========================================================

def get_all_layers(model):
    out = []
    for l in model.layers:
        out.append(l)
        if isinstance(l, tf.keras.Model):
            out.extend(get_all_layers(l))
    return out

# ==========================================================
# ATTENTION STATS + FLOPs
# ==========================================================

def compute_attention_head_stats(model, vectorizer, texts, max_batches=20):
    blocks = [l for l in get_all_layers(model) if isinstance(l, TransformerEncoder)]
    stats = {b: [] for b in blocks}

    ds = tf.data.Dataset.from_tensor_slices(texts).batch(BATCH_SIZE)
    for i, xb in enumerate(ds):
        if i >= max_batches:
            break

        tokens = tf.cast(vectorizer(xb), tf.int32)
        _ = model(tokens, training=False)

        for block in blocks:
            if block.last_attn is None:
                continue
            attn = block.last_attn
            H = block.att.num_heads
            D = attn.shape[-1]
            Hd = D // H
            reshaped = tf.reshape(attn, (-1, attn.shape[1], H, Hd))
            score = tf.reduce_mean(tf.abs(reshaped), axis=[0, 1, 3]).numpy()
            stats[block].append(score)

    return {k: np.mean(v, axis=0) for k, v in stats.items()}


def compute_importance_mask(stats, keep_ratio):
    masks = {}
    for block, score in stats.items():
        k = max(1, int(len(score) * keep_ratio))
        th = np.partition(score, -k)[-k]
        masks[block] = (score >= th).astype(np.float32)
    return masks


def attention_flops(seq_len, embed_dim):
    return 4 * seq_len * embed_dim * embed_dim + 2 * seq_len * seq_len * embed_dim


def transformer_model_flops(model):
    flops = 0
    for l in get_all_layers(model):
        if isinstance(l, TransformerEncoder):
            D = l.att.key_dim * l.att.num_heads
            flops += attention_flops(MAX_LEN, D)
    return flops


def effective_transformer_flops(model, masks):
    flops = 0
    for l in get_all_layers(model):
        if isinstance(l, TransformerEncoder):
            D = l.att.key_dim * l.att.num_heads
            H = l.att.num_heads
            kept = int(np.sum(masks.get(l, np.ones(H))))
            flops += attention_flops(MAX_LEN, D) * (kept / H)
    return flops

# ==========================================================
# MAIN PIPELINE
# ==========================================================

def universal_transformer_pruning(model_path, dataset_path):

    model = tf.keras.models.load_model(
        model_path,
        custom_objects={
            "PositionalEmbedding": PositionalEmbedding,
            "TransformerEncoder": TransformerEncoder,
            "TransformerBlock": TransformerEncoder,
        },
        compile=False
    )

    model_num_outputs = model.output_shape[-1]
    print(f"[INFO] Model output units: {model_num_outputs}")

    texts, y, loss_fn = load_text_dataset(dataset_path, model_num_outputs)

    Xtr, Xv, ytr, yv = train_test_split(
        texts, y,
        test_size=0.2,
        random_state=42,
        stratify=y if model_num_outputs == 1 else y.argmax(axis=1)
    )

    vectorizer = build_vectorizer(Xtr)
    transformer = strip_transformer(model)

    stats = compute_attention_head_stats(transformer, vectorizer, Xtr)
    masks = compute_importance_mask(stats, KEEP_RATIO)

    masked_model = MaskedTransformer(transformer, masks)
    masked_model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-4),
        loss=loss_fn,
        metrics=["accuracy"],
        run_eagerly=True
    )

    masked_model.fit(
        tf.cast(vectorizer(Xtr), tf.int32), ytr,
        validation_data=(tf.cast(vectorizer(Xv), tf.int32), yv),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE
    )

    acc = masked_model.evaluate(
        tf.cast(vectorizer(Xv), tf.int32), yv, verbose=0
    )[1]

    orig = transformer_model_flops(transformer)
    eff  = effective_transformer_flops(transformer, masks)

    print("\n=========== FINAL RESULTS ===========")
    print(f"Masked Accuracy     : {acc:.4f}")
    print(f"Original GFLOPs     : {orig / 1e9:.3f}")
    print(f"Effective GFLOPs    : {eff / 1e9:.3f}")
    print(f"FLOPs Reduction (%) : {(1 - eff / orig) * 100:.2f}%")

    masked_model.save("soft_pruned_transformer_imbd.keras")
    print("[INFO] Saved pruned model: soft_pruned_transformer.keras")

# ==========================================================
# RUN
# ==========================================================

if __name__ == "__main__":
    universal_transformer_pruning(
        r"D:\college\sem-8\final\custom_transformer_imdb.h5",
        r"D:\college\sem-8\dataset\tranformer dataset\IMDB Dataset.csv"
    )
