# functions to create the dataset to be used by the fine-tuned BERT classifier

## Questions and thoughts
- Tutorial: https://huggingface.co/docs/transformers/custom_datasets
- Context texts must be limited to 512 tokens (Limit for BERT model)
- When labeling the dataset, should the labels be start, end, or start and inside? In other projects (with answer extraction) it seems they use start, end..
- Another option is to insert a higlight token around the sentence containing the answer, and then append the answers after a [SEP] token. As in: 
- There are multiple answer spans in the same context text.. Should those be labeled jointly? / should I have multiple instances of the same texts?
- My idea is to use the original text, no stopword removal or lemmatization.

In [106]:
labels = [
    "0",
    "B-answer",
    "I-answer",
]

In [107]:
# necessary library imports
import pandas as pd
import numpy as np
import json

In [108]:
# data imports, to be combined into the final datastructure
df_train_cleaned = pd.read_pickle("../data_frames/df_train_cleaned.pkl")
df_test = pd.read_pickle("../data_frames/df_test.pkl")

In [109]:
def find_answer_start(answer, sent):
    for idx, word in enumerate(sent):
        if answer[0] in word:
            is_match = True
            for i in range(len(answer)):
                if answer[i] not in sent[idx+i] :
                    is_match = False
            if is_match:
                return idx

    return None

In [110]:
def get_tokens_and_labels(sentences, answer, sent_with_ans_id):
    context_text = []
    all_labels = []

    # get the labels and tokens for current sentence
    for idx, sent in enumerate(sentences):
        context_text += sent # concatenate all sentences to a list of consecutive tokens
        labels = np.zeros(len(sent))
        if idx == sent_with_ans_id:
            # the answer is contained in this sentence!
            idx_s = find_answer_start(answer, sent)
            labels[idx_s] = 1
            for i in range(len(answer)-1):
                labels[idx_s + i + 1] = 2
        all_labels.append(labels)
    l = np.concatenate(all_labels).ravel()
    return context_text, l


In [111]:
# create the dataset with the corresponding labels
def label_data(df):
    data_map = {}
    num_removed = 0
    for index, row in df.iterrows():
        sentences = row['context_raw']
        sent_with_ans_id = row['answer_location']
        answer = row['correct_answer_raw']

        context_text, labels = get_tokens_and_labels(sentences, answer, sent_with_ans_id)

        # check if the current text is in the data map, and update the labels accordingly!
        if row['context'] in data_map:
            old_point = data_map[row['context']]
            o_labels = old_point['labels'].copy()
            add_answer_label = True
            for idx, label in enumerate(labels):
                if label > 0:
                    if o_labels[idx] == 0:
                        o_labels[idx] = label
                    elif label != o_labels[idx]:
                        print('setting class 3 ...')
                        print('answer: ', answer)
                        print('existing answer: ', old_point['answers'])
                        print('context: ', context_text)
                        o_labels[idx] = 3
                        add_answer_label = False # this means the answers are overlapping, but not equal! -> don't want this
                        num_removed += 1
                    else:
                        print('labels are a match! ')
                        add_answer_label = False
                        num_removed += 1
            if add_answer_label:
                old_point['labels'] = o_labels
                old_point['answers'].append(answer)
                data_map[row['context']] = old_point
                
        else:
            data_point = { 'id': index, 'labels': labels, 'tokens': context_text, 'answers': [answer] }
            data_map[row['context']] = data_point
    
    print('num removed: ', num_removed)
    # make labels josn compatible..
    for v in data_map.values():
        v['labels'] = [ int(x) for x in v['labels']]
    labeled_data = list(data_map.values())
    print('num data points: ', len(labeled_data))
    return labeled_data


In [112]:
labeled_data = label_data(df_train_cleaned)
labeled_df = pd.DataFrame(labeled_data)
labeled_df.to_pickle("labeled_training_data.pkl")
json_string = json.dumps(labeled_data, ensure_ascii=False)
with open('labeled_training_data.json', 'w') as outfile:
    json.dump(json_string, outfile)


labels are a match! 
labels are a match! 
labels are a match! 
setting class 3 ...
answer:  ['verkstäder']
existing answer:  [['ljuskronor'], ['i', 'små', 'verkstäder']]
context:  ['Smed/', 'Stålbyggare', 'Smeder', 'och', 'stålbyggare', 'arbetar', 'på', 'byggarbetsplatser', ',', 'verkstäder', ',', 'stålverk', 'eller', 'plåtslagerier', '.', 'Konstsmeder', 'arbetar', 'i', 'små', 'verkstäder', '.', 'Som', 'konstsmed', 'är', 'man', 'oftast', 'egen', 'företagare', '.', 'Arbetsuppgifter', 'Smeder', 'och', 'stålbyggare', 'utför', 'allt', 'från', 'stommar', 'till', 'inrednings-', 'och', 'utsmyckningsdetaljer', 'i', 'samband', 'med', 'ny-', 'eller', 'ombyggnation', '.', 'Stålbyggnad', 'används', 'alltmer', 'när', 'det', 'gäller', 'industrialiserat', 'byggande', '.', 'Stommarna', 'tillverkas', 'på', 'en', 'verkstad', ',', 'kontrolleras', 'och', 'CE-märks', 'och', 'skickas', 'sedan', 'ut', 'till', 'byggarbetsplatsen', '.', 'I', 'verkstaden', 'kapas', 'balkar', 'och', 'pelare', 'till', 'rätt', 'lä

In [113]:
# add labels to the test data
labeled_test_data = label_data(df_test)
labeled_test_df = pd.DataFrame(labeled_test_data)
labeled_test_df.to_pickle("labeled_test_data.pkl")
json_test_string = json.dumps(labeled_test_data, ensure_ascii=False)
with open('labeled_test_data.json', 'w') as outfile:
    json.dump(json_test_string, outfile)

labels are a match! 
labels are a match! 
num removed:  2
num data points:  45


In [96]:
with open('labeled_training_data.json') as json_file:
    data = json.load(json_file)
    print(data)

[{"id": 0, "labels": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 