# Softmax is not enough

In [2]:
import transformers
import torch
from torch import nn

from typing import Optional

MODEL_NAME = "google/gemma-3-1b-pt"
# MODEL_NAME = "google/gemma-3n-e2b"

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def get_polynomial_value(x: float, c: list[float]) -> float:
    cur_val = 0
    for i in range(len(c) - 1):
        cur_val = (cur_val + c[i]) * x
    return cur_val + c[-1]

POLY_FIT = torch.tensor([-0.037, 0.481, -2.3, 4.917, -1.791]) # see Figure 6

def softmax_adaptive_temperature(logits, dim, poly_fit=POLY_FIT, dtype=torch.float32):
    """
    from "Softmax is not Enough" Figure 4, adapted to PyTorch
    """
    original_probs = torch.softmax(logits, dim=dim, dtype=dtype)
    entropy = torch.sum(-original_probs * torch.log(original_probs + 1e-9), axis=-1, keepdims=True) # compute the Shannon entropy
    beta = torch.where(
        entropy > 0.5,  # don’t overcorrect low-entropy heads
        torch.maximum(get_polynomial_value(entropy, poly_fit), torch.tensor(1.0)),  # never increase entropy
        torch.tensor(1.0)
    )
    return torch.softmax(logits * beta, dim=dim, dtype=dtype)
    
class SoftmaxZero(nn.Module):
    # for testing only
    def __init__(self):
        super(SoftmaxZero, self).__init__()

    def forward(self, logits):
        return torch.zeros_like(logits)
    
def adaptive_temperature_eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    softcap: Optional[float] = None,
    **kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    A copy of gemma3's eager attention forward function, but using adaptive temperature softmax.
    """
    if scaling is None:
        scaling = module.head_dim**-0.5

    # I don't like importing a foreign module like this, but ok for now
    key_states = transformers.models.gemma3.modeling_gemma3.repeat_kv(key, module.num_key_value_groups)
    value_states = transformers.models.gemma3.modeling_gemma3.repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

    if softcap is not None:
        attn_weights = attn_weights / softcap
        attn_weights = torch.tanh(attn_weights)
        attn_weights = attn_weights * softcap
    if attention_mask is not None:  # no matter the length, we just slice it
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = softmax_adaptive_temperature(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights

# register the new attention function
# this is so scrappy
transformers.modeling_utils.ALL_ATTENTION_FUNCTIONS.register("adaptive_temperature_eager", adaptive_temperature_eager_attention_forward)


In [4]:
model = transformers.Gemma3ForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="adaptive_temperature_eager")
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)

In [5]:
pipe = transformers.pipeline("text-generation", model=model, 
                tokenizer=tokenizer, device="cpu")
output = pipe("The Ohio State University official mascot is ", max_new_tokens=50)
print(output[0]['generated_text'])

Device set to use cpu


The Ohio State University official mascot is <b>"The Ohio State Buckeye"</b>, the school's blue mascot. The school's mascot was first used in 1903 when the team won the National Championship. The mascot was designed by George H. Stover (1


# CLRS Dataset Evaluation

Now let's evaluate both models on algorithmic reasoning tasks from the CLRS dataset to see if adaptive temperature softmax improves performance on reasoning tasks.

In [14]:
# import CLRS text utilities for generating algorithmic reasoning tasks
import sys
import os
from datasets import IterableDataset, Value, Features

sys.path.append(os.path.join(os.getcwd(), 'clrs-repo'))

from clrs._src.clrs_text.huggingface_generators import clrs_generator


ds = IterableDataset.from_generator(
        clrs_generator,
        features=Features(
            {
                "text": Value(dtype="string", id=None),
                "question": Value(dtype="string", id=None),
                "answer": Value(dtype="string", id=None),
                "algo_name": Value(dtype="string", id=None),
                "length": Value(dtype="int32", id=None),
                "use_hints": Value(dtype="bool_", id=None),
            }
        ),
        # gen_kwargs={"algos_and_lengths": algos_and_lengths},
    )