In [None]:
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)

In [None]:
def visualize_attention(text):
    tokens = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**tokens)
    attention = outputs.attentions[-1]
    return attention, tokens

In [None]:
text = "Medical Analysis."
attention, tokens = visualize_attention(text)

attn_scores = attention[0].mean(dim=0).numpy()
tokenized_text = tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])

In [None]:
plt.figure(figsize=(10, 5))
sns.heatmap(attn_scores, xticklabels=tokenized_text, yticklabels=tokenized_text, cmap='viridis')
plt.title("Attention Heatmap")
plt.show()

In [None]:
note1 = "Patient reports persistent cough, high fever, and difficulty breathing for the past three days."
note2 = "Mild headache and occasional dizziness, but no fever or cough."
note3 = "Patient exhibits cyanosis and severe chest pain, indicative of a critical condition."

In [None]:
visualize_attention(note1)
visualize_attention(note2)
visualize_attention(note3)