In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model.eval()


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [3]:
from torchinfo import summary

summary(model, depth=7)

Layer (type:depth-idx)                                  Param #
LlamaForCausalLM                                        --
├─LlamaModel: 1-1                                       --
│    └─Embedding: 2-1                                   262,668,288
│    └─ModuleList: 2-2                                  --
│    │    └─LlamaDecoderLayer: 3-1                      --
│    │    │    └─LlamaSdpaAttention: 4-1                --
│    │    │    │    └─Linear: 5-1                       4,194,304
│    │    │    │    └─Linear: 5-2                       1,048,576
│    │    │    │    └─Linear: 5-3                       1,048,576
│    │    │    │    └─Linear: 5-4                       4,194,304
│    │    │    │    └─LlamaRotaryEmbedding: 5-5         --
│    │    │    └─LlamaMLP: 4-2                          --
│    │    │    │    └─Linear: 5-6                       16,777,216
│    │    │    │    └─Linear: 5-7                       16,777,216
│    │    │    │    └─Linear: 5-8                       1

In [4]:
# 質問の入力
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

In [50]:
text = "This is a [MASK] great."

In [51]:
import torch
device='cuda'
inputs = tokenizer(text, return_tensors="pt")
inputs.to(device)
model.to(device)
token_logits = model(**inputs).logits
inputs['input_ids']

tensor([[128000,   2028,    374,    264,    510,  50963,     60,   2294,     13]],
       device='cuda:0')

In [52]:
mask_token=tokenizer('MASK', return_tensors="pt")
mask_token=mask_token['input_ids']
mask_token_id=mask_token[0]
mask_token_id=mask_token_id[1].item()
mask_token='MASK'
print(mask_token)
print(mask_token_id)

MASK
50963


In [53]:
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    token=tokenizer.decode([token])
    print(f"'>>> {text.replace(mask_token, token)}'")

'>>> This is a []] great.'
'>>> This is a [].] great.'
'>>> This is a []

] great.'
'>>> This is a []
] great.'
'>>> This is a [](] great.'


In [54]:
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {tokenizer.decode([token])}'")

'>>> ]'
'>>> ].'
'>>> ]

'
'>>> ]
'
'>>> ]('
