In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# ============================================================
# 0. LOCAL PROJECT CACHE
# ============================================================

project_cache_dir = "./hf_cache"
os.makedirs(project_cache_dir, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = project_cache_dir
os.environ["HF_DATASETS_CACHE"] = project_cache_dir

In [3]:

# ============================================================
# 1. LOAD MODEL + TOKENIZER (SpanBERT-base)
# ============================================================

model_name = "SpanBERT/spanbert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=project_cache_dir
)

model = AutoModel.from_pretrained(
    model_name,
    cache_dir=project_cache_dir
)

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

print("\n[INFO] Model + tokenizer loaded.\n")


[INFO] Model + tokenizer loaded.



In [5]:

# ============================================================
# 2. INPUT TEXT
# ============================================================

text = "Lionel Messi scored a brilliant goal. The Argentine forward celebrated with his teammates. He is considered one of the best players in the world."

# Tokenize
enc = tokenizer(text, return_tensors="pt", truncation=True)
input_ids = enc["input_ids"].to(device)
attention_mask = enc["attention_mask"].to(device)

tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().tolist())

print("[TOKENS]:")
print(tokens, "\n")


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[TOKENS]:
['[CLS]', 'lion', '##el', 'mess', '##i', 'scored', 'a', 'brilliant', 'goal', '.', 'the', 'a', '##rgent', '##ine', 'forward', 'celebrated', 'with', 'his', 'teammates', '.', 'he', 'is', 'considered', 'one', 'of', 'the', 'best', 'players', 'in', 'the', 'world', '.', '[SEP]'] 



In [6]:
# Get embeddings
with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    hidden_states = outputs.last_hidden_state.squeeze(0)   # [seq_len, hidden]
    hidden_states = hidden_states.cpu()


In [7]:
# ============================================================
# 3. CREATE CANDIDATE SPANS  (single tokens only for demo)
# ============================================================

spans = []
span_embeddings = []

for i in range(1, len(tokens)-1):  # skip CLS / SEP
    span = [tokens[i]]
    spans.append(span)
    span_embeddings.append(hidden_states[i])

print("[RAW SPANS]:")
for s in spans:
    print(s)
print("\n")


[RAW SPANS]:
['lion']
['##el']
['mess']
['##i']
['scored']
['a']
['brilliant']
['goal']
['.']
['the']
['a']
['##rgent']
['##ine']
['forward']
['celebrated']
['with']
['his']
['teammates']
['.']
['he']
['is']
['considered']
['one']
['of']
['the']
['best']
['players']
['in']
['the']
['world']
['.']




In [8]:
# ============================================================
# 4. BUILD NAIVE CLUSTERS BASED ON SIMILARITY
# ============================================================

clusters_text = []
clusters_embeds = []

for i, span_vec in enumerate(span_embeddings):
    if not clusters_text:
        clusters_text.append([spans[i][0]])
        clusters_embeds.append([span_vec])
        continue

    best_score = -1e9
    best_cluster = None

    for c_idx, c in enumerate(clusters_embeds):
        c_centroid = torch.stack(c).mean(dim=0)
        score = torch.dot(span_vec, c_centroid).item()
        if score > best_score:
            best_score = score
            best_cluster = c_idx

    clusters_text[best_cluster].append(spans[i][0])
    clusters_embeds[best_cluster].append(span_vec)

print("============================================")
print(" RAW CLUSTERS (UNMERGED)")
print("============================================")
for idx, cluster in enumerate(clusters_text):
    print(f"Cluster {idx+1}: {cluster}")
print("\n")


 RAW CLUSTERS (UNMERGED)
Cluster 1: ['lion', '##el', 'mess', '##i', 'scored', 'a', 'brilliant', 'goal', '.', 'the', 'a', '##rgent', '##ine', 'forward', 'celebrated', 'with', 'his', 'teammates', '.', 'he', 'is', 'considered', 'one', 'of', 'the', 'best', 'players', 'in', 'the', 'world', '.']




In [9]:
# ============================================================
# 5. MERGE SUBWORDS (##tokens)
# ============================================================

def merge_subwords(tokens):
    out = []
    current = ""
    for tok in tokens:
        if tok.startswith("##"):
            current += tok[2:]
        else:
            if current:
                out.append(current)
            current = tok
    if current:
        out.append(current)
    return out

clean_clusters = []

for cluster in clusters_text:
    merged = merge_subwords(cluster)
    if len(merged) >= 1:
        clean_clusters.append(merged)


In [10]:
# ============================================================
# 6. DISPLAY CLEAN CLUSTERS
# ============================================================

print("============================================")
print(" CLEANED & MERGED CLUSTERS")
print("============================================")

for idx, cluster in enumerate(clean_clusters):
    # choose canonical name:
    canonical = next((w for w in cluster if w[0].isupper()), cluster[0])
    print(f"Cluster {idx+1} (canonical: {canonical}): {cluster}")

print("\n[FINISHED]\n")


 CLEANED & MERGED CLUSTERS
Cluster 1 (canonical: lionel): ['lionel', 'messi', 'scored', 'a', 'brilliant', 'goal', '.', 'the', 'argentine', 'forward', 'celebrated', 'with', 'his', 'teammates', '.', 'he', 'is', 'considered', 'one', 'of', 'the', 'best', 'players', 'in', 'the', 'world', '.']

[FINISHED]

