In [None]:
# Pre req,  Use a GPU T4 minimum ,have a Hugging face Token ,  Accept the Gemma 3 model terms
# Run the first cells to install the requirements and then change the sentence in the box to discover the probabilities of next token
# "Peeking Inside the Black Box" An interactive educational tool that visualizes the decision-making process of the Gemma 3 1B model
# allowing users to write stories by navigating the model's probability tree one branch at a time.

In [None]:
# @title 1. Install Libraries and Login
!pip install -q transformers torch accelerate bitsandbytes

from huggingface_hub import login
from IPython.display import clear_output

# Login to Hugging Face (Required for Gemma models)
print("Please paste your Hugging Face token when prompted:")
login()
print("Libraries installed and logged in successfully!")

In [None]:
# @title 2. Load Gemma 3 1B Model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# We use Gemma 2 2B (it fits easily in Colab Free Tier RAM)
model_id = "google/gemma-3-1b-it"

print("Loading model... this may take a minute...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.float16
)

print("Model loaded!")

In [None]:
# @title Interactive Probability Explorer
import torch.nn.functional as F
import json
from google.colab import output
from IPython.display import display, HTML

# --- 1. Python Backend (The Brain) ---
def get_next_token_probs_material(current_text):
    # Tokenize
    inputs = tokenizer(current_text, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model(**inputs)

    # Calculate Probabilities
    next_token_logits = outputs.logits[0, -1, :]
    probs = F.softmax(next_token_logits, dim=-1)

    # Get Top 5
    top_k = 5
    top_probs, top_indices = torch.topk(probs, top_k)

    candidates = []
    for i in range(top_k):
        token_id = top_indices[i].item()
        token_text = tokenizer.decode([token_id])
        probability = top_probs[i].item()

        candidates.append({
            "raw_text": token_text,
            "display_text": token_text.replace("\n", "\\n").strip(),
            "prob": round(probability * 100, 2)
        })

    # FIX: ensure_ascii=False ensures 'è' stays 'è' and doesn't become '\u00e8'
    return json.dumps(candidates, ensure_ascii=False)

output.register_callback('get_next_token_probs_material', get_next_token_probs_material)

# --- 2. The Frontend (Material Design HTML/CSS/JS) ---
html_code = """
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap" rel="stylesheet">
<style>
    :root {
        --google-blue: #4285F4;
        --google-grey-bg: #F8F9FA;
        --text-primary: #202124;
        --text-secondary: #5f6368;
        --card-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24);
    }

    body {
        font-family: 'Roboto', sans-serif;
        background-color: #fff;
        margin: 0;
        padding: 20px;
        color: var(--text-primary);
    }

    .app-card {
        max-width: 800px;
        margin: 0 auto;
        background: white;
        border-radius: 12px;
        box-shadow: var(--card-shadow);
        padding: 40px;
        min-height: 300px;
        display: flex;
        flex-direction: column;
    }

    .header {
        font-size: 14px;
        text-transform: uppercase;
        letter-spacing: 1px;
        color: var(--text-secondary);
        margin-bottom: 30px;
        font-weight: 700;
        border-bottom: 1px solid #eee;
        padding-bottom: 15px;
    }

    /* --- INPUT SECTION STYLES --- */
    #input-section {
        display: flex;
        flex-direction: column;
        gap: 15px;
        align-items: flex-start;
    }

    .input-label {
        font-size: 18px;
        color: var(--text-primary);
    }

    .custom-input {
        width: 100%;
        padding: 15px;
        font-size: 20px;
        border: 2px solid #eee;
        border-radius: 8px;
        font-family: 'Roboto', sans-serif;
        outline: none;
        transition: border-color 0.2s;
        box-sizing: border-box;
    }

    .custom-input:focus {
        border-color: var(--google-blue);
    }

    .start-btn {
        background-color: var(--google-blue);
        color: white;
        border: none;
        padding: 10px 24px;
        font-size: 16px;
        border-radius: 4px;
        cursor: pointer;
        font-weight: 500;
        box-shadow: 0 1px 3px rgba(0,0,0,0.2);
        transition: background 0.2s;
    }
    .start-btn:hover {
        background-color: #3367d6;
    }

    /* --- DISPLAY SECTION STYLES --- */
    #display-section {
        display: none; /* Hidden initially */
    }

    #sentence-container {
        font-size: 28px;
        line-height: 1.5;
        font-weight: 300;
        color: var(--text-primary);
    }

    .word-token {
        transition: background 0.2s;
        border-radius: 4px;
        padding: 0 2px;
    }
    .word-token:hover { background-color: #e8f0fe; }

    /* Chip & Dropdown */
    .next-token-chip {
        display: inline-flex;
        align-items: center;
        background-color: var(--google-blue);
        color: white;
        padding: 6px 16px;
        border-radius: 24px;
        font-size: 18px;
        font-weight: 500;
        cursor: pointer;
        margin-left: 8px;
        position: relative;
        top: -3px;
        user-select: none;
    }
    .next-token-chip:hover { background-color: #3367d6; }

    .dropdown-menu {
        display: none;
        position: absolute;
        background-color: white;
        min-width: 320px;
        border-radius: 8px;
        box-shadow: 0 5px 20px rgba(0,0,0,0.2);
        z-index: 100;
        margin-top: 12px;
    }

    .option {
        padding: 12px 16px;
        border-bottom: 1px solid #f1f3f4;
        cursor: pointer;
    }
    .option:hover { background-color: #f8f9fa; }

    .option-row { display: flex; justify-content: space-between; margin-bottom: 6px; }
    .token-text { font-size: 16px; font-weight: 500; }
    .token-prob { font-size: 12px; color: var(--text-secondary); background: #f1f3f4; padding: 2px 6px; border-radius: 4px; }
    .progress-bg { height: 6px; background: #e0e0e0; border-radius: 3px; width: 100%; }
    .progress-fill { height: 100%; background: var(--google-blue); border-radius: 3px; }

</style>
</head>
<body>

<div class="app-card">
    <div class="header">Gemma 3 Probability Explorer</div>

    <div id="input-section">
        <label class="input-label">How should the sentence start?</label>
        <input type="text" id="start-input" class="custom-input" placeholder="Type here..." value="The future of AI is">
        <button class="start-btn" onclick="initializeExplorer()">Start Exploring</button>
    </div>

    <div id="display-section">
        <span id="sentence-container"></span>

        <div style="position:relative; display:inline-block;">
            <div id="chip" class="next-token-chip" onclick="toggleMenu()">
                Predict ▼
            </div>
            <div id="menu" class="dropdown-menu">
                <div style="padding:12px; background:#f1f3f4; font-size:12px; color:#555;">TOP PREDICTIONS</div>
                <div id="menu-content"></div>
            </div>
        </div>

        <div style="margin-top: 40px; border-top: 1px solid #eee; padding-top: 20px;">
             <button class="start-btn" style="background:#888; font-size:14px; padding: 6px 12px;" onclick="resetExplorer()">Reset</button>
        </div>
    </div>
</div>

<script>
    let currentSentence = "";
    let isMenuOpen = false;

    // --- State Management ---
    function initializeExplorer() {
        const inputVal = document.getElementById('start-input').value;
        if(!inputVal.trim()) return;
        currentSentence = inputVal;

        document.getElementById('input-section').style.display = 'none';
        document.getElementById('display-section').style.display = 'block';
        document.getElementById('sentence-container').innerText = currentSentence;

        fetchPredictions();
    }

    function resetExplorer() {
        document.getElementById('display-section').style.display = 'none';
        document.getElementById('input-section').style.display = 'flex';
        document.getElementById('menu-content').innerHTML = '';
        isMenuOpen = false;
        document.getElementById('menu').style.display = 'none';
    }

    // --- Interaction Logic ---
    async function toggleMenu() {
        const menu = document.getElementById('menu');
        if (isMenuOpen) {
            menu.style.display = 'none';
            isMenuOpen = false;
        } else {
            if(document.getElementById('menu-content').innerHTML === "") {
                await fetchPredictions();
            }
            menu.style.display = 'block';
            isMenuOpen = true;
        }
    }

    async function fetchPredictions() {
        const chip = document.getElementById('chip');
        chip.innerHTML = 'Thinking...';

        // Call Python
        const result = await google.colab.kernel.invokeFunction('get_next_token_probs_material', [currentSentence], {});

        // Remove quotes around the JSON string if present
        let rawData = result.data['text/plain'];
        if (rawData.startsWith("'") || rawData.startsWith('"')) {
            rawData = rawData.slice(1, -1);
        }

        // Parse JSON (now containing unescaped characters like 'è')
        const candidates = JSON.parse(rawData);

        renderMenu(candidates);
        chip.innerHTML = 'Predict ▼';
    }

    function renderMenu(candidates) {
        const content = document.getElementById('menu-content');
        content.innerHTML = '';

        candidates.forEach(cand => {
            const item = document.createElement('div');
            item.className = 'option';

            // Note: display_text now contains raw UTF-8 chars, HTML renders them natively
            item.innerHTML = `
                <div class="option-row">
                    <span class="token-text">"${cand.display_text}"</span>
                    <span class="token-prob">${cand.prob}%</span>
                </div>
                <div class="progress-bg">
                    <div class="progress-fill" style="width: ${cand.prob}%"></div>
                </div>
            `;
            item.onclick = () => selectWord(cand.raw_text);
            content.appendChild(item);
        });
    }

    async function selectWord(word) {
        currentSentence += word;

        const container = document.getElementById('sentence-container');
        const span = document.createElement('span');
        span.className = 'word-token';
        span.innerText = word;
        container.appendChild(span);

        document.getElementById('menu').style.display = 'none';
        document.getElementById('menu-content').innerHTML = '';
        isMenuOpen = false;

        await fetchPredictions();
    }
</script>

</body>
</html>
"""

display(HTML(html_code))
output.eval_js('google.colab.output.setIframeHeight("650px");')