In [1]:
!pip install transformers torch


Collecting transformers
  Downloading transformers-4.57.3-py3-none-any.whl.metadata (43 kB)
Collecting torch
  Downloading torch-2.9.1-cp312-none-macosx_11_0_arm64.whl.metadata (30 kB)
Collecting filelock (from transformers)
  Using cached filelock-3.20.0-py3-none-any.whl.metadata (2.1 kB)
Collecting huggingface-hub<1.0,>=0.34.0 (from transformers)
  Downloading huggingface_hub-0.36.0-py3-none-any.whl.metadata (14 kB)
Collecting pyyaml>=5.1 (from transformers)
  Using cached pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (2.4 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2025.11.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (40 kB)
Collecting requests (from transformers)
  Using cached requests-2.32.5-py3-none-any.whl.metadata (4.9 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl.metadata (6.8 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Downloading safetensors-0.

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F

MODEL_NAME = "distilbert/distilgpt2"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval()

def next_token_distribution(prompt: str, top_k: int = 10):
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        # logits shape: [batch, seq_len, vocab_size]
        logits = outputs.logits[:, -1, :]  # last position
        probs = F.softmax(logits, dim=-1).squeeze(0)  # [vocab_size]

    # Get top-k tokens
    topk_probs, topk_indices = torch.topk(probs, k=top_k)
    tokens = [tokenizer.decode([idx.item()]) for idx in topk_indices]
    return list(zip(tokens, topk_probs.tolist()))

prompt = "Machine learning is"
for tok, p in next_token_distribution(prompt, top_k=100):
    print(f"{repr(tok):>12}  {p:.4f}")


        ' a'  0.1674
      ' the'  0.0713
      ' not'  0.0525
       ' an'  0.0443
      ' one'  0.0256
     ' very'  0.0158
     ' also'  0.0126
     ' more'  0.0122
     ' just'  0.0113
     ' what'  0.0097
    ' about'  0.0095
      ' now'  0.0084
 ' becoming'  0.0083
' important'  0.0080
      ' all'  0.0071
    ' based'  0.0071
    ' still'  0.0070
     ' like'  0.0066
' something'  0.0066
       ' in'  0.0063
     ' part'  0.0062
   ' really'  0.0061
      ' how'  0.0058
     ' much'  0.0058
    ' going'  0.0056
    ' often'  0.0056
  ' another'  0.0056
' essential'  0.0054
      ' key'  0.0049
 ' critical'  0.0047
     ' hard'  0.0046
     ' done'  0.0044
' extremely'  0.0044
   ' simple'  0.0043
     ' easy'  0.0040
     ' only'  0.0038
       ' as'  0.0038
     ' that'  0.0037
       ' no'  0.0036
       ' at'  0.0036
      ' our'  0.0036
       ' so'  0.0034
       ' to'  0.0033
    ' great'  0.0033
  ' complex'  0.0032
    ' being'  0.0031
' difficult'  0.0031
     ' used' 

In [None]:
question = ["the", "life", "is"]

ans_so_far = ' '.join(question)
for i in range(50):
    nextword, p = next_token_distribution(ans_so_far, top_k=1)[0]
    ans_so_far = ans_so_far + ' ' + nextword

print(ans_so_far)

the life  of  the  world .  .  .                                                                                        
