This notebook looks at the output from the snippet repository
and how to use it to train NER, classification, and mlm models.

In [1]:
from functools import partial
from importlib import reload
from typing import List

import torch
from spacy import displacy
import src.data.snippet_repository as sr

### Named Enitiy Recognition Models (NER)

In [2]:
ner_repo = sr.SnippetRepository(sr.SnippetRepositoryMode.NER)

In [3]:
ner_data = ner_repo.get_training_data(batch_size=10)
detected = False
while not detected:
    ner_df = next(ner_data)
    detected = any(ner_df['ner_tags'].apply(lambda ner_tags: any(map(lambda t: t!="O", ner_tags))))

In [4]:
text = ner_df.iloc[9].text
ner_tags = ner_df.iloc[9].ner_tags
sr.visualize_ner_tags(text, ner_tags)

In [5]:
mlm_repo = sr.SnippetRepository(sr.SnippetRepositoryMode.MASKED_LM)

In [45]:
mlm_data = mlm_repo.get_training_data(batch_size=10, balance_labels=True)
detected = False
while not detected:
    mlm_df = next(mlm_data)
    detected = any(mlm_df['mask'].apply(lambda token_masks: any(token_masks)))

In [46]:
mlm_df

Unnamed: 0,text,pos_tags,mask,label
0,"[However, ,, the, percentages, of, blacks, ear...","[RB, ,, DT, NNS, IN, NNS, VBG, NNP, NN, POS, N...","[False, False, False, False, False, False, Fal...",0
1,"[5b, shows, DeKalb, County, ,, AL, ,, which, i...","[NNP, VBZ, NNP, NNP, ,, NNP, ,, WDT, IN, DT, J...","[False, False, False, False, False, False, Fal...",0
2,"[TIMSS, =, Trends, in, International, Mathemat...","[NNP, IN, NNS, IN, NNP, NNP, CC, NNP, NNP, .]","[False, False, True, True, True, True, True, T...",1
3,"[There, is, only, one, accuracy, for, the, fif...","[EX, VBZ, RB, CD, NN, IN, DT, JJ, NN, IN, DT, ...","[False, False, False, False, False, False, Fal...",0
4,"[x, G, :, Figure, 1, summarizes, the, relation...","[NFP, NN, :, NN, CD, VBZ, DT, NNS, IN, NNS, .]","[False, False, False, False, False, False, Fal...",0
5,"[Percentage, distribution, of, ever, married, ...","[NN, NN, IN, RB, VBN, CD, SYM, CD, NN, JJ, NNS...","[False, False, False, False, False, False, Fal...",1
6,"[A, variety, of, different, mixed, -, effects,...","[DT, NN, IN, JJ, JJ, HYPH, NNS, NNS, ,, DT, NN...","[False, False, False, False, False, False, Fal...",0
7,"[SOURCES, :, National, Science, Foundation, ,,...","[NNS, :, NNP, NNP, NNP, ,, NNP, NNP, IN, NNP, ...","[False, False, False, False, False, False, Fal...",1
8,"[Data, come, from, a, positively, selected, sa...","[NNS, VBP, IN, DT, RB, VBN, NN, IN, JJ, NNS, W...","[False, False, False, False, False, False, Fal...",0
9,"[Table, S6, Data, used, in, preparation, of, t...","[NN, NNP, NNP, VBN, IN, NN, IN, DT, NN, VBD, V...","[False, False, False, False, False, False, Fal...",1


In [47]:
query = mlm_df.sample(3)

support = mlm_df.drop(query.index)
support

Unnamed: 0,text,pos_tags,mask,label
0,"[However, ,, the, percentages, of, blacks, ear...","[RB, ,, DT, NNS, IN, NNS, VBG, NNP, NN, POS, N...","[False, False, False, False, False, False, Fal...",0
1,"[5b, shows, DeKalb, County, ,, AL, ,, which, i...","[NNP, VBZ, NNP, NNP, ,, NNP, ,, WDT, IN, DT, J...","[False, False, False, False, False, False, Fal...",0
2,"[TIMSS, =, Trends, in, International, Mathemat...","[NNP, IN, NNS, IN, NNP, NNP, CC, NNP, NNP, .]","[False, False, True, True, True, True, True, T...",1
3,"[There, is, only, one, accuracy, for, the, fif...","[EX, VBZ, RB, CD, NN, IN, DT, JJ, NN, IN, DT, ...","[False, False, False, False, False, False, Fal...",0
4,"[x, G, :, Figure, 1, summarizes, the, relation...","[NFP, NN, :, NN, CD, VBZ, DT, NNS, IN, NNS, .]","[False, False, False, False, False, False, Fal...",0
7,"[SOURCES, :, National, Science, Foundation, ,,...","[NNS, :, NNP, NNP, NNP, ,, NNP, NNP, IN, NNP, ...","[False, False, False, False, False, False, Fal...",1
8,"[Data, come, from, a, positively, selected, sa...","[NNS, VBP, IN, DT, RB, VBN, NN, IN, JJ, NNS, W...","[False, False, False, False, False, False, Fal...",0


In [48]:
query

Unnamed: 0,text,pos_tags,mask,label
9,"[Table, S6, Data, used, in, preparation, of, t...","[NN, NNP, NNP, VBN, IN, NN, IN, DT, NN, VBD, V...","[False, False, False, False, False, False, Fal...",1
6,"[A, variety, of, different, mixed, -, effects,...","[DT, NN, IN, JJ, JJ, HYPH, NNS, NNS, ,, DT, NN...","[False, False, False, False, False, False, Fal...",0
5,"[Percentage, distribution, of, ever, married, ...","[NN, NN, IN, RB, VBN, CD, SYM, CD, NN, JJ, NNS...","[False, False, False, False, False, False, Fal...",1


In [20]:
from typing import Any, Dict

import datasets


def make_labels_dataset(data:Dict[str, Any]) -> Dict[str, Any]:

    convert_to_float = lambda vals: [float(x) for x in vals]

    return {
        "mask_token_indicator" : list(map(convert_to_float, data["mask"])),
    }

In [53]:
query

Unnamed: 0,text,pos_tags,mask,label
9,"[Table, S6, Data, used, in, preparation, of, t...","[NN, NNP, NNP, VBN, IN, NN, IN, DT, NN, VBD, V...","[False, False, False, False, False, False, Fal...",1
6,"[A, variety, of, different, mixed, -, effects,...","[DT, NN, IN, JJ, JJ, HYPH, NNS, NNS, ,, DT, NN...","[False, False, False, False, False, False, Fal...",0
5,"[Percentage, distribution, of, ever, married, ...","[NN, NN, IN, RB, VBN, CD, SYM, CD, NN, JJ, NNS...","[False, False, False, False, False, False, Fal...",1


In [140]:
from itertools import starmap
from typing import Callable, Tuple


def apply_mask_sample(tokens:List[str], mask_token_indicator:List[float]) -> List[str]:
    tokens = list(map(
        lambda t, m: "[MASK]" if m else t, 
        tokens, 
        mask_token_indicator
    ))
    return tokens

def group_mask_sample(tokens:List[str], mask_token_indicator:List[float]) -> List[str]:
   
    # group the masks
    grouped_text_masks = [tokens[0]]
    grouped_mask_token_indicator = [mask_token_indicator[0]]
    
    for index in range(1, len(tokens)):
        if not (mask_token_indicator[index] == 1 and mask_token_indicator[index-1] == 1):
            grouped_text_masks.append(tokens[index])
            grouped_mask_token_indicator.append(mask_token_indicator[index])

    return grouped_text_masks, grouped_mask_token_indicator

def apply_mask_batched(dataset:Dict[str, Any]) -> Dict[str, Any]:
    # inintially every token is masked, however, we want to group them
    # so that a single token represents an entire dataset
    ungrouped_masks = list(starmap(
        apply_mask_sample,
        zip(dataset["text"], dataset["mask"]),
    ))

    text_mask = list(zip(*list(starmap(
        group_mask_sample,
        zip(ungrouped_masks, dataset["mask"]),
    ))))

    dataset["text"], dataset["mask"] = list(text_mask[0]), list(text_mask[1])

    return dataset


def tokenize_and_align_labels(tokenizer_f:Callable, examples:Dict[str, Any]) -> Dict[str, Any]:
    tokenized_inputs = tokenizer_f(examples["text"])

    labels = []
    for i, label in enumerate(examples["mask_token_indicator"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = [-100] * len(word_ids) # assume all tokens are special
        top_word_id = max(map(lambda x: x if x else -1, word_ids))
        for word_idx in range(top_word_id + 1):
            label_ids[word_ids.index(word_idx)] = label[word_idx]
        labels.append(label_ids)

    tokenized_inputs["mask_token_indicator"] = labels
    return tokenized_inputs


def convert_to_T(T:type, vals:List[str]) -> List[float]:
    return [T(x) for x in vals]

def convert_dataset(
    tokenizer_f:Callable, 
    collator:Callable,
    dataset:datasets.Dataset
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:

    convert_f = partial(convert_to_T, int)

    dataset = dataset.map(
        lambda dset: { "mask_token_indicator" : list(map(convert_f, dset["mask"]))},
        batched=True,
    ).map(
        partial(tokenize_and_align_labels, tokenizer_f),
        batched=True,
    ).remove_columns(
        ["text", "mask"]
    # we rename mask_token_indicator to labels, because that is what the
    # so that the data collator will pad it.
    ).rename_column(
        "mask_token_indicator", "labels"
    ).rename_column(
        "label", "seq_labels"
    )

    # the collator doesn't know what to do with the seq_labels, so we
    # remove it, and then add it back in after the collator is done.
    # The collator also changes our type from Dataset to dict, so we
    # we are now working with a dictionary of tensors.
    tmp_seq_labels = dataset["seq_labels"]
    dataset = collator(list(dataset.remove_columns(["seq_labels"])))

    dataset_inputs = dict(
        input_ids=dataset["input_ids"],
        attention_mask=dataset["attention_mask"],
    )

    dataset_labels = dict(
        mask_token_indicator = dataset["labels"],
        seq_labels = tmp_seq_labels,
    )
    # At this point, the dataset is a dictionary of tensors, where the
    # tensors are all the same length.

    return dataset_inputs, dataset_labels


tokenizer = AutoTokenizer.from_pretrained(
    "distilbert-base-cased",
    do_lower_case = False,
)
tokenizer_f = partial(tokenizer, is_split_into_words=True, truncation=True)

collator = tfs.data.data_collator.DataCollatorForTokenClassification(
    tokenizer,
    return_tensors="pt",
    label_pad_token_id=0,
)

data = datasets.Dataset.from_pandas(
    mlm_df.drop(columns=["pos_tags"])
).train_test_split(train_size=3)

# these have the columns:
#  text, mask, label
ds_query, ds_support = data["train"], data["test"]

# the suport set has the labels masked out
ds_support = ds_support.map(apply_mask_batched, batched=True)


dset, lbl = convert_dataset(tokenizer_f, collator, ds_support)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [141]:
dset

{'input_ids': tensor([[  101,   138,  2783,  ...,     0,     0,     0],
         [  101,   193,   144,  ...,     0,     0,     0],
         [  101,   126,  1830,  ...,     0,     0,     0],
         ...,
         [  101,  1247,  1110,  ...,     0,     0,     0],
         [  101,   156,  2346,  ...,     0,     0,     0],
         [  101, 14286,  8298,  ...,   125,  1830,   102]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 1, 1, 1]])}

In [137]:
lbl

{'mask_token_indicator': tensor([[-100,    0, -100,  ...,    0,    0,    0],
         [-100,    0, -100,  ..., -100, -100, -100],
         [-100,    0,    0,  ...,    0,    0,    0],
         ...,
         [-100,    0,    0,  ...,    0,    0,    0],
         [-100,    0,    0,  ...,    0,    0,    0],
         [-100,    0, -100,  ...,    0,    0,    0]]),
 'seq_labels': [1, 1, 0, 0, 0, 0, 1]}

In [147]:
mask_special = ((lbl["mask_token_indicator"][0,:] == -100).float() - 1).abs()

In [148]:
dset["attention_mask"][0,:] * mask_special

tensor([0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.])

In [98]:
ds_support[2]["text"]

'[MASK]'

In [86]:
import transformers as tfs




data = datasets.Dataset.from_pandas(
    mlm_df.drop(columns=["pos_tags"])
).train_test_split(train_size=3)

ds_query, ds_support = data["train"], data["test"]

ds_query = ds_query.map(
    lambda row: { "mask_token_indicator" : list(map(convert_to_float, row["mask"]))},
    batched=True,
    remove_columns=["mask"]
).map(
    partial(tokenize_and_align_labels, tokenizer_f), 
    batched=True,
).rename_column("label", "seq_label").remove_columns(["text"])
dq_query_labels = ds_query.remove_columns(["input_ids"])
ds_query = ds_query.remove_columns(["seq_label", "mask_token_indicator"])

print(
    "labels {}\n".format(dq_query_labels.column_names),
    "features {}".format(ds_query.column_names),
)

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

labels ['seq_label', 'mask_token_indicator', 'attention_mask']
 features ['input_ids', 'attention_mask']


In [88]:
def apply_mask_sample(tokens:List[str], mask_token_indicator:List[float]) -> List[str]:
    return list(map(
        lambda t, m: "[MASK]" if m else t, 
        tokens, 
        mask_token_indicator
    ))

def apply_mask_batched(sample:Dict[str, Any]) -> Dict[str, Any]:
    sample["text"] = apply_mask_sample(sample["text"], sample["mask_token_indicator"])
    return sample


SyntaxError: incomplete input (4031507676.py, line 15)

In [87]:
collator(list(dq_query_labels))

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['seq_label', 'mask_token_indicator', 'attention_mask']

In [84]:
collator(list(ds_query))

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': tensor([[  101,  1438,   117,  1103,  6556,  1116,  1104, 14892,  6957,   156,
           111,   142,  8091,   112,   188,  4842,  1121,   145,  9428,  2591,
          1116,  1105,  1104,  6098,  1116,  6957,   156,   111,   142,  8091,
           112,   188,  4842,  1121,   145,  3048, 27485,  1138,  5799,  1290,
          1630,   119,   102,     0],
        [  101,   157, 13371, 12480,   134,   157,  5123,  3680,  1107,  1570,
          9833,  1105,  2444,  8690,   119,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101, 11389,   156,  1545,  7154,  1215,  1107,  7288,  1104,  1142,
          3342,  1127,  3836,  1121,  1103, 24278,   112,   188, 20012,   151,
          8816,  8136, 26772, 13508,   113,  5844, 27451,   114,  8539,   113,
          8050,  2605,   119, 25338,  260

In [83]:
from transformers import AutoTokenizer

def apply_mask_sample(tokens:List[str], mask_token_indicator:List[float]) -> List[str]:
    return list(map(
        lambda t, m: "[MASK]" if m else t, 
        tokens, 
        mask_token_indicator
    ))

def apply_mask_batched(sample:Dict[str, Any]) -> Dict[str, Any]:
    sample["text"] = apply_mask_sample(sample["text"], sample["mask_token_indicator"])
    return sample

def tokenize_and_align_labels(tokenizer_f, examples):
    tokenized_inputs = tokenizer_f(examples["text"])

    labels = []
    for i, label in enumerate(examples["mask_token_indicator"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = [-100] * len(word_ids) # assume all tokens are special
        top_word_id = max(map(lambda x: x if x else -1, word_ids))
        for word_idx in range(top_word_id + 1):
            label_ids[word_ids.index(word_idx)] = label[word_idx]
        labels.append(label_ids)

    tokenized_inputs["mask_token_indicator"] = labels
    return tokenized_inputs

def convert_to_float(vals:List[str]) -> List[float]:
    return [float(x) for x in vals]

keep_cols = ['seq_label', 'mask_token_indicator', 'input_ids', 'attention_mask']

tokenizer = AutoTokenizer.from_pretrained(
    "distilbert-base-cased",
    do_lower_case = False,
)
tokenizer_f = partial(tokenizer, is_split_into_words=True, truncation=True)

collator = tfs.data.data_collator.DataCollatorForTokenClassification(
    tokenizer,
    return_tensors="pt",
)


q_dset = datasets.Dataset.from_pandas(
    query.drop(columns=["pos_tags"])
).map(
    lambda row: { "mask_token_indicator" : list(map(convert_to_float, row["mask"]))},
    batched=True,
    remove_columns=["mask"]
).map(
    partial(tokenize_and_align_labels, tokenizer_f), 
    batched=True,
).rename_column("label", "seq_label")

q_dset.remove_columns(list(filter(lambda c: c not in keep_cols, q_dset.column_names)))

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['seq_label', 'mask_token_indicator', 'input_ids', 'attention_mask'],
    num_rows: 3
})

In [None]:
support = datasets.Dataset.from_pandas(
    support.drop(columns=["pos_tags"]).rename(columns={"mask_token_indicator": "mask"})
).map(

In [34]:
import datasets as ds
import pandas as pd 

def make_labels_dataset(pandas:pd.DataFrame) -> ds.Dataset:
    keep_cols = ["text", "mask", "label"]

    convert_to_float = lambda vals: [float(x) for x in vals]

    labelset = ds.Dataset.from_pandas(
        pandas.rename(columns={"mask_token_indicator": "mask"})
    ).map(
        lambda row: { "mask_token_indicator" : list(map(convert_to_float, row["mask"]))},
        batched=True,
        remove_columns=["mask"]
    )

    labelset = labelset.remove_columns(list(filter(
        lambda c: c not in keep_cols,
        labelset.column_names
    ))) 

    return labelset






make_labels_dataset(query)

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['text', 'label'],
    num_rows: 3
})