# Plan

OBJECTIVE: we want to train an address NER for German legal text.

we will do this by 
1. cleaning the existing address entity labels
2. building a semi-artificial dataset of real german legal documents with machine-generated addresses labeled and randomly placed among the real text

we will truncate the legal documents so that the model sees short spans of text.

truncation will serve 2 purposes
1. greater training speed 
2. the model is less likely to get conflicting signals from the artificial dataset if we only present a window where we know that there is an address, instead of a whole document where there may be unlabeled addresses

we need to design a strategy that places addresses in contexts where they might actually occur naturally:
- after line breaks
- after periods
- randomly inside sentences. this is harder. We can do randomly, or train a classifier to predict a good place to put an address on the subset of addresses that we have label (this is biased).




# HF Datasets

In [None]:
from datasets import load_dataset

In [None]:
wnut["validation"]["tokens"]

In [None]:
wnut["train"][:2].__class__

In [None]:
wnut["train"]["tokens"].__class__

In [None]:
label_list = wnut["train"].features[f"ner_tags"].feature.names
label_list

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:
example = wnut["train"][0]

In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
tokens

In [None]:
def shift_label(label):
    """assumes that B-XXX labels are even and 
    that corresponding I-XXX label is the next number"""
    if label % 2 == 1:
        label += 1
    return label

def align_labels_with_tokens(labels, word_ids, subword_strategy="label"):
    assert subword_strategy in ["label", "skip"]
    
    new_labels = []
    current_word = None
    
    for word_id in word_ids:
        if word_id is None:
            new_labels.append(-100)
        elif word_id != current_word:
            # start of a new word
            current_word = word_id
            new_labels.append(labels[word_id])
        else:
            # sub-word
            if subword_strategy == "label":
                new_labels.append(shift_label(labels[word_id]))
            else:
                new_labels.append(-100)
    
    return new_labels

def tokenize_and_align_labels(samples):
    tokenized_inputs = tokenizer(samples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(samples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        labels.append(align_labels_with_tokens(label, word_ids))
    
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)


In [None]:
wnut["validation"].map(tokenize_and_align_labels, batched=True)


In [None]:
tokenized_wnut

In [None]:
import numpy as np
from datasets import Dataset 
from torch.utils.data import DataLoader
data = np.random.rand(16)
label = np.random.randint(0, 2, size=16)
ds = Dataset.from_dict({"data": data, "label": label}).with_format("torch")
dataloader = DataLoader(ds, batch_size=3)
for batch in dataloader:
    print(batch)      

In [None]:
from datasets import Split

In [None]:
Dataset.from_dict?