#  power_scale vs simple power demo
 This notebook demonstrates the `power_scale` function (using a log-sum-exp trick to normalize and exponentiate) compared to a simple power transformation. It also explores how `power_scale` behaves for large scale values, approaching a one-hot argmax.

# Function definitions

In [1]:
# %%

import jax.numpy as jnp
from jax import lax

def power_scale(value: jnp.ndarray, scale: float) -> jnp.ndarray:
    """Returns value scaled by the exponent factor using logsumexp trick."""
    log_activation = jnp.log(value)
    return lax.cond(
        jnp.logical_and(jnp.any(value != 0), scale != 1),
        lambda _: jnp.exp(scale * (log_activation - jnp.max(log_activation))),
        lambda _: value,
        None,
    )

def simple_power(value: jnp.ndarray, scale: float) -> jnp.ndarray:
    """Returns value raised to the specified power."""
    return value ** scale

# Example usage

In [2]:
# %%

vector = jnp.array([1.0, 2.0, 3.0, 4.0])
scale = 2.0

ps = power_scale(vector, scale)
sp = simple_power(vector, scale)

print("Original:", vector)
print("power_scale:", ps)
print("simple_power:", sp)

Original: [1. 2. 3. 4.]
power_scale: [0.0625 0.25   0.5625 1.    ]
simple_power: [ 1.  4.  9. 16.]


# Advantage over simple power
 The `power_scale` method ensures that the maximum entry is always scaled to 1 and preserves relative ratios under exponentiation, avoiding numerical underflow/overflow issues and preserving stable normalization.

# High-scale behavior and one-hot approximation
 As `scale` becomes very large, `power_scale` drives all entries except the maximum toward zero, effectively approximating a one-hot encoding at the position of the argmax.

In [3]:
# %%

vector = jnp.array([1.0, 2.0, 5.0, 3.0])
for s in [1, 5, 20, 100]:
    ps = power_scale(vector, s)
    print(f"scale={s} ->", ps)

scale=1 -> [1. 2. 5. 3.]
scale=5 -> [3.1999993e-04 1.0239999e-02 1.0000000e+00 7.7759996e-02]
scale=20 -> [1.0485754e-14 1.0995110e-08 1.0000000e+00 3.6561578e-05]
scale=100 -> [0.000000e+00 0.000000e+00 1.000000e+00 6.533167e-23]


# Mask-then-scale vs. Scale-then-mask
This cell shows how the ordering of masking (zeroing out already-recalled items)
and `power_scale` changes the final activations—especially with a large `scale`
where numerical under-/overflow can bite.

In [5]:
# %% [markdown]
# # Including the usual normalization step
# After masking and/or scaling we typically renormalize the surviving activations
# so they sum to 1. This makes the ordering difference even clearer.

# %%
import jax.numpy as jnp

def normalize(v):
    s = jnp.sum(v)
    return jnp.where(s == 0, v, v / s)

def mask_then_scale_normalized(vals, recalled_mask, scale):
    masked  = jnp.where(recalled_mask, 0.0, vals)
    scaled  = power_scale(masked, scale)
    return normalize(scaled)

def scale_then_mask_normalized(vals, recalled_mask, scale):
    scaled  = power_scale(vals, scale)
    masked  = jnp.where(recalled_mask, 0.0, scaled)
    return normalize(masked)

# toy data
activations = jnp.array([0.05, 1e8, 3.0, 2.8, 2.5])   # 1e8 is already-recalled
recalled    = jnp.array([False, True, False, False, False])
scale       = 90.0          # big enough to push exp() past the under-flow edge

mts_norm = mask_then_scale_normalized(activations, recalled, scale)
stm_norm = scale_then_mask_normalized(activations, recalled, scale)

print("mask → scale → norm :", mts_norm)
print("scale → mask → norm :", stm_norm)

mask → scale → norm : [0.0000000e+00 0.0000000e+00 9.9799335e-01 2.0064958e-03 7.4613226e-08]
scale → mask → norm : [0. 0. 0. 0. 0.]
