## Imports

In [1]:
import torchvision
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.set_warn_always(False)
%load_ext autoreload
%autoreload 2

from datasets import load_dataset, DatasetDict, concatenate_datasets


## Load Data

In [4]:
from datasets import DatasetDict, concatenate_datasets, Features, Value, Sequence

# Preprocess SQuAD
def preprocess_squad(examples):
    return {
        "id": examples["id"],
        "question": examples["question"],
        "context": examples["context"],
        "answers": 
            {"text": examples["answers"]["text"], "answer_start": [int(start) for start in examples["answers"]["answer_start"]]}
    }

# Preprocess WikiQA
def preprocess_wikiqa(examples):
    # answers = {"text": examples["answer"], "answer_start": [0] * len(examples["answer"])}
    answers = {"text": [examples["answer"]] if isinstance(examples["answer"], str) else examples["answer"],
           "answer_start": [0] * len(examples["answer"])}
    return {
        "id": examples["question_id"],
        "question": examples["question"],
        "context": examples["document_title"],
        "answers": answers,
    }

# Preprocess TriviaQA
def preprocess_triviaqa(examples):
    context = ", ".join(examples["entity_pages"]["wiki_context"]) or ", ".join(examples["search_results"]["search_context"])
    answer_text = examples["answer"]["normalized_value"]
    answer_start = context.find(answer_text) if answer_text in context else -1
    # answers = {"text": answer_text, "answer_start": [int(answer_start)]}  # Ensure answer_start is int32
    answers = {"text": [answer_text], "answer_start": [int(answer_start)]}
    return {
        "id": examples["question_id"],
        "question": examples["question"],
        "context": context,
        "answers": answers,
    }

# Load datasets and preprocess
squad = load_dataset("squad")
wikiqa = load_dataset("wiki_qa")
triviaqa = load_dataset("trivia_qa", "rc")

squad = squad.map(preprocess_squad)
wikiqa = wikiqa.map(preprocess_wikiqa)
triviaqa = triviaqa.map(preprocess_triviaqa)

squad = squad.remove_columns(["title"])
wikiqa = wikiqa.remove_columns(["question_id", "document_title", "answer", "label"])
triviaqa = triviaqa.remove_columns(['question_id', 'question_source', 'entity_pages', 'search_results', 'answer'])

# Ensure all datasets share the same features schema
common_features = Features({
    "id": Value("string"),
    "question": Value("string"),
    "context": Value("string"),
    "answers": Sequence(
        {
            "text": Value("string"),
            "answer_start": Value("int32")
        }
    )
})

print(squad["train"].features)
print(wikiqa["train"].features)
print(triviaqa["train"].features)


# Cast datasets to common schema
squad = squad.cast(common_features)
wikiqa = wikiqa.cast(common_features)
triviaqa = triviaqa.cast(common_features)

# Combine datasets
qa_datasets = DatasetDict({
    "train": concatenate_datasets([squad["train"], wikiqa["train"], triviaqa["train"]]),
    "validation": concatenate_datasets([squad["validation"], wikiqa["validation"], triviaqa["validation"]]),
})

print(qa_datasets)


Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

Map:   0%|          | 0/6165 [00:00<?, ? examples/s]

Map:   0%|          | 0/2733 [00:00<?, ? examples/s]

Map:   0%|          | 0/20360 [00:00<?, ? examples/s]

Map:   0%|          | 0/138384 [00:00<?, ? examples/s]

Map:   0%|          | 0/17944 [00:00<?, ? examples/s]

Map:   0%|          | 0/17210 [00:00<?, ? examples/s]

{'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}
{'question': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': {'answer_start': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'text': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}}
{'question': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': {'answer_start': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'text': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}}


Casting the dataset:   0%|          | 0/6165 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2733 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20360 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/138384 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/17944 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/17210 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'context', 'answers'],
        num_rows: 246343
    })
    validation: Dataset({
        features: ['id', 'question', 'context', 'answers'],
        num_rows: 31247
    })
})


In [6]:
print(qa_datasets["train"].features)

{'id': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}


In [19]:
print(squad["train"][0])

{'id': '5733be284776f41900661182', 'title': 'University_of_Notre_Dame', 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}


## Run BO

In [None]:
from BO import iterative_loop, get_BO_plots
from botorch.optim import optimize_acqf
from botorch.acquisition import UpperConfidenceBound

def run_BO(all_loaders, validaton_dataloader, iterations, num_epochs=20, printout=False):
    print("running BO...")
    X, observations, gp = iterative_loop(all_loaders, validaton_dataloader, num_epochs=num_epochs, iterations=iterations, printout=printout)
    BO_to_plot = get_BO_plots(observations) # BO results
    naive_combine = BO_to_plot[0] # naive mixing result is the first iteration result of BO

    # plot model performance as BO progresses...
    plt.plot(range(len(BO_to_plot)), BO_to_plot, c="blue", alpha=0.3, label="BO on mixing ratio")
    plt.axhline(naive_combine, linestyle="--", c="red", label="sample from each data source equally")
    plt.xlabel("BO iterations")
    plt.ylabel("accuracy on evaluation task")
    plt.legend()
    plt.show()

    # plot posterior
    # posterior_acc = []
    # for x in np.linspace(0,1,100):
    #     posterior_acc.append(gp.posterior(torch.Tensor([[x,1-x]])).mean.item())
        
    # plt.plot(np.linspace(0,1,100), posterior_acc)
    # plt.xlabel("mixing ratio (percentage on cats and dogs)")
    # plt.ylabel("accuracy")
    # plt.title("evaluation ratio : 1.0 cats and dogs")
    # plt.show()

    def get_optimal_mixture_from_GP_posterior():
        UCB = UpperConfidenceBound(gp, beta=0.0)
        bounds = torch.stack([torch.zeros(len(all_loaders)), torch.ones(len(all_loaders))]) # need to change the bounds for parameters
        A = [1.0] * len(all_loaders)
        x = list(range(len(all_loaders)))
        candidate, acq_value = optimize_acqf(
            UCB, bounds=bounds, q=1, num_restarts=20, raw_samples=30,
            equality_constraints = [(torch.tensor(x), torch.tensor(A), 1)]
        )
        return candidate
    

    def get_best_observation_mixture():
        
        # Find the index in list B that has the highest value
        highest_index = observations.index(max(observations))
        
        # Return the corresponding item in list A
        return X[highest_index]

    
    print("best mixture found in BO iterations is: ", get_best_observation_mixture())
    
    return X, observations, gp