In [17]:
import os
project_cache_dir = "./hf_cache"
os.makedirs(project_cache_dir, exist_ok=True)

In [18]:
import torch
from transformers import AutoTokenizer, AutoModel

# models
model_name = "SpanBERT/spanbert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=project_cache_dir,
    lcoal_files_only=False
)
encoder = AutoModel.from_pretrained(
    model_name,
    cache_dir=project_cache_dir,
    local_files_only=False
)
encoder.eval()


Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [19]:
# gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
print("[Stage 1] Loaded SpanBERT-base model on device:", device)

[Stage 1] Loaded SpanBERT-base model on device: cuda


In [20]:
# test input
text = ("Lionel Messi scored a brilliant goal. "
        "The Argentine forward celebrated with his teammates. "
        "He is considered one of the best players in the world.")
print("[Stage 2] Input text:\n", text)

[Stage 2] Input text:
 Lionel Messi scored a brilliant goal. The Argentine forward celebrated with his teammates. He is considered one of the best players in the world.


In [21]:
# inference
tokens = tokenizer(text, return_tensors="pt")
tokens = {k:v.to(device) for k,v in tokens.items()}

with torch.no_grad():
    outputs = encoder(**tokens)

token_embeddings = outputs.last_hidden_state.squeeze(0)  # [seq_len, hidden_size]
print("\n[Stage 3] Token embeddings shape:", token_embeddings.shape)


[Stage 3] Token embeddings shape: torch.Size([33, 768])


In [22]:
max_span_width = 2
span_indices = []
for i in range(token_embeddings.size(0)):
    for w in range(1, max_span_width+1):
        if i + w <= token_embeddings.size(0):
            span_indices.append((i, i + w - 1))
print("\n[Stage 4] Candidate spans (token indices):\n", span_indices)


[Stage 4] Candidate spans (token indices):
 [(0, 0), (0, 1), (1, 1), (1, 2), (2, 2), (2, 3), (3, 3), (3, 4), (4, 4), (4, 5), (5, 5), (5, 6), (6, 6), (6, 7), (7, 7), (7, 8), (8, 8), (8, 9), (9, 9), (9, 10), (10, 10), (10, 11), (11, 11), (11, 12), (12, 12), (12, 13), (13, 13), (13, 14), (14, 14), (14, 15), (15, 15), (15, 16), (16, 16), (16, 17), (17, 17), (17, 18), (18, 18), (18, 19), (19, 19), (19, 20), (20, 20), (20, 21), (21, 21), (21, 22), (22, 22), (22, 23), (23, 23), (23, 24), (24, 24), (24, 25), (25, 25), (25, 26), (26, 26), (26, 27), (27, 27), (27, 28), (28, 28), (28, 29), (29, 29), (29, 30), (30, 30), (30, 31), (31, 31), (31, 32), (32, 32)]


In [23]:
span_embeddings = torch.stack([token_embeddings[start] + token_embeddings[end] for start, end in span_indices])
print("\n[Stage 5] Span embeddings shape:", span_embeddings.shape)


[Stage 5] Span embeddings shape: torch.Size([65, 768])


In [24]:
scores = torch.matmul(span_embeddings, span_embeddings.T)
scores = torch.tril(scores, diagonal=-1)
print("\n[Stage 6] Antecedent scores matrix shape:", scores.shape)


[Stage 6] Antecedent scores matrix shape: torch.Size([65, 65])


In [25]:
# clustering 
clusters = []
span_to_cluster = {}
for i, (start, end) in enumerate(span_indices):
    if i == 0:
        span_to_cluster[i] = len(clusters)
        clusters.append([(start, end)])
        continue
    antecedents = scores[i,:i]
    if antecedents.max() > .2:
        best = antecedents.argmax().item()
        span_to_cluster[i] = span_to_cluster[best]
        clusters[span_to_cluster[i]].append((start, end))
    else:
        span_to_cluster[i] = len(clusters)
        clusters.append([(start, end)])
        
print("\n[Stage 7] Clusters (token indices):\n", clusters)


[Stage 7] Clusters (token indices):
 [[(0, 0), (0, 1), (1, 1), (1, 2), (2, 2), (2, 3), (3, 3), (3, 4), (4, 4), (4, 5), (5, 5), (5, 6), (6, 6), (6, 7), (7, 7), (7, 8), (8, 8), (8, 9), (9, 9), (9, 10), (10, 10), (10, 11), (11, 11), (11, 12), (12, 12), (12, 13), (13, 13), (13, 14), (14, 14), (14, 15), (15, 15), (15, 16), (16, 16), (16, 17), (17, 17), (17, 18), (18, 18), (18, 19), (19, 19), (19, 20), (20, 20), (20, 21), (21, 21), (21, 22), (22, 22), (22, 23), (23, 23), (23, 24), (24, 24), (24, 25), (25, 25), (25, 26), (26, 26), (26, 27), (27, 27), (27, 28), (28, 28), (28, 29), (29, 29), (29, 30), (30, 30), (30, 31), (31, 31), (31, 32), (32, 32)]]


In [26]:
# conversion back to text
token_list = tokenizer.convert_ids_to_tokens(tokens["input_ids"].squeeze())
clusters_text = []
for cluster in clusters:
    clusters_text.append([" ".join(token_list[start:end+1]) for start, end in cluster])


In [27]:
# output
print("Candidate clusters (resolved mentions):")
for c in clusters_text:
    print(c)


Candidate clusters (resolved mentions):
['[CLS]', '[CLS] lion', 'lion', 'lion ##el', '##el', '##el mess', 'mess', 'mess ##i', '##i', '##i scored', 'scored', 'scored a', 'a', 'a brilliant', 'brilliant', 'brilliant goal', 'goal', 'goal .', '.', '. the', 'the', 'the a', 'a', 'a ##rgent', '##rgent', '##rgent ##ine', '##ine', '##ine forward', 'forward', 'forward celebrated', 'celebrated', 'celebrated with', 'with', 'with his', 'his', 'his teammates', 'teammates', 'teammates .', '.', '. he', 'he', 'he is', 'is', 'is considered', 'considered', 'considered one', 'one', 'one of', 'of', 'of the', 'the', 'the best', 'best', 'best players', 'players', 'players in', 'in', 'in the', 'the', 'the world', 'world', 'world .', '.', '. [SEP]', '[SEP]']


In [28]:
for i, (start, end) in enumerate(span_indices):
    span_text = " ".join(token_list[start:end+1])
    if i == 0:
        print(f"{span_text:20} -> None")
        continue
    antecedents = scores[i,:i]
    top_score, top_idx = antecedents.max(0)
    top_span = " ".join(token_list[span_indices[top_idx][0]:span_indices[top_idx][1]+1])
    print(f"{span_text:20} -> {top_span:20} (score={top_score.item():.2f})")

[CLS]                -> None
[CLS] lion           -> [CLS]                (score=48.06)
lion                 -> [CLS] lion           (score=50.67)
lion ##el            -> lion                 (score=53.70)
##el                 -> lion ##el            (score=72.69)
##el mess            -> ##el                 (score=80.84)
mess                 -> ##el mess            (score=60.57)
mess ##i             -> ##el                 (score=61.99)
##i                  -> ##el                 (score=70.94)
##i scored           -> ##i                  (score=80.64)
scored               -> ##i scored           (score=90.88)
scored a             -> scored               (score=99.10)
a                    -> scored a             (score=107.03)
a brilliant          -> a                    (score=99.97)
brilliant            -> a brilliant          (score=87.02)
brilliant goal       -> brilliant            (score=84.16)
goal                 -> brilliant goal       (score=86.85)
goal .               -> go

In [31]:
# clusters_text
clusters_text = []
for cluster in clusters:
    clusters_text.append([" ".join(token_list[start:end+1]) for start, end in cluster])
print("\n[Final Output] Candidate clusters (resolved mentions):", clusters_text)

# Only clusters with â‰¥2 mentions
relevant_clusters = [c for c in clusters_text if len(c) > 1]

print("\n[Relevant Entity Clusters]:")
for c in relevant_clusters:
    print(c)



[Final Output] Candidate clusters (resolved mentions): [['[CLS]', '[CLS] lion', 'lion', 'lion ##el', '##el', '##el mess', 'mess', 'mess ##i', '##i', '##i scored', 'scored', 'scored a', 'a', 'a brilliant', 'brilliant', 'brilliant goal', 'goal', 'goal .', '.', '. the', 'the', 'the a', 'a', 'a ##rgent', '##rgent', '##rgent ##ine', '##ine', '##ine forward', 'forward', 'forward celebrated', 'celebrated', 'celebrated with', 'with', 'with his', 'his', 'his teammates', 'teammates', 'teammates .', '.', '. he', 'he', 'he is', 'is', 'is considered', 'considered', 'considered one', 'one', 'one of', 'of', 'of the', 'the', 'the best', 'best', 'best players', 'players', 'players in', 'in', 'in the', 'the', 'the world', 'world', 'world .', '.', '. [SEP]', '[SEP]']]

[Relevant Entity Clusters]:
['[CLS]', '[CLS] lion', 'lion', 'lion ##el', '##el', '##el mess', 'mess', 'mess ##i', '##i', '##i scored', 'scored', 'scored a', 'a', 'a brilliant', 'brilliant', 'brilliant goal', 'goal', 'goal .', '.', '. the'