In [2]:
import torch

access_token = "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"
text = "Beijing is the capital of <mask>." 

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("JackBAI/crate-base", token=access_token)
inputs = tokenizer(text, return_tensors="pt")
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("JackBAI/crate-base", token=access_token)
logits = model(**inputs).logits
mask_token_logits = logits[0, mask_token_index, :]

top_3_tokens = torch.topk(mask_token_logits, 3, dim=1).indices[0].tolist()

for token in top_3_tokens:
    print(text.replace(tokenizer.mask_token, tokenizer.decode([token])))
    
inputs = tokenizer(text, return_tensors="pt")

input_id = inputs.input_ids
print("=== Tokenization ===")
encoded = []
mask_id = 0
for i in range(input_id.shape[1]):
    encoded_item = tokenizer.decode(input_id[0,i].item())
    encoded.append(encoded_item)
    if encoded_item == '<mask>':
        mask_id = i
print(encoded)

# Pass the input through the model to get predictions
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True, output_attentions=True)

print(f"=== Matrix Sizes ===")
hidden_states = outputs.hidden_states
embedding_output = hidden_states[0]
last_hidden_states = hidden_states[-1] # [1, 8, 768]

# get last layer [CLS] (<s>) vector
z_cls = last_hidden_states[0,mask_id,:]
print("z_cls:", z_cls.shape) # [768]

# get last layer matrix input
z_i = last_hidden_states[0,1:-1,:]
# exclude the state at mask_id by setting it to -inf
z_i = torch.cat((z_i[:mask_id-1, :], z_i[mask_id:, :]))
print("z_i:", z_i.shape) # [9, 768]

# get last layer model weights
U = model.crate.encoder.layer[-1].attention.self.qkv.weight.data
print("U:", U.shape) # [768, 768]

Uz_i = torch.matmul(U, z_i.T) # [768, 9]
Uz_cls = torch.matmul(U, z_cls) # [768]

# calculate per-token attention, i.e., each column of Uz_i dot product by Uz_cls
attentions = torch.matmul(Uz_i.T, Uz_cls) # [9]
    
from IPython.display import HTML

# Define a function to create HTML with background color based on attention score
def colored_text_with_attention(tokens, attentions):
    html = ""
    for token, attention in zip(tokens, attentions):
        color = f'rgba(255, 0, 0, {attention})'  # Adjust the color as per your preference
        html += f'<span style="background-color:{color}">{token}</span> '
    return HTML(html)

print("=== Attention ===")
attentions = torch.softmax(attentions, dim=0)
tokens = encoded[1:-1]
tokens = tokens[:mask_id-1] + tokens[mask_id:]
colored_text_with_attention(tokens, attentions.tolist())

Beijing is the capital of  it.
Beijing is the capital of  India.
Beijing is the capital of  China.
=== Tokenization ===
['<s>', 'Be', 'ijing', ' is', ' the', ' capital', ' of', '<mask>', '.', '</s>']
=== Matrix Sizes ===
z_cls: torch.Size([768])
z_i: torch.Size([7, 768])
U: torch.Size([768, 768])
=== Attention ===
