In [24]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
import numpy as np

In [26]:
# Load the fine-tuned model and tokenizer
model_name = "/content/fine-tuned-ner-model"
model = AutoModelForTokenClassification.from_pretrained(model_name, output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [27]:
def get_attention_weights(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors='pt')
    outputs = model(**inputs)
    attentions = outputs.attentions[-1].squeeze(0)  # Get the attention weights from the last layer
    return inputs, attentions
# Sample text for interpretation
sample_text = "የገና በዓልን ምክንያት በማድረግ የባሕርዳር ዩኒቨርሲቲ ተማሪዎች ከዝግባ ሕጻናትና አረጋዊያን መርጃ  በጎ አድራጎት ድርጅት በመረዳት ላይ ለሚገኙ ወገኖች የምሳ ግብዣ በማድረግ በዓልን አሳልፈዋል።"
inputs, attentions = get_attention_weights(model, tokenizer, sample_text)

In [28]:
# Function to print attention weights in descending order
def print_attention_weights(attentions, inputs, tokenizer):
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze().tolist())
    attention_weights = attentions.mean(dim=0).detach().cpu().numpy()  # Average over heads
    token_attention_weights = []
    for i, token in enumerate(tokens):
        avg_attention = attention_weights[:, i].mean()  # Average attention for each token specifically
        token_attention_weights.append((token, avg_attention))
    # Sort by attention weight in descending order
    sorted_token_attention_weights = sorted(token_attention_weights, key=lambda item: item[1], reverse=True)
    print("Attention Weights (Descending Order):")
    for token, weight in sorted_token_attention_weights:
        print(f"Token: {token}, Weight: {weight:.4f}")
# Print the attention weights
print_attention_weights(attentions, inputs, tokenizer)

Attention Weights (Descending Order):
Token: [SEP], Weight: 0.0618
Token: [UNK], Weight: 0.0538
Token: [CLS], Weight: 0.0513
Token: [UNK], Weight: 0.0488
Token: [UNK], Weight: 0.0481
Token: [UNK], Weight: 0.0468
Token: [UNK], Weight: 0.0450
Token: [UNK], Weight: 0.0428
Token: [UNK], Weight: 0.0415
Token: [UNK], Weight: 0.0403
Token: [UNK], Weight: 0.0393
Token: [UNK], Weight: 0.0374
Token: [UNK], Weight: 0.0362
Token: [UNK], Weight: 0.0351
Token: [UNK], Weight: 0.0342
Token: [UNK], Weight: 0.0334
Token: [UNK], Weight: 0.0329
Token: [UNK], Weight: 0.0324
Token: [UNK], Weight: 0.0324
Token: [UNK], Weight: 0.0323
Token: [UNK], Weight: 0.0317
Token: [UNK], Weight: 0.0302
Token: [UNK], Weight: 0.0292
Token: [UNK], Weight: 0.0280
Token: [UNK], Weight: 0.0277
Token: [UNK], Weight: 0.0273
