In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
base_model_name = "Qwen/Qwen3-4B-Thinking-2507"

print('... started loading...')
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    torch_dtype="auto",
    device_map="auto"
)

... started loading...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 32.44it/s]
Some parameters are on the meta device because they were offloaded to the cpu and disk.


In [4]:
base_model.config

Qwen3Config {
  "architectures": [
    "Qwen3ForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "dtype": "bfloat16",
  "eos_token_id": 151645,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 2560,
  "initializer_range": 0.02,
  "intermediate_size": 9728,
  "layer_types": [
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    "full_attention",
    

In [7]:
base_model.config.intermediate_size

9728

In [1]:
import torch

In [4]:
cache_position = torch.arange(5, 5 + 3, device='cpu')

In [5]:
cache_position

tensor([5, 6, 7])

In [6]:
def sample_tokens(logits, temperature=0.6, top_p=0.95, top_k=20, min_p=0.0):
    """Sample from logits with temperature, min-p, top-k, and top-p filtering.

    Args:
        logits: (batch_size, vocab_size) raw logits for the next token.
        temperature: Scaling factor applied before softmax.
        top_p: Nucleus sampling threshold.
        top_k: Keep only top-k logits before sampling.
        min_p: Minimum probability relative to the max probability token.

    Returns:
        (batch_size, 1) sampled token ids.
    """
    # 1. Temperature
    if temperature == 0: # greedy
        return torch.argmax(logits, dim=-1, keepdim=True)
    
    logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)

    # 2. Min-p filtering
    if min_p > 0.0:
        max_prob = probs.max(dim=-1, keepdim=True).values
        probs = probs.masked_fill(probs < min_p * max_prob, 0.0)
        # Re-normalize
        probs = probs / probs.sum(dim=-1, keepdim=True)

    # 3. Top-k filtering
    if top_k > 0:
        top_k_values, _ = torch.topk(probs, top_k, dim=-1)
        min_top_k = top_k_values[:, -1, None]
        probs = probs.masked_fill(probs < min_top_k, 0.0)
        # Re-normalize
        probs = probs / probs.sum(dim=-1, keepdim=True)

    # 4. Top-p (nucleus) filtering
    if top_p < 1.0:
        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        mask = cumulative_probs - sorted_probs > top_p
        sorted_probs[mask] = 0.0
        probs = torch.zeros_like(probs).scatter(-1, sorted_indices, sorted_probs)
        # Re-normalize
        probs = probs / probs.sum(dim=-1, keepdim=True)

    next_tokens = torch.multinomial(probs, num_samples=1)
    return next_tokens

In [7]:
logits = torch.randn(2, 3)
logits

tensor([[-0.6125,  0.4183, -1.7656],
        [-2.6121, -0.6260,  0.2224]])

In [8]:
temperature = 0.6
logits = logits/temperature
logits

tensor([[-1.0209,  0.6972, -2.9427],
        [-4.3535, -1.0433,  0.3707]])

In [9]:
probs = torch.softmax(logits, dim=-1)
probs

tensor([[0.1488, 0.8294, 0.0218],
        [0.0071, 0.1942, 0.7987]])

#### min p

In [10]:
max_prob = probs.max(dim=-1, keepdim=True).values
max_prob

tensor([[0.8294],
        [0.7987]])

In [11]:
min_p = 0.1
probs.masked_fill(probs < min_p * max_prob, 0.0)

tensor([[0.1488, 0.8294, 0.0000],
        [0.0000, 0.1942, 0.7987]])

#### Top K

In [14]:
top_k = 2
top_k_values, top_k_indices = torch.topk(probs, top_k, dim=-1)
top_k_values, top_k_indices

(tensor([[0.8294, 0.1488],
         [0.7987, 0.1942]]),
 tensor([[1, 0],
         [2, 1]]))

In [21]:
min_top_k = top_k_values[:, -1, None]
min_top_k

tensor([[0.1488],
        [0.1942]])

In [22]:
probs.masked_fill(probs < min_top_k, 0.0)

tensor([[0.1488, 0.8294, 0.0000],
        [0.0000, 0.1942, 0.7987]])

#### Top P

In [25]:
top_p=0.90
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
sorted_probs, sorted_indices

(tensor([[0.8294, 0.1488, 0.0218],
         [0.7987, 0.1942, 0.0071]]),
 tensor([[1, 0, 2],
         [2, 1, 0]]))

In [24]:
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
cumulative_probs

tensor([[0.8294, 0.9782, 1.0000],
        [0.7987, 0.9929, 1.0000]])

In [44]:
mask = cumulative_probs - sorted_probs > top_p
mask

tensor([[False, False,  True],
        [False, False,  True]])

In [42]:
sorted_probs[mask] = 0.0
sorted_probs

tensor([[0.8294, 0.1488, 0.0000],
        [0.7987, 0.1942, 0.0000]])

In [43]:
torch.zeros_like(probs).scatter(1, sorted_indices, sorted_probs)

tensor([[0.1488, 0.8294, 0.0000],
        [0.0000, 0.1942, 0.7987]])