# File for testing stuff

In [29]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel,AutoModelForTokenClassification
import torch
import numpy as np
from tqdm import tqdm
import gc

In [30]:
# Load OntoNotes dataset
ontonotes = load_dataset("conll2012_ontonotesv5", "english_v12")
train_data = ontonotes["train"]

# Load BERT tokenizer and model
model_name = "dslim/bert-base-NER"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
ner_model = AutoModelForTokenClassification.from_pretrained(model_name)
# Move model to GPU if available
device = 'cpu'
if torch.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
print(f"Using device: {device}")
model.to(device)
model.eval()

#  Manually define the NER label list (OntoNotes-style)
label_list = [
    "O", "B-PERSON", "I-PERSON", "B-ORG", "I-ORG",
    "B-GPE", "I-GPE", "B-DATE", "I-DATE",
    "B-CARDINAL", "I-CARDINAL",
    "B-MONEY", "I-MONEY",
    "B-PERCENT", "I-PERCENT",
    "B-TIME", "I-TIME",
    "B-FAC", "I-FAC", "B-LOC", "I-LOC",
    "B-PRODUCT", "I-PRODUCT",
    "B-WORK_OF_ART", "I-WORK_OF_ART",
    "B-LAW", "I-LAW", "B-EVENT", "I-EVENT",
    "B-LANGUAGE", "I-LANGUAGE",
    "B-NORP", "I-NORP", "B-QUANTITY", "I-QUANTITY",
    "B-ORDINAL", "I-ORDINAL"
]


Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Using device: cpu


In [31]:

def extract_entities(words, ner_tags):
    """
    Returns binary labels Y for each timestep, where Y=1 marks the moment an entity is completed.
    Supports BIO tagging.
    """
    seen_entities = set()
    Y = []

    current_entity = []
    current_type = None

    for i in range(1, len(words) + 1):
        word = words[i - 1]
        tag = ner_tags[i - 1]
        label = 0

        next_tag = ner_tags[i] if i < len(ner_tags) else 'O'

        if tag.startswith("B-"):
            current_entity = [word]
            current_type = tag[2:]

            # Single-token entity
            if not next_tag.startswith("I-") or next_tag[2:] != current_type:
                ent_tuple = (tuple(current_entity), current_type)
                if ent_tuple not in seen_entities:
                    label = 1
                    seen_entities.add(ent_tuple)
                current_entity = []
                current_type = None

        elif tag.startswith("I-") and current_type == tag[2:]:
            current_entity.append(word)

            # Last token of multi-token entity
            if not next_tag.startswith("I-") or next_tag[2:] != current_type:
                ent_tuple = (tuple(current_entity), current_type)
                if ent_tuple not in seen_entities:
                    label = 1
                    seen_entities.add(ent_tuple)
                current_entity = []
                current_type = None

        else:
            current_entity = []
            current_type = None

        Y.append(label)

    return Y



def process_sentence(words, ner_ids):
    ner_tags = [label_list[i] for i in ner_ids]
    x_list, y_list = [], []

    # Get labels for each timestep using the updated extract_entities function
    timestep_labels = extract_entities(words, ner_tags)

    for i in range(1, len(words) + 1):
        partial = words[:i]
        partial_tags = ner_tags[:i]

        inputs = tokenizer(" ".join(partial), return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            cls = outputs.last_hidden_state[0][0].numpy()

        y = timestep_labels[i - 1]

        x_list.append(cls)
        y_list.append(y)

        # print(f"Timestep {i}")
        # print(f"Partial sentence: {' '.join(partial)}")
        # print(f"NER tags so far: {partial_tags}")
        # print(f"Label Y: {y}")
        # print("-" * 60)

    return x_list, y_list

def process_sentence_evaluation(words, ner_tags):
    """
    For a sentence and its BIO tags:
    Returns one tuple:
    (list of unique tags, list of partial sentences (lists of words) where an entity ends)
    """
    # Unique tags in order
    seen_tags = set()
    unique_tags = []
    for tag in ner_tags:
        if tag not in seen_tags:
            unique_tags.append(tag)
            seen_tags.add(tag)

    # Get entity ending labels using your existing function
    timestep_labels = extract_entities(words, ner_tags)

    # Collect only partials where an entity ends
    entity_partials = []
    for i in range(1, len(words) + 1):
        if timestep_labels[i - 1] == 1:
            partial = words[:i]
            entity_partials.append(partial)

    return (unique_tags, entity_partials)


def process_sentence_batch(sentences_data, batch_size=32):
    """
    Process multiple sentences in batches for GPU efficiency.
    """
    all_partial_texts = []
    all_labels = []
    sentence_boundaries = []
    
    # Prepare all partial sentences and labels
    for words, ner_tags in sentences_data:
        timestep_labels = extract_entities(words, ner_tags)
        sentence_start = len(all_partial_texts)
        
        for i in range(1, len(words) + 1):
            partial = words[:i]
            partial_text = " ".join(partial)
            all_partial_texts.append(partial_text)
            all_labels.append(timestep_labels[i - 1])
        
        sentence_boundaries.append((sentence_start, len(all_partial_texts)))
    
    # Process in batches
    X_batch, Y_batch = [], []
    
    for i in tqdm(range(0, len(all_partial_texts), batch_size), desc="Processing batches"):
        batch_texts = all_partial_texts[i:i + batch_size]
        batch_labels = all_labels[i:i + batch_size]
        
        # Tokenize batch
        inputs = tokenizer(
            batch_texts, 
            padding=True, 
            truncation=True, 
            return_tensors="pt"
        ).to(device)
        
        # Get embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        
        X_batch.extend(cls_embeddings)
        Y_batch.extend(batch_labels)
        
        # Clear GPU memory
        del inputs, outputs
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        torch.mps.empty_cache() if torch.mps.is_available() else None
    
    return X_batch, Y_batch

def collect_sentence_data(data_subset, max_docs=None):
    """
    Collect all sentence data before processing.
    """
    sentences_data = []
    doc_count = 0
    
    for ex in data_subset:
        if max_docs and doc_count >= max_docs:
            break
        
        for sentence in ex["sentences"]:
            words = sentence["words"]
            ner_ids = sentence["named_entities"]
            ner_tags = [label_list[i] for i in ner_ids]
            sentences_data.append((words, ner_tags))
        
        doc_count += 1
    
    return sentences_data


def process_sentence_evaluation_version2(words, ner_pred_ids, label_list):
    """
    words: list of words in sentence
    ner_pred_ids: list of predicted label ids from model output (integers)
    label_list: list mapping label_id -> BIO tag string, e.g. ['O', 'B-PER', 'I-PER', ...]

    Returns:
        tuple: (unique_tags, list_of_partial_sentences_where_entity_ends)
    """
    # Convert predicted ids to tags
    ner_tags = [label_list[i] for i in ner_pred_ids]

    # Unique tags in order
    seen_tags = set()
    unique_tags = []
    for tag in ner_tags:
        if tag not in seen_tags:
            unique_tags.append(tag)
            seen_tags.add(tag)

    # Get entity ending labels
    timestep_labels = extract_entities(words, ner_tags)

    # Collect only partial sentences where entity ends
    entity_partials = []
    for i in range(1, len(words) + 1):
        if timestep_labels[i - 1] == 1:
            partial = words[:i]
            entity_partials.append(partial)

    return (unique_tags, entity_partials)


# Testing dataset oracle

In [26]:
test_data = ontonotes["test"]
print(f"test_data: {test_data}")
print((test_data["sentences"][0][0]["words"]))
print(test_data["sentences"][0][0])


test_data: Dataset({
    features: ['document_id', 'sentences'],
    num_rows: 1200
})
['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties', '.']
{'part_id': 0, 'words': ['--', 'basically', ',', 'it', 'was', 'unanimously', 'agreed', 'upon', 'by', 'the', 'various', 'relevant', 'parties', '.'], 'pos_tags': [9, 33, 5, 31, 41, 33, 43, 18, 18, 14, 19, 19, 28, 8], 'parse_tree': '(TOP(S (: --) (ADVP (RB basically) ) (, ,) (NP (PRP it) )(VP (VBD was) (ADVP (RB unanimously) )(VP (VBN agreed) (PP (IN upon) )(PP (IN by) (NP (DT the)  (JJ various)  (JJ relevant)  (NNS parties) )))) (. .) ))', 'predicate_lemmas': [None, None, None, None, 'be', None, 'agree', None, None, None, None, None, None, None], 'predicate_framenet_ids': [None, None, None, None, '03', None, '01', None, None, None, None, None, None, None], 'word_senses': [None, None, None, None, None, None, 1.0, None, None, None, None, None, None, None], 'speaker': 'speaker#1', 'na

In [23]:
words = ["Barack", "Obama", "was", "born", "in", "Hawaii", "and", "lives", "in", "Washington", "DC"]
ner_tags = ["B-PER", "I-PER", "O", "O", "O", "B-LOC", "O", "O", "O", "B-LOC", "I-LOC"]


print(process_sentence_evaluation(words, ner_tags))


(['B-PER', 'I-PER', 'O', 'B-LOC', 'I-LOC'], [['Barack', 'Obama'], ['Barack', 'Obama', 'was', 'born', 'in', 'Hawaii'], ['Barack', 'Obama', 'was', 'born', 'in', 'Hawaii', 'and', 'lives', 'in', 'Washington', 'DC']])


In [33]:
label_list = model.config.id2label  # dict {id: label}
label_list = [label_list[i] for i in range(len(label_list))]  # list ordered by id
words = ["Barack", "Obama", "was", "born", "in", "Hawaii"]
inputs = tokenizer(words, is_split_into_words=True, return_tensors="pt", truncation=True)

with torch.no_grad():
    outputs = ner_model(**inputs)
    logits = outputs.logits  # shape: [batch_size, seq_len, num_labels]
    
predicted_token_ids = torch.argmax(logits, dim=-1).squeeze().tolist()

word_ids = inputs.word_ids()  # list with word index for each token
word_pred_ids = []
current_word = None
for idx, word_id in enumerate(word_ids):
    if word_id is None:
        continue
    if word_id != current_word:
        # start of a new word
        word_pred_ids.append(predicted_token_ids[idx])
        current_word = word_id
    else:
        # For subword tokens of the same word, you can choose to keep first or most confident prediction
        # Here we ignore subword predictions for simplicity
        pass

# Now call your function
result = process_sentence_evaluation_version2(words, word_pred_ids, label_list)

print("Unique tags in sentence:", result[0])
print("Partial sentences where entities end:")
for partial in result[1]:
    print(partial)

Unique tags in sentence: ['B-PER', 'I-PER', 'O', 'B-LOC']
Partial sentences where entities end:
['Barack', 'Obama']
['Barack', 'Obama', 'was', 'born', 'in', 'Hawaii']


# Deprecated?

In [18]:
# Test with smaller subset first
print("Testing with first 10 documents...")
sentences_data = collect_sentence_data(train_data, max_docs=10)
print(f"Collected {len(sentences_data)} sentences from 10 documents")

# Process in batches
X_all, Y_all = process_sentence_batch(sentences_data, batch_size=256)

print(f"Collected {len(X_all)} samples from test run.")
print(f"Positive samples: {sum(Y_all)}")
print(f"Negative samples: {len(Y_all) - sum(Y_all)}")

# Save test results
X_array = np.array(X_all)
Y_array = np.array(Y_all)
np.savez("ontonotes_embeddings_test.npz", X=X_array, Y=Y_array)
print("Saved test results to ontonotes_embeddings_test.npz")

# Clean up memory
del X_all, Y_all, X_array, Y_array
gc.collect()

Testing with first 10 documents...
Collected 5445 sentences from 10 documents


Processing batches: 100%|██████████| 285/285 [02:25<00:00,  1.96it/s]


Collected 72860 samples from test run.
Positive samples: 4591
Negative samples: 68269
Saved test results to ontonotes_embeddings_test.npz


0

In [5]:
# Process full dataset in chunks to avoid memory issues
chunk_size = 1000  # Process 1000 documents at a time
total_docs = len(train_data)
X_all, Y_all = [], []

for chunk_start in range(0, total_docs, chunk_size):
    chunk_end = min(chunk_start + chunk_size, total_docs)
    print(f"\nProcessing documents {chunk_start} to {chunk_end-1}...")
    
    # Get chunk data
    chunk_data = train_data.select(range(chunk_start, chunk_end))
    sentences_data = collect_sentence_data(chunk_data)
    
    # Process chunk
    X_chunk, Y_chunk = process_sentence_batch(sentences_data, batch_size=256)
    
    X_all.extend(X_chunk)
    Y_all.extend(Y_chunk)
    
    print(f"Chunk complete. Total samples so far: {len(X_all)}")
    
    # Clear memory
    del X_chunk, Y_chunk, sentences_data
    gc.collect()
    
    # Save intermediate results every few chunks
    if (chunk_start // chunk_size) % 5 == 4:
        print("Saving intermediate results...")
        X_array = np.array(X_all)
        Y_array = np.array(Y_all)
        np.savez(f"ontonotes_embeddings_intermediate_{chunk_start}.npz", X=X_array, Y=Y_array)
        del X_array, Y_array

print(f"\nFinal: Collected {len(X_all)} samples.")
print(f"Positive samples: {sum(Y_all)}")
print(f"Negative samples: {len(Y_all) - sum(Y_all)}")

# Save final results
X_array = np.array(X_all)
Y_array = np.array(Y_all)
np.savez("ontonotes_embeddings_full.npz", X=X_array, Y=Y_array)
print("Saved full results to ontonotes_embeddings_full.npz")


Processing documents 0 to 999...


Processing batches: 100%|██████████| 1970/1970 [19:52<00:00,  1.65it/s]


Chunk complete. Total samples so far: 504191

Processing documents 1000 to 1999...


Processing batches: 100%|██████████| 1960/1960 [20:23<00:00,  1.60it/s]


Chunk complete. Total samples so far: 1005878

Processing documents 2000 to 2999...


Processing batches: 100%|██████████| 2246/2246 [21:46<00:00,  1.72it/s]


Chunk complete. Total samples so far: 1580775

Processing documents 3000 to 3999...


Processing batches: 100%|██████████| 1792/1792 [19:38<00:00,  1.52it/s]


Chunk complete. Total samples so far: 2039335

Processing documents 4000 to 4999...


Processing batches: 100%|██████████| 92/92 [01:07<00:00,  1.37it/s]


Chunk complete. Total samples so far: 2062809
Saving intermediate results...

Processing documents 5000 to 5999...


Processing batches: 100%|██████████| 98/98 [01:17<00:00,  1.27it/s]


Chunk complete. Total samples so far: 2087771

Processing documents 6000 to 6999...


Processing batches: 100%|██████████| 93/93 [01:20<00:00,  1.16it/s]


Chunk complete. Total samples so far: 2111389

Processing documents 7000 to 7999...


Processing batches: 100%|██████████| 104/104 [01:37<00:00,  1.07it/s]


Chunk complete. Total samples so far: 2137962

Processing documents 8000 to 8999...


Processing batches: 100%|██████████| 98/98 [01:23<00:00,  1.17it/s]


Chunk complete. Total samples so far: 2162815

Processing documents 9000 to 9999...


Processing batches: 100%|██████████| 96/96 [01:41<00:00,  1.05s/it]


Chunk complete. Total samples so far: 2187200
Saving intermediate results...

Processing documents 10000 to 10538...


Processing batches: 100%|██████████| 54/54 [01:02<00:00,  1.16s/it]


Chunk complete. Total samples so far: 2200865

Final: Collected 2200865 samples.
Positive samples: 125904
Negative samples: 2074961
Saved full results to ontonotes_embeddings_full.npz


In [6]:
print("ontonotes_embeddings_*.npz are too large for git, therefore it's available online at https://drive.google.com/drive/folders/1ykTaDLdHIEmZQYN0b1Hr9hkOYjgMshSa?usp=sharing")

ontonotes_embeddings_*.npz are too large for git, therefore it's available online at https://drive.google.com/drive/folders/1ykTaDLdHIEmZQYN0b1Hr9hkOYjgMshSa?usp=sharing
