# Introduction:

* In this notebook I will try to investigate the **`SWAG`** datasets.
* The idea is to understand how to deal with multiple choice datasets and how to prepare them for the next step.
* Multiple choice is frequent problem in the filed of LLMs and NLP in general
* So the preprocessing of data will have a hige effect on the success of any proposed solution

In [None]:
# load the dataset
from datasets import load_dataset
dataset = load_dataset('swag', 'regular')


Downloading builder script:   0%|          | 0.00/7.97k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/7.10k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.88k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/6.71M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.24M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.21M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/73546 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/20006 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/20005 [00:00<?, ? examples/s]

In [None]:
# let's grab a sample
dataset['train'][0]

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0}

* These fields represent the idea begind this dataset
   - a situation where we have to predict the right ending
   - `sent1` and `sent2` represent the given situation and they added up to `startphrase`
   - `endings 0 to 3` represent the the endings for that situation, only one is the right
   - `label` index the right answer   

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [None]:
train_ds = dataset['train']

In [None]:
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names]
        for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

In [None]:
tokenized_ds = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/73546 [00:00<?, ? examples/s]

Map:   0%|          | 0/20006 [00:00<?, ? examples/s]

Map:   0%|          | 0/20005 [00:00<?, ? examples/s]

In [None]:
ends = ['end0', 'end1', 'end2', 'end3']
def func(example):
  first = [[context] * 4 for context in example['sent1']]
  heads = example['sent2']
  second = [[f'{head} {ends[i]}' for end in ends] for i, head in enumerate(heads)]
  flat_first = sum(first, [])
  flat_sencond = sum(second, [])
  toks = tokenizer(flat_first, flat_second, truncation=True)
  return {k:[v[i:i+4] for i in range(0, len(v), 4)] for k, v in toks.items()}