In [None]:
# ---------- 7) Learnable EMHSA layer ----------
class EnhancedMultiHeadSelfAttention(tf.keras.layers.Layer):
    """
    EMHSA: Multi-head self-attention with a learnable positive scaling alpha_h per head.
    alpha_h = softplus(beta_h) + eps
    scores = (Q K^T) / alpha_h
    """
    def __init__(self, units, num_heads=4, eps=1e-6, **kwargs):
        super().__init__(**kwargs)
        assert units % num_heads == 0, "units must be divisible by num_heads"
        self.units = units
        self.num_heads = num_heads
        self.depth = units // num_heads
        self.eps = eps

        self.W_q = Dense(units)
        self.W_k = Dense(units)
        self.W_v = Dense(units)
        self.W_o = Dense(units)
        self.softmax = tf.keras.layers.Softmax(axis=-1)

        # Initialize beta so that alpha â‰ˆ sqrt(depth) (standard scaling)
        init_alpha = float(np.sqrt(self.depth))
        init_beta = np.log(np.exp(init_alpha) - 1.0)  # approx inverse softplus
        self.beta = tf.Variable(
            initial_value=tf.ones([self.num_heads], dtype=tf.float32) * init_beta,
            trainable=True,
            name="beta_head"
        )

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # (B,H,T,D)

    def call(self, inputs):
        B = tf.shape(inputs)[0]

        Q = self.W_q(inputs)
        K_ = self.W_k(inputs)
        V = self.W_v(inputs)

        Q = self.split_heads(Q, B)
        K_ = self.split_heads(K_, B)
        V = self.split_heads(V, B)

        alpha = tf.nn.softplus(self.beta) + self.eps  # (H,)
        alpha = tf.reshape(alpha, (1, self.num_heads, 1, 1))  # broadcast

        scores = tf.matmul(Q, K_, transpose_b=True) / alpha
        weights = self.softmax(scores)
        out = tf.matmul(weights, V)  # (B,H,T,D)

        out = tf.transpose(out, perm=[0, 2, 1, 3])  # (B,T,H,D)
        out = tf.reshape(out, (B, -1, self.units))  # (B,T,units)

        return self.W_o(out)
