## Imports

In [1]:
import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.set_warn_always(False)
%load_ext autoreload
%autoreload 2

from datasets import load_dataset, DatasetDict, concatenate_datasets


## Load Data

In this set of datasets, we are combining Stanford SQuAD, Wiki QA and Trivia QA datasets.

Common features have to be merged in order to ensure data consistency

In [7]:
from datasets import DatasetDict, concatenate_datasets, Features, Value, Sequence

# Preprocess SQuAD
def preprocess_squad(examples):
    return {
        "id": examples["id"],
        "question": examples["question"],
        "context": examples["context"],
        "answers": 
            {"text": examples["answers"]["text"], "answer_start": [int(start) for start in examples["answers"]["answer_start"]]}
    }

# Preprocess WikiQA
def preprocess_wikiqa(examples):
    # answers = {"text": examples["answer"], "answer_start": [0] * len(examples["answer"])}
    answers = {"text": [examples["answer"]] if isinstance(examples["answer"], str) else examples["answer"],
           "answer_start": [0] * len(examples["answer"])}
    return {
        "id": examples["question_id"],
        "question": examples["question"],
        "context": examples["document_title"],
        "answers": answers,
    }

# Preprocess TriviaQA
def preprocess_triviaqa(examples):
    context = ", ".join(examples["entity_pages"]["wiki_context"]) or ", ".join(examples["search_results"]["search_context"])
    answer_text = examples["answer"]["normalized_value"]
    answer_start = context.find(answer_text) if answer_text in context else -1
    # answers = {"text": answer_text, "answer_start": [int(answer_start)]}  # Ensure answer_start is int32
    answers = {"text": [answer_text], "answer_start": [int(answer_start)]}
    return {
        "id": examples["question_id"],
        "question": examples["question"],
        "context": context,
        "answers": answers,
    }

# Load datasets and preprocess
squad = load_dataset("squad")
wikiqa = load_dataset("wiki_qa")
triviaqa = load_dataset("trivia_qa", "rc")

squad = squad.map(preprocess_squad)
wikiqa = wikiqa.map(preprocess_wikiqa)
triviaqa = triviaqa.map(preprocess_triviaqa)

squad = squad.remove_columns(["title"])
wikiqa = wikiqa.remove_columns(["question_id", "document_title", "answer", "label"])
triviaqa = triviaqa.remove_columns(['question_id', 'question_source', 'entity_pages', 'search_results', 'answer'])

# Ensure all datasets share the same features schema
common_features = Features({
    "id": Value("string"),
    "question": Value("string"),
    "context": Value("string"),
    "answers": Sequence(
        {
            "text": Value("string"),
            "answer_start": Value("int32")
        }
    )
})

print(squad["train"].features)
print(wikiqa["train"].features)
print(triviaqa["train"].features)


# Cast datasets to common schema
squad = squad.cast(common_features)
wikiqa = wikiqa.cast(common_features)
triviaqa = triviaqa.cast(common_features)

# Combine datasets
# qa_datasets = DatasetDict({
#     "train": concatenate_datasets([squad["train"], wikiqa["train"], triviaqa["train"]]),
#     "validation": concatenate_datasets([squad["validation"], wikiqa["validation"], triviaqa["validation"]]),
# })

# print(qa_datasets)


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

{'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}
{'question': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': {'answer_start': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'text': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}}
{'question': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': {'answer_start': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'text': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}}


In [None]:
# print(qa_datasets["train"].features)

{'id': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}


## Run BO

In [10]:
from BO import iterative_loop, get_BO_plots
from botorch.optim import optimize_acqf
from botorch.acquisition import UpperConfidenceBound

def run_BO(all_datasets, validation_dataset, iterations, num_epochs=20, printout=False):
    print("running BO...")
    X, observations, gp = iterative_loop(all_datasets, validation_dataset, num_epochs=num_epochs, iterations=iterations, printout=printout)
    BO_to_plot = get_BO_plots(observations) # BO results
    naive_combine = BO_to_plot[0] # naive mixing result is the first iteration result of BO

    # plot model performance as BO progresses...
    plt.plot(range(len(BO_to_plot)), BO_to_plot, c="blue", alpha=0.3, label="BO on mixing ratio")
    plt.axhline(naive_combine, linestyle="--", c="red", label="sample from each data source equally")
    plt.xlabel("BO iterations")
    plt.ylabel("accuracy on evaluation task")
    plt.legend()
    plt.show()

    # plot posterior
    # posterior_acc = []
    # for x in np.linspace(0,1,100):
    #     posterior_acc.append(gp.posterior(torch.Tensor([[x,1-x]])).mean.item())
        
    # plt.plot(np.linspace(0,1,100), posterior_acc)
    # plt.xlabel("mixing ratio (percentage on cats and dogs)")
    # plt.ylabel("accuracy")
    # plt.title("evaluation ratio : 1.0 cats and dogs")
    # plt.show()

    def get_optimal_mixture_from_GP_posterior():
        UCB = UpperConfidenceBound(gp, beta=0.0)
        bounds = torch.stack([torch.zeros(len(all_datasets)), torch.ones(len(all_datasets))]) # need to change the bounds for parameters
        A = [1.0] * len(all_datasets)
        x = list(range(len(all_datasets)))
        candidate, acq_value = optimize_acqf(
            UCB, bounds=bounds, q=1, num_restarts=20, raw_samples=30,
            equality_constraints = [(torch.tensor(x), torch.tensor(A), 1)]
        )
        return candidate
    

    def get_best_observation_mixture():
        
        # Find the index in list B that has the highest value
        highest_index = observations.index(max(observations))
        
        # Return the corresponding item in list A
        return X[highest_index]

    
    print("best mixture found in BO iterations is: ", get_best_observation_mixture())
    
    return X, observations, gp

## Run BO

In [20]:
from helper import sample_from

evaluation_task_data_ratio = 0.7
print("train data: SQuAD, WikiQA, TriviaQA")
print("test data ratio: SQuAD {}%, WikiQA {}%, TriviaQA {}%".format(round(evaluation_task_data_ratio * 100), round((1-evaluation_task_data_ratio)*100), 0))

# data for training
train_datasets = [squad["train"], wikiqa["train"], triviaqa["train"]]

# data just for evaluation; in real life we do not know this mixture. We can change this composition ratio or use other loaders
# for example, this uses SQuAD and Wiki QA in a ratio of 70% and 30% in the evaluation task.
validation_datasets = [squad["validation"], wikiqa["validation"], triviaqa["validation"]]
validation_ratio = [evaluation_task_data_ratio, 1-evaluation_task_data_ratio, 0]
validation_dataset = sample_from(validation_datasets, validation_ratio, seed=2024)

iterations=10
# run BO on different mixtures of the training data. At each iteration, we train a model using the training data mixture and check the evaluation performance on validation dataset
run_BO(train_datasets, validation_dataset, iterations, printout=True) # should print a ratio that is close to evaluation_task_data_ratio



train data: SQuAD, WikiQA, TriviaQA
test data ratio: SQuAD 70%, WikiQA 30%, TriviaQA 0%
running BO...
mixing data...


Map:   0%|          | 0/73901 [00:00<?, ? examples/s]

Map:   0%|          | 0/8212 [00:00<?, ? examples/s]

Map:   0%|          | 0/8217 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


  0%|          | 0/13857 [00:00<?, ?it/s]

KeyboardInterrupt: 