In [69]:
import pandas as pd
import numpy as np
import time
import json
from transformers import AutoTokenizer
from datasets import Dataset
from tqdm import tqdm
from joblib import Parallel, delayed

def get_merged_preds_vectorized(predictions, token_idxs_mapping):
    # Ensure predictions is a 2D array for compatibility with DataFrame
    predictions = np.atleast_2d(predictions)
    
    # Create a DataFrame from token_idxs_mapping and predictions
    df = pd.DataFrame({
        'token_idx': token_idxs_mapping,
        'predictions': list(predictions)
    })
    
    # Exclude -1 from averaging, if necessary
    df = df[df['token_idx'] != -1]
    
    # Group by token_idx and average predictions
    averaged_df = df.groupby('token_idx')['predictions'].apply(lambda x: np.mean(np.vstack(x), axis=0)).reset_index()
    
    # Map averaged predictions back to the original order
    averaged_predictions = np.array(df['token_idx'].map(averaged_df.set_index('token_idx')['predictions']).tolist())
    
    # Handle -1 indices if necessary, assuming original predictions for -1 indices are kept
    minus_one_indices = np.where(token_idxs_mapping == -1)[0]
    if minus_one_indices.size > 0:
        for idx in minus_one_indices:
            averaged_predictions[idx] = predictions[idx]
    
    return np.array(averaged_predictions)


def get_merged_preds(predictions, token_idxs_mapping):
    if token_idxs_mapping is not None:
        # Initialize averaged_predictions with the same shape as predictions
        averaged_predictions = np.array(predictions)

        unique_token_idxs = set(token_idxs_mapping).difference(set([-1]))
        
        # Iterate over each unique token index to average predictions
        for token_idx in unique_token_idxs:
            # Find the indices in token_idxs_mapping that match the current token_idx
            indices = np.where(np.array(token_idxs_mapping) == token_idx)[0]
            
            # Average the predictions for these indices and assign to the correct positions
            averaged_predictions[indices] = np.mean(predictions[indices], axis=0)
        
        # Use the averaged predictions for further processing
        predictions = averaged_predictions
    
    return predictions, predictions



In [50]:
def tokenize(example, tokenizer, label2id, max_length):
    """tokenize the examples"""
    text = []
    labels = []
    token_map = [] # Character index to spacy token mapping

    token_map_idx = 0
    for t, l, ws in zip(example["tokens"], example["provided_labels"], example["trailing_whitespace"]):
        text.append(t)
        labels.extend([l]*len(t))
        token_map.extend([token_map_idx] * len(t))
        if ws:
            text.append(" ")
            labels.append("O")
            token_map.append(-1)

        token_map_idx += 1


    tokenized = tokenizer("".join(text), return_offsets_mapping=True, truncation = True, max_length=max_length, return_overflowing_tokens=True, stride = 256)
    
    labels = np.array(labels)
    
    text = "".join(text)
    token_labels = []
    token_idxs_mapping = [] # Represents the mapping of deberta token idx to spacy token idx. We can potentially merge the predictions of these tokens
    num_sequences = len(tokenized["input_ids"])
    for sequence_idx in range(num_sequences):
        offset_mapping_sequence = tokenized["offset_mapping"][sequence_idx]
        token_labels_sequence = []
        token_idxs_mapping_sequence = []
        for start_idx, end_idx in offset_mapping_sequence:
            
            # CLS token
            if start_idx == 0 and end_idx == 0: 
                token_idxs_mapping_sequence.append(-1)
                token_labels_sequence.append(label2id["O"])
                continue
            
            # case when token starts with whitespace
            if text[start_idx].isspace():
                start_idx += 1
            
            while start_idx >= len(labels):
                start_idx -= 1
                
            token_labels_sequence.append(label2id[labels[start_idx]])
            token_idxs_mapping_sequence.append(token_map[start_idx])
        
        token_labels.append(token_labels_sequence)
        token_idxs_mapping.append(token_idxs_mapping_sequence)
    
    token_map = [token_map for _ in range(num_sequences)]
    document = [example["document"] for _ in range(num_sequences)]
    fold = [example["fold"] for _ in range(num_sequences)]
    tokens = [example["tokens"] for _ in range(num_sequences)]
        
    return {
        **tokenized,
        "labels": token_labels,
        "token_map": token_map,
        "document": document,
        "fold": fold,
        "tokens": tokens,
        "token_idxs_mapping": token_idxs_mapping
    }

In [51]:
#Add labels as global variable
LABELS = ['B-EMAIL',
        'B-ID_NUM',
        'B-NAME_STUDENT',
        'B-PHONE_NUM',
        'B-STREET_ADDRESS',
        'B-URL_PERSONAL',
        'B-USERNAME',
        'I-ID_NUM',
        'I-NAME_STUDENT',
        'I-PHONE_NUM',
        'I-STREET_ADDRESS',
        'I-URL_PERSONAL',
        'O']

In [52]:
with open("../data/train.json") as f:
    data = json.load(f)

ds = Dataset.from_dict({
    "full_text": [x["full_text"] for x in data],
    "document": [x["document"] for x in data],
    "tokens": [x["tokens"] for x in data],
    "trailing_whitespace": [x["trailing_whitespace"] for x in data],
    "provided_labels": [x["labels"] for x in data],
    "fold": [x["document"] % 4 for x in data]
})

label2id = {label: i for i, label in enumerate(LABELS)}
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
ds = ds.map(
    tokenize, 
    fn_kwargs={"tokenizer": tokenizer, "label2id": label2id, "max_length": 1024},
    num_proc = 4
).remove_columns(["full_text", "trailing_whitespace", "provided_labels"])



Map (num_proc=4):   0%|          | 0/6807 [00:00<?, ? examples/s]

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [53]:
def build_flatten_ds(ds):
    features = list(ds.features.keys())
    dataset_dict = {feature: [] for feature in features}

    for example in tqdm(ds, total=len(ds)):
        #Also make sure everything is a list
        for feature in features:
            assert isinstance(example[feature], list), f"Feature {feature} is not a list"
        for feature in features:
            dataset_dict[feature].extend(example[feature])

    return Dataset.from_dict(dataset_dict)

In [54]:
ds = build_flatten_ds(ds)

100%|██████████| 6807/6807 [00:27<00:00, 243.17it/s]


In [55]:
valid_ds = ds.filter(lambda x: x["fold"] == 0, num_proc = 4)

Filter (num_proc=4):   0%|          | 0/7605 [00:00<?, ? examples/s]

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [56]:
valid_ds

Dataset({
    features: ['document', 'tokens', 'fold', 'input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping', 'labels', 'token_map', 'token_idxs_mapping'],
    num_rows: 1899
})

In [57]:
preds = [np.random.rand(len(offsets), 13) for offsets in valid_ds["offset_mapping"]]
token_idxs_mapping = valid_ds["token_idxs_mapping"]

In [59]:
len(preds), len(token_idxs_mapping)

(1899, 1899)

In [63]:
tic = time.perf_counter()
for i, pred in enumerate(preds):
    averaged_predictions = get_merged_preds(pred, token_idxs_mapping[i])
toc = time.perf_counter()
print(f"Non-vectorized method took {toc - tic:0.4f} seconds")

Non-vectorized method took 34.8940 seconds


In [70]:
tic = time.perf_counter()
new_preds = Parallel(n_jobs=8)(delayed(get_merged_preds)(preds[i], token_idxs_mapping[i]) for i in range(len(preds)))
toc = time.perf_counter()
print(f"Parallel method took {toc - tic:0.4f} seconds")

Parallel method took 8.6647 seconds


In [71]:
len(new_preds), len(new_preds[0])

(1899, 2)

In [60]:
tic = time.perf_counter()
averaged_predictions = get_merged_preds_vectorized(preds, token_idxs_mapping)
toc = time.perf_counter()
print(f"Vectorized method took {toc - tic:0.4f} seconds")

  ary = asanyarray(ary)


ValueError: All arrays must be of the same length

Non-vectorized method took 34.9694 seconds
