<a href="https://colab.research.google.com/github/kelanmail-create/colabs/blob/main/deocde_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, clear_output
import ipywidgets as widgets
import matplotlib.pyplot as plt

In [None]:
# ---------- Load GPT Model ----------
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
EOS_ID = model.config.eos_token_id
if EOS_ID is None:
    EOS_ID = tokenizer.eos_token_id  # GPT-2 50256

stop_at_eos_cb = widgets.Checkbox(value=True, description="Stop when EOS is sampled")
show_special_in_labels_cb = widgets.Checkbox(value=True, description="Show EOS in plot labels")

Temapurture
$$
P'(t_i) = \frac{P(t_i)^{1/T}}{\sum_{j} P(t_j)^{1/T}}
$$

Top - k

$$
V_k = \text{TopK}\big(P(t_i \mid \text{context}),\, k\big)
$$

$$
t_{\text{next}} \sim P(t_i \mid t_i \in V_k)
$$

Top - p

$$
V_p = \{\, t_i \mid \sum_{j=1}^{i} P(t_j) \ge p \,\}
$$

$$
t_{\text{next}} \sim P(t_i \mid t_i \in V_p)
$$

In [None]:
@torch.no_grad()
def top_k_filtering(logits, top_k=0):
    if top_k > 0:
        v, _ = torch.topk(logits, top_k)
        min_keep = v[..., -1, None]
        logits = torch.where(logits < min_keep, torch.tensor(float('-inf'), device=logits.device), logits)
    return logits

@torch.no_grad()
def top_p_filtering(logits, top_p=1.0):
    if top_p < 1.0:
        sorted_logits, sorted_idx = torch.sort(logits, descending=True)
        probs = F.softmax(sorted_logits, dim=-1)
        cumprobs = torch.cumsum(probs, dim=-1)
        mask = cumprobs > top_p
        mask[..., 0] = False
        sorted_logits = sorted_logits.masked_fill(mask, float('-inf'))
        scatter = torch.full_like(logits, float('-inf'))
        logits = scatter.scatter(-1, sorted_idx, sorted_logits)
    return logits

@torch.no_grad()
def sample_from_logits(logits):
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1), probs

def greedy_strategy(logits, **kwargs):
    return torch.argmax(logits, dim=-1, keepdim=True), F.softmax(logits, dim=-1)

def temperature_strategy(logits, temperature=1.0, **kwargs):
    logits = logits / max(temperature, 1e-8)
    return sample_from_logits(logits)

def topk_strategy(logits, top_k=40, temperature=1.0, **kwargs):
    logits = top_k_filtering(logits, top_k)
    logits = logits / max(temperature, 1e-8)
    return sample_from_logits(logits)

def topp_strategy(logits, top_p=0.9, temperature=1.0, **kwargs):
    logits = top_p_filtering(logits, top_p)
    logits = logits / max(temperature, 1e-8)
    return sample_from_logits(logits)

STRATEGY_MAP = {
    "greedy": greedy_strategy,
    "temperature": temperature_strategy,
    "topk": topk_strategy,
    "topp": topp_strategy,
}


class StepDecoder:
    def __init__(self, model, tokenizer, prompt, device="cpu"):
        self.model = model.eval().to(device)
        self.tok = tokenizer
        self.device = device
        self.input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        with torch.no_grad():
            out = self.model(self.input_ids, use_cache=True)
        self.past = out.past_key_values
        self.generated = self.input_ids.clone()
        self.finished = False

    @torch.no_grad()
    def step(self, strategy="greedy", strategy_kwargs=None):
        if self.finished:
            return None
        strategy_kwargs = strategy_kwargs or {}
        last_token = self.generated[:, -1:]
        out = self.model(input_ids=last_token, use_cache=True, past_key_values=self.past)
        logits = out.logits[:, -1, :]                        # [1, vocab]
        self.past = out.past_key_values

        fn = STRATEGY_MAP[strategy]
        next_id, probs = fn(logits, **strategy_kwargs)       # next_id: [1,1], probs: [1, vocab]

        self.generated = torch.cat([self.generated, next_id], dim=-1)

        next_id_int = int(next_id.item())
        next_prob = float(probs[0, next_id_int].item())

        if EOS_ID is not None and next_id_int == EOS_ID:
          next_tok = "<eos>"
          if stop_at_eos_cb.value:
              self.finished = True
        else:
            next_tok = self.tok.decode([next_id_int], skip_special_tokens=True)

        return {
            "probs": probs[0].detach().cpu(),      # [vocab]
            "next_id": next_id_int,
            "next_prob": next_prob,
            "next_token_str": next_tok,
            "text_so_far": self.tok.decode(self.generated[0], skip_special_tokens=True),
        }
prompt_box = widgets.Text(
    value="Q: What is the capital of France?\n A:",
    description="Prompt:",
    layout=widgets.Layout(width="100%")
)

strategy_dd = widgets.Dropdown(
    options=["greedy", "temperature", "topk", "topp"],
    value="topp",
    description="Strategy:"
)

temperature_sl = widgets.FloatSlider(
    value=0.8, min=0.1, max=2.0, step=0.05, description="Temp"
)
topk_sl = widgets.IntSlider(
    value=40, min=1, max=200, step=1, description="Top-K"
)
topp_sl = widgets.FloatSlider(
    value=0.9, min=0.1, max=1.0, step=0.01, description="Top-P"
)
topN_sl = widgets.IntSlider(
    value=20, min=5, max=100, step=1, description="Plot Top-N"
)

init_btn = widgets.Button(description="Initialize / Reset", button_style="")
step_btn = widgets.Button(description="Step ▷ (1 token)", button_style="")

out_plot = widgets.Output()
out_text = widgets.Output()

def _toggle_param_visibility(*args):
    temperature_sl.layout.display = "none"
    topk_sl.layout.display = "none"
    topp_sl.layout.display = "none"
    if strategy_dd.value == "temperature":
        temperature_sl.layout.display = ""
    elif strategy_dd.value == "topk":
        temperature_sl.layout.display = ""
        topk_sl.layout.display = ""
    elif strategy_dd.value == "topp":
        temperature_sl.layout.display = ""
        topp_sl.layout.display = ""

strategy_dd.observe(_toggle_param_visibility, names="value")
_toggle_param_visibility()

decoder_state = {"decoder": None}

def build_kwargs():
    s = strategy_dd.value
    if s == "greedy":
        return {}
    if s == "temperature":
        return {"temperature": float(temperature_sl.value)}
    if s == "topk":
        return {"temperature": float(temperature_sl.value), "top_k": int(topk_sl.value)}
    if s == "topp":
        return {"temperature": float(temperature_sl.value), "top_p": float(topp_sl.value)}
    return {}

def on_init_clicked(_):
    decoder_state["decoder"] = StepDecoder(model.to(device), tokenizer, prompt_box.value, device=device)
    with out_text:
        clear_output()
        print("Decoder is ready. Click 'Step ▷' to generate one token at a time.")
        print(f"Prompt: {prompt_box.value!r}")
    with out_plot:
        clear_output()
        # 初始空图
        fig = plt.figure(figsize=(8, 4))
        plt.title("Token probability distribution (click 'Step ▷' to start)")
        plt.xlabel("token")
        plt.ylabel("probability")
        plt.xticks(rotation=60)
        plt.tight_layout()
        plt.show()

def on_step_clicked(_):
    dec = decoder_state.get("decoder", None)
    if dec is None:
        on_init_clicked(None)
        dec = decoder_state["decoder"]

    result = dec.step(strategy=strategy_dd.value, strategy_kwargs=build_kwargs())
    if result is None:
        return

    probs = result["probs"]        # [vocab] on CPU
    next_id = result["next_id"]
    next_prob = result["next_prob"]
    next_tok = result["next_token_str"]
    text_so_far = result["text_so_far"]

    topN = int(topN_sl.value)
    top_vals, top_idx = torch.topk(probs, k=topN)
    top_vals = top_vals.tolist()
    top_idx = top_idx.tolist()
    top_tokens = [tokenizer.decode([i], skip_special_tokens=True) or f"<{i}>" for i in top_idx]

    with out_plot:
        clear_output(wait=True)
        fig = plt.figure(figsize=(10, 4))
        plt.bar(range(len(top_vals)), top_vals)
        plt.xticks(range(len(top_tokens)), top_tokens, rotation=60)
        plt.xlabel("token")
        plt.ylabel("probability")
        plt.title(f"Top-{topN} probs | sampled: id={next_id}, token={repr(next_tok)}, p={next_prob:.4f}")
        if next_id in top_idx:
            j = top_idx.index(next_id)
            plt.annotate(f"◀ sampled ({next_prob:.3f})", xy=(j, top_vals[j]), xytext=(j, max(top_vals)*1.05),
                         ha="center", arrowprops=dict(arrowstyle="-"))
        plt.tight_layout()
        plt.show()

    with out_text:
        print(f"[{strategy_dd.value}] sampled: {repr(next_tok)} (id={next_id}, p={next_prob:.4f})")
        print("Text so far:")
        print(text_so_far)
        print("-" * 60)

init_btn.on_click(on_init_clicked)
step_btn.on_click(on_step_clicked)

controls_row1 = widgets.HBox([prompt_box])
controls_row2 = widgets.HBox([strategy_dd, temperature_sl, topk_sl, topp_sl, topN_sl, stop_at_eos_cb, show_special_in_labels_cb])
controls_row3 = widgets.HBox([init_btn, step_btn])

display(controls_row1, controls_row2, controls_row3, out_plot, out_text)

on_init_clicked(None)