<a href="https://colab.research.google.com/github/iRahulPandey/medium-articles/blob/master/Logit_Masking_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Logit Masking Tutorial: Understanding and Controlling LLM Text Generation

This notebook demonstrates:
1. How LLMs generate text (deterministic computation + random sampling)
2. Traditional control methods (temperature, top-k, top-p)
3. Their limitations for guaranteed compliance
4. Logit masking as a solution for hard constraints

Author: Rahul Pandey
Date: 2025
Model: Llama 3.2 3B Instruct
"""

In [None]:
# ============================================================================
# SETUP AND IMPORTS
# ============================================================================

# Install required packages
!pip install --upgrade --quiet transformers torch accelerate bitsandbytes

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.4/41.4 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m100.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Import packages
from transformers import pipeline, LogitsProcessor
import torch
from google.colab import userdata
import warnings
warnings.filterwarnings('ignore')

In [None]:
# ============================================================================
# MODEL INITIALIZATION
# ============================================================================

def setup_model_pipeline():
    """
    Initialize Llama 3.2 3B Instruct model with automatic device detection.

    Purpose:
        - Detect available hardware (GPU/CPU)
        - Configure optimal settings for available hardware
        - Load model with appropriate quantization

    Returns:
        tuple: (pipeline, device_info)
            - pipeline: Hugging Face text generation pipeline
            - device_info: dict with hardware information

    Expected Output:
        Prints device information and model loading status
        Returns configured pipeline ready for inference

    Notes:
        - GPU: Uses 4-bit quantization for memory efficiency
        - CPU: Uses float32 (slower but functional)
        - Requires HF_TOKEN in Colab secrets for gated models
    """

    MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

    # Check available hardware
    device_info = {
        'has_cuda': torch.cuda.is_available(),
        'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
        'current_device': torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
    }

    # Display hardware information
    print(f"Device Detection:")
    print(f"   CUDA Available: {device_info['has_cuda']}")
    if device_info['has_cuda']:
        print(f"   GPU: {torch.cuda.get_device_name()}")
        print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    else:
        print(f"   Running on CPU (will be slower)")

    # Get HuggingFace authentication token
    try:
        hf_token = userdata.get('HF_TOKEN')
        print("HuggingFace token loaded")
    except:
        print("WARNING: HF_TOKEN not found in secrets")
        print("Add your token in Colab: Secrets -> Add Secret -> Name: HF_TOKEN")
        raise ValueError("HF_TOKEN required for gated models")

    # Base pipeline configuration
    pipeline_config = {
        "task": "text-generation",
        "model": MODEL_ID,
        "token": hf_token,
        "trust_remote_code": True
    }

    # Hardware-specific optimizations
    if device_info['has_cuda']:
        # GPU configuration: 4-bit quantization for efficiency
        pipeline_config.update({
            "torch_dtype": torch.bfloat16,
            "device_map": "auto",
            "model_kwargs": {
                "quantization_config": {
                    "load_in_4bit": True,
                    "bnb_4bit_use_double_quant": True,
                    "bnb_4bit_quant_type": "nf4",
                    "bnb_4bit_compute_dtype": torch.bfloat16
                }
            }
        })
        print("Optimizing for GPU with 4-bit quantization")
    else:
        # CPU configuration: Standard float32
        pipeline_config.update({
            "torch_dtype": torch.float32,
            "device_map": "cpu"
        })
        print("Optimizing for CPU (slower but functional)")

    # Load the model
    print("Loading model... (this may take 2-3 minutes)")
    pipe = pipeline(**pipeline_config)

    print(f"Model loaded successfully!")
    print(f"Model device: {pipe.model.device}")

    return pipe, device_info

def get_model_device():
    """
    Get the device where the model is currently loaded.

    Purpose:
        Dynamically determine device for tensor operations to avoid
        device mismatch errors when moving tensors between CPU/GPU.

    Returns:
        torch.device: Device where model resides (cuda:0 or cpu)

    Notes:
        - Checks multiple attributes for compatibility
        - Handles device_map="auto" configuration
        - Falls back to CUDA check if attributes missing
    """
    if hasattr(pipe.model, 'device'):
        return pipe.model.device
    elif hasattr(pipe.model, 'hf_device_map'):
        # For device_map="auto", get device of first layer
        first_device = next(iter(pipe.model.hf_device_map.values()))
        return torch.device(first_device)
    else:
        # Fallback to CUDA availability check
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the pipeline (global variable for notebook)
pipe, device_info = setup_model_pipeline()

Device Detection:
   CUDA Available: True
   GPU: Tesla T4
   GPU Memory: 15.8 GB
HuggingFace token loaded
Optimizing for GPU with 4-bit quantization
Loading model... (this may take 2-3 minutes)


config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Device set to use cuda:0


Model loaded successfully!
Model device: cuda:0


In [None]:
# ============================================================================
# PART 1: UNDERSTANDING DETERMINISM
# ============================================================================

def peek_inside_ai_brain(prompt):
    """
    Reveal the deterministic probability distribution computed by the model.

    Purpose:
        Demonstrate that language models compute the exact same probabilities
        for the same input every time. The "intelligence" is in these fixed
        probabilities, not in random choices.

    Args:
        prompt (str): Input text to analyze

    Returns:
        torch.Tensor: Full probability distribution over vocabulary
                     (shape: [vocab_size])

    Expected Output:
        Prints top 10 most probable next tokens with their probabilities.
        Returns complete probability distribution for further analysis.

    Example:
        >>> peek_inside_ai_brain("This luxury watch is")
        Deterministic probability distribution:
           1. ' a'            : 18.2%
           2. ' the'          : 15.7%
           ...

    Notes:
        - Forward pass is pure function: same input -> same output
        - Probabilities sum to 1.0
        - Distribution does NOT change between runs
    """

    print(f"Analyzing AI thoughts for: '{prompt}'")

    # Get model's device for tensor operations
    model_device = get_model_device()

    # Tokenize input and move to model's device
    inputs = pipe.tokenizer.encode(prompt, return_tensors='pt')
    inputs = inputs.to(model_device)

    # Forward pass through transformer (deterministic computation)
    with torch.no_grad():
        try:
            outputs = pipe.model(inputs)
            # Get logits for the last position (next token prediction)
            logits = outputs.logits[0, -1]
        except Exception as e:
            print(f"Error in forward pass: {e}")
            return None

    # Convert logits to probabilities using softmax (still deterministic)
    probabilities = torch.softmax(logits, dim=-1)

    # Get top 10 most probable tokens
    top_probs, top_indices = torch.topk(probabilities, 10)

    # Display results
    print("Deterministic probability distribution:")
    print("-" * 50)
    for i, (prob, token_id) in enumerate(zip(top_probs, top_indices)):
        token = pipe.tokenizer.decode([token_id])
        confidence = prob.item() * 100
        # Visual bar chart (each █ = 5%)
        bar = "█" * max(1, int(confidence / 5))
        print(f"  {i+1:2d}. '{token:15s}': {confidence:5.1f}% {bar}")

    return probabilities

def test_determinism(prompt="This luxury watch is"):
    """
    Prove that model computations are 100% deterministic.

    Purpose:
        Experimentally verify that the model produces identical probability
        distributions for the same input across multiple runs.

    Args:
        prompt (str): Test input text

    Returns:
        bool: True if calculations are identical, False otherwise

    Expected Output:
        Prints two probability distributions and verifies they are identical
        within numerical precision (1e-6).

    Key Insight:
        If this returns True (it will), it proves that all "intelligence"
        is in the deterministic computation, NOT in random selection.
    """

    print("TESTING AI DETERMINISM")
    print("=" * 60)

    try:
        # First calculation
        print("First calculation:")
        probs1 = peek_inside_ai_brain(prompt)

        if probs1 is None:
            print("Failed to get probabilities")
            return False

        print(f"\n{'='*50}")

        # Second calculation (should be identical)
        print("Second calculation (should be identical):")
        probs2 = peek_inside_ai_brain(prompt)

        if probs2 is None:
            print("Failed to get probabilities")
            return False

        # Verify identical calculations (within floating point precision)
        identical = torch.allclose(probs1, probs2, atol=1e-6)

        print(f"\nRESULT:")
        if identical:
            print(f"   Calculations identical: {identical} ✅")
            print("   This proves AI 'thinking' is 100% deterministic!")
            print("   Randomness only comes from the sampling step")
        else:
            print(f"   Calculations identical: {identical} ❌")
            print("   Unexpected! Check for hardware issues or precision errors")

        return identical

    except Exception as e:
        print(f"Error in determinism test: {e}")
        return False

# Run the determinism test
determinism_verified = test_determinism()

TESTING AI DETERMINISM
First calculation:
Analyzing AI thoughts for: 'This luxury watch is'
Deterministic probability distribution:
--------------------------------------------------
   1. ' a             ':  37.5% ███████
   2. ' designed      ':   9.5% █
   3. ' made          ':   8.4% █
   4. ' crafted       ':   5.4% █
   5. ' inspired      ':   3.7% █
   6. ' part          ':   3.1% █
   7. ' adorned       ':   2.4% █
   8. ' an            ':   2.0% █
   9. ' the           ':   2.0% █
  10. ' not           ':   1.9% █

Second calculation (should be identical):
Analyzing AI thoughts for: 'This luxury watch is'
Deterministic probability distribution:
--------------------------------------------------
   1. ' a             ':  37.5% ███████
   2. ' designed      ':   9.5% █
   3. ' made          ':   8.4% █
   4. ' crafted       ':   5.4% █
   5. ' inspired      ':   3.7% █
   6. ' part          ':   3.1% █
   7. ' adorned       ':   2.4% █
   8. ' an            ':   2.0% █
   9. ' t

In [None]:
# ============================================================================
# PART 2: WHERE RANDOMNESS ENTERS
# ============================================================================

def demonstrate_sampling_randomness(prompt, num_samples=5):
    """
    Show that randomness exists ONLY in the sampling step, not computation.

    Purpose:
        Demonstrate that the same probability distribution can produce
        different outputs through random sampling. This is where variety
        and creativity come from.

    Args:
        prompt (str): Input text to continue
        num_samples (int): Number of different samples to generate

    Expected Output:
        Prints multiple completions that are different despite using
        the exact same probability distribution.

    Key Insight:
        Same input -> Same probabilities (deterministic)
        Same probabilities -> Different samples (random)

    Example Output:
        Sample 1: 'This luxury watch is a testament to precision'
        Sample 2: 'This luxury watch is the epitome of craftsmanship'
        Sample 3: 'This luxury watch is designed for collectors'
    """

    print(f"SAMPLING RANDOMNESS DEMONSTRATION")
    print(f"Prompt: '{prompt}'")
    print("=" * 60)

    # Get model device
    model_device = get_model_device()

    # Tokenize and move to model device
    inputs = pipe.tokenizer.encode(prompt, return_tensors='pt')
    inputs = inputs.to(model_device)

    print("Sampling from the SAME deterministic probability distribution:")
    print("-" * 50)

    for i in range(num_samples):
        try:
            # Same deterministic forward pass every time
            with torch.no_grad():
                output = pipe.model.generate(
                    inputs,
                    max_new_tokens=8,
                    do_sample=True,  # THIS is where randomness enters
                    temperature=1.0,
                    pad_token_id=pipe.tokenizer.eos_token_id,
                    eos_token_id=pipe.tokenizer.eos_token_id
                )

            # Decode only the newly generated tokens
            generated = pipe.tokenizer.decode(
                output[0][len(inputs[0]):],
                skip_special_tokens=True
            )

            print(f"  Sample {i+1}: '{prompt}{generated.strip()}'")

        except Exception as e:
            print(f"  Sample {i+1}: Generation failed: {e}")

    print(f"\nKey Insight:")
    print(f"   • Same input → Same probabilities (deterministic)")
    print(f"   • Same probabilities → Different samples (random)")
    print(f"   • Randomness exists ONLY in the sampling step!")

# Demonstrate sampling randomness
demonstrate_sampling_randomness("This luxury watch is ")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


SAMPLING RANDOMNESS DEMONSTRATION
Prompt: 'This luxury watch is '
Sampling from the SAME deterministic probability distribution:
--------------------------------------------------
  Sample 1: 'This luxury watch is 40 mm in diameter. It features a'
  Sample 2: 'This luxury watch is 42 mm in diameter, 12.'
  Sample 3: 'This luxury watch is 18k gold and engraved with 3'
  Sample 4: 'This luxury watch is 40mm in size, made of platinum'
  Sample 5: 'This luxury watch is 18K rose gold and 18K'

Key Insight:
   • Same input → Same probabilities (deterministic)
   • Same probabilities → Different samples (random)
   • Randomness exists ONLY in the sampling step!


In [None]:
# ============================================================================
# PART 3: TRADITIONAL CONTROL METHOD 1 - TEMPERATURE
# ============================================================================

def demonstrate_temperature_effects(prompt, temperatures=[0.1, 1.0, 2.0]):
    """
    Show how temperature reshapes the probability distribution.

    Purpose:
        Demonstrate that temperature is a style control that changes the
        "sharpness" of the probability distribution before sampling.

    Args:
        prompt (str): Input text
        temperatures (list): Temperature values to test

    Temperature Effects:
        - T < 1.0: Sharper distribution (more confident/conservative)
        - T = 1.0: Original distribution (balanced)
        - T > 1.0: Flatter distribution (more creative/risky)

    Expected Output:
        Shows how the same logits produce different probability distributions
        when divided by different temperature values.

    Mathematical Formula:
        probs = softmax(logits / temperature)

    Limitation:
        Even at very low temperature, unwanted tokens still have
        non-zero probability!
    """

    print(f"TEMPERATURE EFFECTS DEMONSTRATION")
    print(f"Prompt: '{prompt}'")
    print("=" * 70)

    # Get model device and prepare inputs
    model_device = get_model_device()
    inputs = pipe.tokenizer.encode(prompt, return_tensors='pt').to(model_device)

    try:
        # Get base logits (always identical for same input)
        with torch.no_grad():
            outputs = pipe.model(inputs)
            base_logits = outputs.logits[0, -1]

        for temp in temperatures:
            print(f"\nTemperature {temp:3.1f}:")

            # Apply temperature scaling: divide logits by temperature
            scaled_logits = base_logits / temp
            probs = torch.softmax(scaled_logits, dim=-1)
            top_probs, top_indices = torch.topk(probs, 5)

            # Calculate distribution entropy (measure of uncertainty)
            entropy = -(probs * torch.log(probs + 1e-10)).sum().item()

            print(f"   Distribution sharpness (entropy): {entropy:.2f}")
            print("   Top 5 tokens:")

            for i, (prob, token_id) in enumerate(zip(top_probs, top_indices)):
                token = pipe.tokenizer.decode([token_id])
                percentage = prob.item() * 100
                bar = "█" * max(1, int(percentage / 4))
                print(f"     {token:12s}: {percentage:5.1f}% {bar}")

            # Show concentration of probability mass
            top_5_mass = top_probs.sum().item() * 100
            print(f"   Top 5 probability mass: {top_5_mass:.1f}%")

    except Exception as e:
        print(f"Temperature demonstration failed: {e}")

def temperature_generation_comparison(prompt):
    """
    Compare actual text generation with different temperatures.

    Purpose:
        Show how temperature affects the style and variety of generated text
        in practice.

    Args:
        prompt (str): Input text to continue

    Expected Output:
        Low temp: More repetitive, safe, predictable
        Medium temp: Balanced variety and coherence
        High temp: More diverse, creative, potentially incoherent
    """

    print(f"\nTEMPERATURE GENERATION COMPARISON")
    print(f"Prompt: '{prompt}'")
    print("-" * 50)

    temperatures = [0.1, 1.0, 2.0]

    for temp in temperatures:
        print(f"\nTemperature {temp} outputs:")

        messages = [{"role": "user", "content": f"Continue this text: {prompt}"}]

        for i in range(3):
            try:
                result = pipe(
                    messages,
                    max_new_tokens=12,
                    temperature=temp,
                    do_sample=True,
                    pad_token_id=pipe.tokenizer.eos_token_id
                )[0]['generated_text']
                print(f"   Sample {i+1}: {result}")

            except Exception as e:
                print(f"   Sample {i+1}: Failed: {e}")

# Run temperature demonstrations
demonstrate_temperature_effects("This luxury watch is")
temperature_generation_comparison("This luxury watch is")

TEMPERATURE EFFECTS DEMONSTRATION
Prompt: 'This luxury watch is'

Temperature 0.1:
   Distribution sharpness (entropy): 0.00
   Top 5 tokens:
      a          : 100.0% █████████████████████████
      designed   :   0.0% █
      made       :   0.0% █
      crafted    :   0.0% █
      inspired   :   0.0% █
   Top 5 probability mass: 100.0%

Temperature 1.0:
   Distribution sharpness (entropy): 3.09
   Top 5 tokens:
      a          :  37.5% █████████
      designed   :   9.5% ██
      made       :   8.4% ██
      crafted    :   5.4% █
      inspired   :   3.7% █
   Top 5 probability mass: 64.5%

Temperature 2.0:
   Distribution sharpness (entropy): 9.44
   Top 5 tokens:
      a          :   2.4% █
      designed   :   1.2% █
      made       :   1.1% █
      crafted    :   0.9% █
      inspired   :   0.7% █
   Top 5 probability mass: 6.3%

TEMPERATURE GENERATION COMPARISON
Prompt: 'This luxury watch is'
--------------------------------------------------

Temperature 0.1 outputs:
   Sampl

In [None]:
# ============================================================================
# PART 4: TRADITIONAL CONTROL METHOD 2 - TOP-K
# ============================================================================

def demonstrate_top_k_sampling(prompt, k_values=[5, 20, 100]):
    """
    Show how top-k limits vocabulary choices to k most probable tokens.

    Purpose:
        Demonstrate that top-k sampling reduces vocabulary size by keeping
        only the k most probable tokens and zeroing out all others.

    Args:
        prompt (str): Input text
        k_values (list): Different k values to demonstrate

    How It Works:
        1. Compute probabilities for all tokens
        2. Keep only top-k tokens
        3. Set all other probabilities to 0
        4. Renormalize the k remaining probabilities
        5. Sample only from these k tokens

    Limitation:
        Fixed k doesn't adapt to distribution. Sometimes you want more
        options (flat distribution), sometimes fewer (sharp distribution).
    """

    print(f"\nTOP-K SAMPLING DEMONSTRATION")
    print(f"Prompt: '{prompt}'")
    print("=" * 70)

    # Get model device and prepare inputs
    model_device = get_model_device()
    inputs = pipe.tokenizer.encode(prompt, return_tensors='pt').to(model_device)

    try:
        with torch.no_grad():
            outputs = pipe.model(inputs)
            logits = outputs.logits[0, -1]

        probabilities = torch.softmax(logits, dim=-1)

        for k in k_values:
            print(f"\nTop-{k} sampling:")

            # Get top-k tokens
            top_k_probs, top_k_indices = torch.topk(
                probabilities,
                min(k, len(probabilities))
            )

            # Calculate vocabulary statistics
            total_vocab_size = len(probabilities)
            top_k_mass = top_k_probs.sum().item()

            print(f"   Original vocabulary: {total_vocab_size:,} tokens")
            print(f"   Filtered vocabulary: {k} tokens ({k/total_vocab_size*100:.1f}%)")
            print(f"   Probability mass captured: {top_k_mass:.3f}")

            # Show top 5 from the k-filtered distribution
            print("   Top 5 tokens in filtered vocab:")
            for i in range(min(5, k)):
                token = pipe.tokenizer.decode([top_k_indices[i]])
                prob = top_k_probs[i].item()
                print(f"     {token:12s}: {prob*100:5.1f}%")

    except Exception as e:
        print(f"Top-K demonstration failed: {e}")

def top_k_generation_comparison(prompt):
    """
    Compare generation with different k values.

    Purpose:
        Show how k affects output diversity in practice.

    Args:
        prompt (str): Input text to continue

    Expected Output:
        Small k: Very conservative, limited vocabulary
        Large k: More diverse, larger vocabulary
    """

    print(f"\nTOP-K GENERATION COMPARISON")
    print(f"Prompt: '{prompt}'")
    print("-" * 50)

    k_values = [5, 20, 100]
    messages = [{"role": "user", "content": f"Continue this text: {prompt}"}]

    for k in k_values:
        print(f"\nTop-K = {k}:")

        for i in range(3):
            try:
                result = pipe(
                    messages,
                    max_new_tokens=10,
                    do_sample=True,
                    top_k=k,
                    temperature=0.8,
                    pad_token_id=pipe.tokenizer.eos_token_id
                )[0]['generated_text']
                print(f"   Sample {i+1}: {result}")

            except Exception as e:
                print(f"   Sample {i+1}: Failed: {e}")

# Run top-k demonstrations
demonstrate_top_k_sampling("This luxury watch is")
top_k_generation_comparison("This luxury watch is")


TOP-K SAMPLING DEMONSTRATION
Prompt: 'This luxury watch is'

Top-5 sampling:
   Original vocabulary: 128,256 tokens
   Filtered vocabulary: 5 tokens (0.0%)
   Probability mass captured: 0.645
   Top 5 tokens in filtered vocab:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%
      crafted    :   5.4%
      inspired   :   3.7%

Top-20 sampling:
   Original vocabulary: 128,256 tokens
   Filtered vocabulary: 20 tokens (0.0%)
   Probability mass captured: 0.855
   Top 5 tokens in filtered vocab:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%
      crafted    :   5.4%
      inspired   :   3.7%

Top-100 sampling:
   Original vocabulary: 128,256 tokens
   Filtered vocabulary: 100 tokens (0.1%)
   Probability mass captured: 0.941
   Top 5 tokens in filtered vocab:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%
      crafted    :   5.4%
      inspired   :   3.7%

TOP-K GENERATION COMPARISON
Prompt: 'This

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


   Sample 1: [{'role': 'user', 'content': 'Continue this text: This luxury watch is'}, {'role': 'assistant', 'content': '...a masterpiece of horological engineering, featuring a'}]
   Sample 2: [{'role': 'user', 'content': 'Continue this text: This luxury watch is'}, {'role': 'assistant', 'content': '...a masterpiece of horological engineering, boasting an'}]
   Sample 3: [{'role': 'user', 'content': 'Continue this text: This luxury watch is'}, {'role': 'assistant', 'content': '...a masterpiece of horology, featuring a sleek'}]

Top-K = 20:
   Sample 1: [{'role': 'user', 'content': 'Continue this text: This luxury watch is'}, {'role': 'assistant', 'content': 'a masterpiece of horological craftsmanship, combining cutting-edge'}]
   Sample 2: [{'role': 'user', 'content': 'Continue this text: This luxury watch is'}, {'role': 'assistant', 'content': '...a masterpiece of horological engineering, boasting a'}]
   Sample 3: [{'role': 'user', 'content': 'Continue this text: This luxury watch i

In [None]:
# ============================================================================
# PART 5: TRADITIONAL CONTROL METHOD 3 - TOP-P (NUCLEUS)
# ============================================================================

def demonstrate_nucleus_sampling(prompt, p_values=[0.5, 0.8, 0.95]):
    """
    Show how nucleus sampling adapts vocabulary size to distribution shape.

    Purpose:
        Demonstrate that top-p (nucleus sampling) automatically adjusts
        vocabulary size based on how concentrated the probability is.

    Args:
        prompt (str): Input text
        p_values (list): Different p values to demonstrate

    How It Works:
        1. Compute probabilities for all tokens
        2. Sort tokens by probability (descending)
        3. Keep smallest set where cumulative probability >= p
        4. This is the "nucleus" - sample only from these tokens

    Advantage over Top-K:
        Adapts to distribution shape:
        - Sharp distribution (clear winner): Keeps ~5-10 tokens
        - Flat distribution (many options): Keeps ~50-100 tokens

    Limitation:
        Still probabilistic. Can't guarantee specific tokens never appear.
    """

    print(f"\nNUCLEUS SAMPLING DEMONSTRATION")
    print(f"Prompt: '{prompt}'")
    print("=" * 70)

    # Get model device and prepare inputs
    model_device = get_model_device()
    inputs = pipe.tokenizer.encode(prompt, return_tensors='pt').to(model_device)

    try:
        with torch.no_grad():
            outputs = pipe.model(inputs)
            logits = outputs.logits[0, -1]

        probabilities = torch.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)

        for p in p_values:
            print(f"\nTop-P = {p} (Nucleus sampling):")

            # Find nucleus cutoff: smallest set with cumulative prob >= p
            cumsum = torch.cumsum(sorted_probs, dim=0)
            nucleus_size = (cumsum <= p).sum().item() + 1  # +1 includes cutoff

            nucleus_probs = sorted_probs[:nucleus_size]
            nucleus_indices = sorted_indices[:nucleus_size]

            # Calculate nucleus statistics
            nucleus_mass = nucleus_probs.sum().item()
            total_vocab = len(probabilities)

            print(f"   Nucleus size: {nucleus_size} tokens "
                  f"({nucleus_size/total_vocab*100:.2f}% of vocab)")
            print(f"   Probability mass: {nucleus_mass:.3f}")

            # Characterize distribution
            if nucleus_size < 50:
                char = "Sharp distribution (confident prediction)"
            else:
                char = "Flat distribution (many plausible options)"
            print(f"   Adaptation: {char}")

            # Show top tokens in nucleus
            print("   Top 5 tokens in nucleus:")
            for i in range(min(5, nucleus_size)):
                token = pipe.tokenizer.decode([nucleus_indices[i]])
                prob = nucleus_probs[i].item()
                print(f"     {token:12s}: {prob*100:5.1f}%")

    except Exception as e:
        print(f"Nucleus sampling demonstration failed: {e}")

def compare_top_k_vs_top_p(prompt):
    """
    Direct comparison of top-k vs top-p sampling in generation.

    Purpose:
        Show the practical difference between fixed vocabulary (top-k)
        and adaptive vocabulary (top-p).

    Args:
        prompt (str): Input text to continue
    """

    print(f"\nTOP-K vs TOP-P COMPARISON")
    print(f"Prompt: '{prompt}'")
    print("-" * 60)

    messages = [{"role": "user", "content": f"Continue this text: {prompt}"}]

    # Top-K sampling (fixed vocabulary)
    print("\nTop-K = 20 (fixed vocabulary):")
    for i in range(3):
        try:
            result = pipe(
                messages,
                max_new_tokens=10,
                top_k=20,
                temperature=0.8,
                pad_token_id=pipe.tokenizer.eos_token_id
            )[0]['generated_text']
            print(f"   Sample {i+1}: {result}")
        except Exception as e:
            print(f"   Sample {i+1}: Failed: {e}")

    # Top-P sampling (adaptive vocabulary)
    print("\nTop-P = 0.8 (adaptive vocabulary):")
    for i in range(3):
        try:
            result = pipe(
                messages,
                max_new_tokens=10,
                top_p=0.8,
                temperature=0.8,
                pad_token_id=pipe.tokenizer.eos_token_id
            )[0]['generated_text']
            print(f"   Sample {i+1}: {result}")
        except Exception as e:
            print(f"   Sample {i+1}: Failed: {e}")

# Run nucleus sampling demonstrations
demonstrate_nucleus_sampling("This luxury watch is")
compare_top_k_vs_top_p("This luxury watch is")


NUCLEUS SAMPLING DEMONSTRATION
Prompt: 'This luxury watch is'

Top-P = 0.5 (Nucleus sampling):
   Nucleus size: 3 tokens (0.00% of vocab)
   Probability mass: 0.555
   Adaptation: Sharp distribution (confident prediction)
   Top 5 tokens in nucleus:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%

Top-P = 0.8 (Nucleus sampling):
   Nucleus size: 13 tokens (0.01% of vocab)
   Probability mass: 0.805
   Adaptation: Sharp distribution (confident prediction)
   Top 5 tokens in nucleus:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%
      crafted    :   5.4%
      inspired   :   3.7%

Top-P = 0.95 (Nucleus sampling):
   Nucleus size: 91 tokens (0.07% of vocab)
   Probability mass: 0.938
   Adaptation: Flat distribution (many plausible options)
   Top 5 tokens in nucleus:
      a          :  37.5%
      designed   :   9.5%
      made       :   8.4%
      crafted    :   5.4%
      inspired   :   3.7%

TOP-K vs TOP-P COMPARISON
Prom

In [None]:
# ============================================================================
# PART 6: THE LIMITATION - WHY TRADITIONAL METHODS FAIL
# ============================================================================

def demonstrate_control_limitations_simple():
    """
    Demonstrate the fundamental limitation of probabilistic controls.

    Purpose:
        Show that temperature/top-k/top-p control HOW we sample, but cannot
        guarantee WHAT gets sampled. If a token has any probability > 0,
        it can eventually be selected.

    Key Insight:
        Over thousands of generations, even 0.1% probability events
        will occur. You can reduce violations, but not eliminate them.

    Expected Output:
        Shows that problematic words have non-zero probability even
        in conservative settings.

    Business Impact:
        For brand compliance, legal requirements, or safety - you need
        guarantees, not just reduced probability.
    """

    print(f"\nTRADITIONAL CONTROL LIMITATIONS")
    print("=" * 70)

    problematic_words = ['cheap', 'affordable', 'budget', 'basic', 'standard']

    print("The core problem with traditional sampling methods:")
    print("-" * 50)

    # Explain the theoretical issue
    print("THEORETICAL ISSUE:")
    print("   • Temperature/Top-K/Top-P control HOW the AI chooses")
    print("   • But they can't control WHAT the AI can choose FROM")
    print("   • If 'cheap' has 5% probability, it might still get selected")
    print("   • Over thousands of generations, violations are inevitable")

    # Demonstrate with actual probabilities
    try:
        print(f"\nPRACTICAL DEMONSTRATION:")

        # Get raw probabilities to show the issue
        test_prompt = "This luxury watch is"
        model_device = get_model_device()
        inputs = pipe.tokenizer.encode(test_prompt, return_tensors='pt').to(model_device)

        with torch.no_grad():
            outputs = pipe.model(inputs)
            logits = outputs.logits[0, -1]

        probabilities = torch.softmax(logits, dim=-1)

        # Check if problematic words have non-zero probability
        problematic_found = False
        print(f"   Checking if problematic words have non-zero probability...")

        for word in problematic_words:
            # Tokenize with leading space (important!)
            word_tokens = pipe.tokenizer.encode(f" {word}", add_special_tokens=False)
            if word_tokens:
                word_token_id = word_tokens[0]
                if word_token_id < len(probabilities):
                    word_prob = probabilities[word_token_id].item()
                    if word_prob > 1e-6:  # Non-negligible probability
                        print(f"   ❌ '{word}' has {word_prob*100:.3f}% probability")
                        problematic_found = True

        if problematic_found:
            print(f"\n   💡 Even with conservative sampling, these words could appear!")
            print(f"   💡 Over 1000 generations, you'd expect multiple violations")
        else:
            print(f"\n   ✅ No problematic words found in top probabilities")
            print(f"   (But this could change with different prompts)")

    except Exception as e:
        print(f"   Demonstration failed: {e}")

    print(f"\nSOLUTION PREVIEW:")
    print(f"   Logit masking will set problematic word probabilities to 0%")
    print(f"   Making it IMPOSSIBLE for the AI to select them")
    print(f"   This guarantees 100% compliance with business rules")

# Demonstrate the limitation
demonstrate_control_limitations_simple()


TRADITIONAL CONTROL LIMITATIONS
The core problem with traditional sampling methods:
--------------------------------------------------
THEORETICAL ISSUE:
   • Temperature/Top-K/Top-P control HOW the AI chooses
   • But they can't control WHAT the AI can choose FROM
   • If 'cheap' has 5% probability, it might still get selected
   • Over thousands of generations, violations are inevitable

PRACTICAL DEMONSTRATION:
   Checking if problematic words have non-zero probability...
   ❌ 'affordable' has 0.003% probability
   ❌ 'budget' has 0.000% probability
   ❌ 'basic' has 0.000% probability
   ❌ 'standard' has 0.001% probability

   💡 Even with conservative sampling, these words could appear!
   💡 Over 1000 generations, you'd expect multiple violations

SOLUTION PREVIEW:
   Logit masking will set problematic word probabilities to 0%
   Making it IMPOSSIBLE for the AI to select them
   This guarantees 100% compliance with business rules


In [None]:
# ============================================================================
# PART 7: THE SOLUTION - LOGIT MASKING
# ============================================================================

class RobustBrandAwareLogitsProcessor(LogitsProcessor):
    """
    Production-ready logits processor for guaranteed vocabulary compliance.

    Purpose:
        Enforce hard constraints on vocabulary by modifying logits before
        sampling. This makes violations literally impossible, not just unlikely.

    How It Works:
        1. Called at each generation step BEFORE sampling
        2. Decodes current beam sequences to readable text
        3. Scores each sequence: +1 for premium words, -1 for budget words
        4. Identifies best-scoring sequence(s)
        5. Heavily penalizes logits of lower-scoring sequences
        6. Result: Only best-scoring paths can continue

    Attributes:
        tokenizer: HuggingFace tokenizer for decoding
        premium_words (set): Words that increase score
        budget_words (set): Words that decrease score
        penalty_strength (float): Penalty magnitude (default: 10000)
        verbose (bool): Enable debug logging
        step_count (int): Track generation steps
        device (torch.device): Automatically detected device

    Use Cases:
        - Brand voice enforcement (premium vs budget language)
        - Legal compliance (required/forbidden terms)
        - Content safety (block inappropriate content)
        - Domain expertise (require technical terminology)
        - Format enforcement (structured output)

    Example:
        >>> processor = RobustBrandAwareLogitsProcessor(
        ...     tokenizer,
        ...     premium_words={'luxury', 'premium'},
        ...     budget_words={'cheap', 'affordable'}
        ... )
        >>> output = model.generate(
        ...     input_ids,
        ...     logits_processor=[processor],
        ...     num_beams=5
        ... )
    """

    def __init__(self, tokenizer, premium_words, budget_words,
                 penalty_strength=10000.0, verbose=False):
        """
        Initialize the logits processor with vocabulary constraints.

        Args:
            tokenizer: HuggingFace tokenizer
            premium_words (set/list): Words that increase compliance score
            budget_words (set/list): Words that decrease compliance score
            penalty_strength (float): Magnitude of penalty for low-scoring sequences
            verbose (bool): Enable detailed logging for debugging
        """
        self.tokenizer = tokenizer
        self.premium_words = set(word.lower() for word in premium_words)
        self.budget_words = set(word.lower() for word in budget_words)
        self.penalty_strength = penalty_strength
        self.verbose = verbose
        self.step_count = 0
        self.device = None  # Will be detected from input tensors

        # Print initialization summary
        print(f"RobustBrandAware Processor initialized:")
        print(f"   ✅ Premium words: {len(self.premium_words)}")
        print(f"   ❌ Budget words: {len(self.budget_words)}")
        print(f"   🔧 Penalty strength: {penalty_strength}")
        print(f"   📊 Verbose mode: {verbose}")

    def _detect_device(self, tensor):
        """
        Dynamically detect and cache device from input tensors.

        Purpose:
            Avoid device mismatch errors by detecting device once
            and reusing it for all tensor operations.

        Args:
            tensor: Any torch tensor from the generation process

        Returns:
            torch.device: Detected device (cuda:0 or cpu)
        """
        if self.device is None:
            self.device = tensor.device
            if self.verbose:
                print(f"   📍 Detected device: {self.device}")
        return self.device

    def evaluate_brand_compliance(self, text):
        """
        Score text based on premium vs budget word usage.

        Purpose:
            Quantify how well text aligns with brand guidelines.

        Args:
            text (str): Generated text to score

        Returns:
            int: Score = (premium words count) - (budget words count)
                 Higher is better

        Examples:
            "luxury premium watch" -> +2 (2 premium, 0 budget)
            "cheap affordable watch" -> -2 (0 premium, 2 budget)
            "luxury affordable watch" -> 0 (1 premium, 1 budget)
        """
        if not text:
            return 0

        text_lower = text.lower()

        # Count premium words
        premium_count = sum(1 for word in self.premium_words
                          if word in text_lower)

        # Count budget words
        budget_count = sum(1 for word in self.budget_words
                         if word in text_lower)

        # Score: premium is good, budget is bad
        return premium_count - budget_count

    def __call__(self, input_ids, scores):
        """
        Main logit masking logic - called at every generation step.

        Purpose:
            Modify logits to enforce vocabulary constraints before sampling.

        Args:
            input_ids: Current sequence IDs for all beams [batch_size, seq_len]
            scores: Current logits for next token [batch_size, vocab_size]

        Returns:
            torch.Tensor: Modified logits with penalties applied

        Process:
            1. Decode all beam sequences to text
            2. Score each sequence for brand compliance
            3. Find best score among all beams
            4. Penalize beams that don't match best score
            5. Return modified logits

        Result:
            Only highest-scoring sequences can continue generation.
            Lower-scoring sequences are effectively eliminated.
        """
        self.step_count += 1

        try:
            # Detect and adapt to device
            device = self._detect_device(scores)

            # Clone scores to avoid modifying original
            output_scores = scores.clone()

            # Evaluate each sequence in the current batch/beams
            sequence_scores = []
            sequence_texts = []

            for seq_idx, sequence in enumerate(input_ids):
                try:
                    # Move sequence to CPU for decoding (tokenizer works on CPU)
                    if sequence.device != torch.device('cpu'):
                        sequence_cpu = sequence.cpu()
                    else:
                        sequence_cpu = sequence

                    # Decode current sequence to readable text
                    decoded_text = self.tokenizer.decode(
                        sequence_cpu,
                        skip_special_tokens=True
                    )
                    sequence_texts.append(decoded_text)

                    # Score this sequence for brand compliance
                    compliance_score = self.evaluate_brand_compliance(decoded_text)
                    sequence_scores.append(compliance_score)

                except Exception as decode_error:
                    # Graceful handling of decode errors
                    if self.verbose:
                        print(f"   ⚠️ Decode error for sequence {seq_idx}: {decode_error}")
                    sequence_scores.append(-999)  # Very low score for errors
                    sequence_texts.append("[DECODE_ERROR]")

            # Find the best compliance score
            if sequence_scores:
                max_score = max(sequence_scores)
            else:
                if self.verbose:
                    print(f"   ⚠️ No valid sequences to evaluate")
                return output_scores

            # Mask out sequences that don't meet the maximum standard
            masked_count = 0
            for seq_idx, score in enumerate(sequence_scores):
                if score < max_score:
                    # Set all next-token logits for this sequence to very negative
                    # This makes these sequences impossible to select
                    penalty = torch.full_like(
                        output_scores[seq_idx],
                        -self.penalty_strength
                    )
                    output_scores[seq_idx] = penalty
                    masked_count += 1

            # Optional verbose logging (first few steps only to avoid spam)
            if self.verbose and self.step_count <= 3:
                print(f"\n📊 Logit Masking Step {self.step_count} (Device: {device}):")
                for i, (text, score) in enumerate(zip(sequence_texts, sequence_scores)):
                    status = "✅ KEPT" if score == max_score else "❌ MASKED"
                    # Show last 50 chars to see current generation progress
                    display_text = text[-50:] if len(text) > 50 else text
                    print(f"   Seq {i+1}: {status} (score: {score:+d}) - '...{display_text}'")

                kept_count = len(sequence_scores) - masked_count
                print(f"   📊 Result: Kept {kept_count}, masked {masked_count} sequences")

            return output_scores

        except Exception as e:
            # Ultimate fallback - return original scores if anything fails
            if self.verbose:
                print(f"❌ LogitsProcessor error: {e}")
            return scores

# ============================================================================
# VOCABULARY DEFINITIONS
# ============================================================================

# Define comprehensive business vocabulary
PREMIUM_WORDS = {
    # Luxury descriptors
    'luxury', 'premium', 'exceptional', 'sophisticated', 'exquisite',
    'masterpiece', 'artisan', 'handcrafted', 'precision', 'heritage',
    'exclusive', 'prestige', 'refined', 'elegant', 'superior',

    # Innovation and quality
    'innovative', 'advanced', 'professional', 'outstanding', 'remarkable',
    'distinguished', 'prestigious', 'world-class', 'finest', 'ultimate'
}

BUDGET_WORDS = {
    # Price-focused terms
    'cheap', 'affordable', 'budget', 'discount', 'value', 'deal', 'bargain',
    'low-cost', 'inexpensive', 'economical', 'reasonable',

    # Quality diminishers
    'basic', 'standard', 'ordinary', 'common', 'typical', 'average',
    'generic', 'simple', 'plain', 'entry-level', 'mass-produced'
}

# ============================================================================
# INITIALIZE BRAND GUARDIAN
# ============================================================================

print("\n🛡️ INITIALIZING ROBUST BRAND GUARDIAN")
print("=" * 60)

brand_guardian = RobustBrandAwareLogitsProcessor(
    pipe.tokenizer,
    PREMIUM_WORDS,
    BUDGET_WORDS,
    penalty_strength=8000.0,  # Slightly lower for stability
    verbose=True  # Enable detailed logging for first 3 steps
)

# ============================================================================
# HELPER FUNCTIONS FOR CLEAN DEMONSTRATION
# ============================================================================

def extract_assistant_content_only(pipeline_output):
    """
    Extract only the assistant's response from pipeline output.

    Purpose:
        Clean up pipeline output to get just the generated text without
        conversation formatting or metadata.

    Args:
        pipeline_output: Raw output from pipeline

    Returns:
        str: Clean generated text
    """
    try:
        if isinstance(pipeline_output, list) and len(pipeline_output) > 0:
            first_item = pipeline_output[0]
            if isinstance(first_item, dict) and 'generated_text' in first_item:
                conversation = first_item['generated_text']

                # If conversation is a list of messages, extract assistant's content
                if isinstance(conversation, list):
                    for message in conversation:
                        if isinstance(message, dict) and message.get('role') == 'assistant':
                            return message.get('content', '').strip()

                # Fallback to string conversion
                elif isinstance(conversation, str):
                    return conversation.strip()

        return str(pipeline_output).strip()

    except Exception as e:
        return str(pipeline_output).strip()

def analyze_word_usage(text, premium_words, budget_words):
    """
    Analyze premium and budget word usage in generated text.

    Purpose:
        Quantify vocabulary compliance for evaluation and reporting.

    Args:
        text (str): Generated text to analyze
        premium_words (set): Premium vocabulary
        budget_words (set): Budget vocabulary

    Returns:
        dict: Analysis with keys:
            - premium_words: List of premium words found
            - budget_words: List of budget words found
            - score: Net score (premium - budget)
            - total_premium: Count of premium words
            - total_budget: Count of budget words
    """
    text_lower = text.lower()

    # Find premium words used
    premium_found = []
    for word in premium_words:
        if word in text_lower:
            count = text_lower.count(word)
            premium_found.append(f"{word} ({count}x)" if count > 1 else word)

    # Find budget words used
    budget_found = []
    for word in budget_words:
        if word in text_lower:
            count = text_lower.count(word)
            budget_found.append(f"{word} ({count}x)" if count > 1 else word)

    # Calculate score
    score = len([w for w in premium_words if w in text_lower]) - \
            len([w for w in budget_words if w in text_lower])

    return {
        'premium_words': premium_found,
        'budget_words': budget_found,
        'score': score,
        'total_premium': len(premium_found),
        'total_budget': len(budget_found)
    }


🛡️ INITIALIZING ROBUST BRAND GUARDIAN
RobustBrandAware Processor initialized:
   ✅ Premium words: 25
   ❌ Budget words: 22
   🔧 Penalty strength: 8000.0
   📊 Verbose mode: True


In [None]:
# ============================================================================
# PART 8: DEMONSTRATION - BEFORE AND AFTER
# ============================================================================

def clean_logit_masking_demonstration(product_name):
    """
    Clean demonstration of logit masking effectiveness.

    Purpose:
        Show clear before/after comparison of generation with and without
        logit masking constraints.

    Args:
        product_name (str): Product to generate description for

    Returns:
        dict: Results with baseline and controlled analysis

    Expected Output:
        Baseline: May contain budget words, lower brand score
        Controlled: Only premium words, higher brand score

    Key Metrics:
        - Brand Score: (premium words) - (budget words)
        - Premium Word Count
        - Budget Word Count
        - Improvement: Change in score
    """

    print(f"\n" + "="*80)
    print(f"LOGIT MASKING DEMONSTRATION: {product_name.upper()}")
    print("="*80)

    # Create professional prompt
    messages = [
        {
            "role": "system",
            "content": "You are a professional product marketing expert. Write compelling product descriptions."
        },
        {
            "role": "user",
            "content": f"Write a professional product description for this {product_name}. Focus on premium quality and craftsmanship."
        }
    ]

    # Standard generation parameters
    gen_params = {
        "max_new_tokens": 80,
        "do_sample": True,
        "temperature": 0.8,
        "num_beams": 5,
        "early_stopping": True,
        "pad_token_id": pipe.tokenizer.eos_token_id
    }

    # ========================================
    # BASELINE: Generate WITHOUT control
    # ========================================
    print(f"\n📝 BASELINE GENERATION (No Control)")
    print("-" * 50)

    try:
        baseline_output = pipe(messages, **gen_params)
        baseline_text = extract_assistant_content_only(baseline_output)
        baseline_analysis = analyze_word_usage(baseline_text, PREMIUM_WORDS, BUDGET_WORDS)

        print(f"Generated Text:")
        print(f'"{baseline_text}"')
        print(f"\nWord Analysis:")
        print(f"  ✅ Premium words: {', '.join(baseline_analysis['premium_words']) if baseline_analysis['premium_words'] else 'None'}")
        print(f"  ❌ Budget words:  {', '.join(baseline_analysis['budget_words']) if baseline_analysis['budget_words'] else 'None'}")
        print(f"  📊 Brand Score: {baseline_analysis['score']:+d} "
              f"({baseline_analysis['total_premium']} premium - "
              f"{baseline_analysis['total_budget']} budget)")

    except Exception as e:
        print(f"❌ Baseline generation failed: {e}")
        return None

    # ========================================
    # CONTROLLED: Generate WITH logit masking
    # ========================================
    print(f"\n🎛️ CONTROLLED GENERATION (With Logit Masking)")
    print("-" * 50)

    try:
        # Create processor (non-verbose for clean output)
        processor = RobustBrandAwareLogitsProcessor(
            pipe.tokenizer,
            PREMIUM_WORDS,
            BUDGET_WORDS,
            verbose=False  # Disable logging for clean demo
        )

        # Add processor to generation parameters
        controlled_params = gen_params.copy()
        controlled_params["logits_processor"] = [processor]

        controlled_output = pipe(messages, **controlled_params)
        controlled_text = extract_assistant_content_only(controlled_output)
        controlled_analysis = analyze_word_usage(controlled_text, PREMIUM_WORDS, BUDGET_WORDS)

        print(f"Generated Text:")
        print(f'"{controlled_text}"')
        print(f"\nWord Analysis:")
        print(f"  ✅ Premium words: {', '.join(controlled_analysis['premium_words']) if controlled_analysis['premium_words'] else 'None'}")
        print(f"  ❌ Budget words:  {', '.join(controlled_analysis['budget_words']) if controlled_analysis['budget_words'] else 'None'}")
        print(f"  📊 Brand Score: {controlled_analysis['score']:+d} "
              f"({controlled_analysis['total_premium']} premium - "
              f"{controlled_analysis['total_budget']} budget)")

    except Exception as e:
        print(f"❌ Controlled generation failed: {e}")
        return None

    # ========================================
    # IMPACT ANALYSIS
    # ========================================
    improvement = controlled_analysis['score'] - baseline_analysis['score']
    premium_increase = controlled_analysis['total_premium'] - baseline_analysis['total_premium']
    budget_decrease = baseline_analysis['total_budget'] - controlled_analysis['total_budget']

    print(f"\n📈 IMPACT ANALYSIS")
    print("-" * 50)
    print(f"Brand Score Change:    {baseline_analysis['score']:+d} → "
          f"{controlled_analysis['score']:+d} ({improvement:+d} points)")
    print(f"Premium Words:         {baseline_analysis['total_premium']} → "
          f"{controlled_analysis['total_premium']} ({premium_increase:+d} words)")
    print(f"Budget Words:          {baseline_analysis['total_budget']} → "
          f"{controlled_analysis['total_budget']} ({budget_decrease:+d} words)")

    if improvement > 0:
        improvement_pct = ((controlled_analysis['score'] - baseline_analysis['score']) /
                          max(1, abs(baseline_analysis['score']))) * 100
        print(f"Overall Improvement:   {improvement_pct:+.0f}% more brand-compliant")

    # Key insights
    print(f"\n🔍 KEY INSIGHTS")
    print("-" * 50)

    # What was promoted
    new_premium_words = set(controlled_analysis['premium_words']) - set(baseline_analysis['premium_words'])
    if new_premium_words:
        print(f"✅ Promoted words: {', '.join(new_premium_words)}")

    # What was suppressed
    suppressed_budget_words = set(baseline_analysis['budget_words']) - set(controlled_analysis['budget_words'])
    if suppressed_budget_words:
        print(f"❌ Suppressed words: {', '.join(suppressed_budget_words)}")

    if not new_premium_words and not suppressed_budget_words and improvement > 0:
        print(f"🎯 Enhanced premium word density without suppressing specific terms")
    elif improvement == 0:
        print(f"📊 Baseline was already well-optimized for this product")

    return {
        'baseline': baseline_analysis,
        'controlled': controlled_analysis,
        'improvement': improvement,
        'baseline_text': baseline_text,
        'controlled_text': controlled_text
    }

# ============================================================================
# RUN COMPREHENSIVE DEMONSTRATIONS
# ============================================================================

def run_comprehensive_clean_demo():
    """
    Run multiple product demonstrations to show consistency.

    Purpose:
        Demonstrate that logit masking works consistently across
        different product categories and contexts.

    Returns:
        list: Results from all test cases

    Expected Output:
        Consistent improvement in brand compliance scores across
        all product categories.
    """

    print(f"\n" + "="*80)
    print(f"COMPREHENSIVE LOGIT MASKING ANALYSIS")
    print("="*80)
    print(f"Testing controlled text generation across multiple luxury product categories")

    test_products = [
        "Swiss luxury watch with leather strap",
        "Italian leather handbag with gold accents",
        "premium wireless headphones",
        "artisan fountain pen with gold nib"
    ]

    results = []
    total_improvement = 0

    for i, product in enumerate(test_products, 1):
        print(f"\n{'='*20} TEST {i}/{len(test_products)} {'='*20}")

        result = clean_logit_masking_demonstration(product)

        if result:
            results.append({
                'product': product,
                'improvement': result['improvement'],
                'baseline_score': result['baseline']['score'],
                'controlled_score': result['controlled']['score']
            })
            total_improvement += result['improvement']

        print(f"{'='*60}")

    # Summary analysis
    if results:
        avg_improvement = total_improvement / len(results)
        successful_tests = len(results)

        print(f"\n📊 EXECUTIVE SUMMARY")
        print("="*60)
        print(f"Tests Completed:       {successful_tests}/{len(test_products)}")
        print(f"Average Improvement:   {avg_improvement:+.1f} points")
        print(f"Total Improvement:     {total_improvement:+d} points")

        # Performance by product
        print(f"\nPerformance by Product:")
        for result in results:
            improvement_indicator = "🚀" if result['improvement'] > 0 else "➖"
            print(f"  {improvement_indicator} {result['product']:<45} "
                  f"{result['baseline_score']:+2d} → "
                  f"{result['controlled_score']:+2d} "
                  f"({result['improvement']:+d})")

        # Key findings
        best_improvement = max(results, key=lambda x: x['improvement'])
        print(f"\n🏆 Best Performance: {best_improvement['product']} "
              f"({best_improvement['improvement']:+d} points)")

        print(f"\n✅ SUCCESS: Logit masking consistently improves brand compliance")
        print(f"   Average improvement of {avg_improvement:.1f} points across all categories")

    else:
        print(f"\n❌ No successful demonstrations completed")

    return results

# Run the comprehensive demonstration
print("\n🎯 RUNNING LOGIT MASKING DEMONSTRATIONS")
print("="*60)

comprehensive_results = run_comprehensive_clean_demo()


🎯 RUNNING LOGIT MASKING DEMONSTRATIONS

COMPREHENSIVE LOGIT MASKING ANALYSIS
Testing controlled text generation across multiple luxury product categories


LOGIT MASKING DEMONSTRATION: SWISS LUXURY WATCH WITH LEATHER STRAP

📝 BASELINE GENERATION (No Control)
--------------------------------------------------
Generated Text:
"**Introducing the Timeless Masterpiece: Swiss Luxury Watch with Leather Strap**

Elevate your style and sophistication with our exquisite Swiss luxury watch, expertly crafted to meet the highest standards of precision and elegance. This masterpiece is a testament to the art of watchmaking, where every detail is meticulously considered to create a truly exceptional timepiece.

**Crafted with Precision and Passion**

Our Swiss luxury"

Word Analysis:
  ✅ Premium words: exquisite, precision (2x), masterpiece (2x), luxury (3x), exceptional
  ❌ Budget words:  standard
  📊 Brand Score: +4 (5 premium - 1 budget)

🎛️ CONTROLLED GENERATION (With Logit Masking)
------------