In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModel
from torch.nn.functional import cosine_similarity

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
project_cache_dir = "./hf_cache"
os.makedirs(project_cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = project_cache_dir
os.environ["HF_DATASETS_CACHE"] = project_cache_dir

In [5]:
# 1. Load SpanBERT directly from Hugging Face
# SpanBERT is architecturally just BERT, so AutoModel loads it perfectly.
model_name = "SpanBERT/spanbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=project_cache_dir)
model = AutoModel.from_pretrained(model_name, cache_dir=project_cache_dir)


Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Move to GPU (8GB VRAM is plenty for this)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [7]:
# 2. Your Input
text = "Lionel Messi scored a brilliant goal. The Argentine forward celebrated with his teammates. He is considered the best."
# 3. Prepare Inputs
inputs = tokenizer(text, return_tensors="pt").to(device)


In [8]:
# 4. Run the Model (Forward Pass)
# This runs the self-attention layers. The output 'last_hidden_state'
# contains the vectors that "know" the context.
with torch.no_grad():
    outputs = model(**inputs)

# Shape: [Batch_Size, Sequence_Length, Hidden_Size (768)]
embeddings = outputs.last_hidden_state[0]


In [10]:
# --- CRITICAL STEP: Mapping Entities to Token Indices ---
# In a real pipeline, your BIO tagger gives you these offsets.
# Let's find where the tokens are manually for this test.
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

# Helper to find token range for a substring
def find_span(substring, all_tokens):
    # This is a naive finder for demonstration
    sub_tokens = tokenizer.tokenize(substring)
    length = len(sub_tokens)
    for i in range(len(all_tokens) - length):
        if all_tokens[i:i+length] == sub_tokens:
            return i, i + length # Start, End (exclusive)
    return None, None

# Find positions
messi_start, messi_end = find_span("Lionel Messi", tokens)
he_start, he_end = find_span("He", tokens)
goal_start, goal_end = find_span("goal", tokens)

print(f"Indices -> Messi: {messi_start}-{messi_end}, He: {he_start}-{he_end}, Goal: {goal_start}-{goal_end}")


Indices -> Messi: 1-5, He: 20-21, Goal: 8-9


In [11]:
# 5. Extract Span Representations (Average Pooling)
# We average the embeddings of the tokens in the span to get one vector per entity.
messi_vec = torch.mean(embeddings[messi_start:messi_end], dim=0).unsqueeze(0)
he_vec    = torch.mean(embeddings[he_start:he_end], dim=0).unsqueeze(0)
goal_vec  = torch.mean(embeddings[goal_start:goal_end], dim=0).unsqueeze(0)


In [12]:
# 6. Calculate Similarity
# High similarity means the Self-Attention mechanism contextualized them similarly.
score_messi_he = cosine_similarity(messi_vec, he_vec).item()
score_messi_goal = cosine_similarity(messi_vec, goal_vec).item()

print("\n--- Results ---")
print(f"Similarity ('Lionel Messi' vs 'He'):   {score_messi_he:.4f}")
print(f"Similarity ('Lionel Messi' vs 'goal'): {score_messi_goal:.4f}")

if score_messi_he > score_messi_goal:
    print(">> 'He' refers to 'Lionel Messi'")
else:
    print(">> 'He' refers to 'goal'")


--- Results ---
Similarity ('Lionel Messi' vs 'He'):   0.7052
Similarity ('Lionel Messi' vs 'goal'): 0.5806
>> 'He' refers to 'Lionel Messi'
