This script adds more fields to the HotPotQA dataset and generate more data,
for statement & verdict training:
  - choose a random direction: True or False
  - generate statement based on the direction and the QA
  - generate some new statements that has no coresponding context in the retriever as "irrelevant"
  - save to file

TODO:
  - quality control automatically
  - more diverse irrelevant statements

In [10]:
import os

select_dataset_total = 100000  # how many to select from each portion of the dataset
concurrency = 24  # LLM calling
irrelevant_percent = 0.3  # how many irrelevant statements to created based on the number of total True & False

OPENAI_BASE_URL = "http://127.0.0.1:8010/v1/"
os.environ["OPENAI_API_KEY"] = "aaaa"

In [11]:
import dspy

llm = dspy.OpenAI(model='google/gemma-2-9b-it', api_base=OPENAI_BASE_URL, max_tokens=200, stop='\n\n') # model_type="chat", 
dspy.settings.configure(lm=llm)

In [3]:
import random
from datasets import load_dataset
from dspy.datasets.dataset import Dataset

class HotPotSV():
    """
    HotSpotQA to statements + verdicts
    """
    def __init__(
        self,
        train_size = 50,
        validation_size = 50,
        keep_details = False,
    ):

        hf_official_train = load_dataset("hotpot_qa", 'fullwiki', split='train', trust_remote_code=True)
        hf_official_validation = load_dataset("hotpot_qa", 'fullwiki', split='validation', trust_remote_code=True)
        # `test` split has no answer

        self.train = self.process_dataset(hf_official_train, train_size)
        self.validation = self.process_dataset(hf_official_validation, validation_size)

    def process_dataset(self, dataset, size, keep_details = False):
        rep = []
        for raw_example in dataset:
            if keep_details is True:
                keys = ['id', 'question', 'answer', 'type', 'supporting_facts', 'context', 'level']
            elif keep_details == 'validation_titles':
                keys = ['question', 'answer', 'supporting_facts', 'level']
            else:
                keys = ['question', 'answer', 'level']

            example = {k: raw_example[k] for k in keys}
            
            if 'supporting_facts' in example:
                example['gold_titles'] = set(example['supporting_facts']['title'])
                del example['supporting_facts']

            rep.append(example)

        rng = random.Random(0)
        rng.shuffle(rep)
        rep = rep[:size]
        return rep

In [4]:
dataset_HotPotSV = HotPotSV(train_size=select_dataset_total, validation_size=select_dataset_total)

In [5]:
class ChangeAnswerSignature(dspy.Signature):
    """Generate a new answer based on the question and correct answer"""
    question = dspy.InputField()
    answer = dspy.InputField(desc="correct answer")
    output = dspy.OutputField(desc="incorrect answer")
    
class CombineQA(dspy.Signature):
    """Combine the question and answer into one statement of direct expression with context"""
    question = dspy.InputField()
    answer = dspy.InputField()
    statement = dspy.OutputField(desc="generated from the question and answer only, with all the original info")
    
class GenerateStatement(dspy.Module):
    def __init__(self):
        super().__init__()
        self.change_answer = dspy.Predict(ChangeAnswerSignature)
        self.combine_qa = dspy.Predict(CombineQA)

    def forward(self, question, answer, level=None):
        direction = random.choice(["True", "False"])
        # direction = 'False'
        answer_fake = None
        if direction == "False":
            new_answer = self.change_answer(question=question, answer=answer).output
            answer_fake = new_answer
            # print(f"new_answer: {new_answer}")
        else:
            new_answer = answer
        statement = self.combine_qa(question=question, answer=new_answer).statement
        # statement = self.generate_statement(question=question, answer=new_answer).statement

        rep = {
            'question': question,
            'answer': answer,
            'verdict': direction,
            'statement': statement,
            'level': level,
        }
        
        if answer_fake:
            rep['answer_fake'] = answer_fake
        # print(f"with statement: {rep}\n")
        return rep

In [6]:
import concurrent.futures

def generate_datasets():
    def generate_statement(d):
        return GenerateStatement()(**d)
        
    train = []
    validation = []

    with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
        train_futures = [executor.submit(generate_statement, d) for d in dataset_HotPotSV.train]
        validation_futures = [executor.submit(generate_statement, d) for d in dataset_HotPotSV.validation]

        # Collecting results as they complete
        train = [future.result() for future in concurrent.futures.as_completed(train_futures)]
        validation = [future.result() for future in concurrent.futures.as_completed(validation_futures)]

    return train, validation

dataset_train_HotPotQA_generated, dataset_validation_HotPotQA_generated = generate_datasets()

print(dataset_train_HotPotQA_generated[:3])
print("\n")
print(dataset_validation_HotPotQA_generated[:3])

  return v1_cached_gpt3_request_v2(**kwargs)


[{'question': 'Which cartridge did John Browning design that has a rim that is the same in diameter as the .50 GI?', 'answer': '.45 ACP', 'verdict': 'False', 'statement': 'John Browning designed the .30-06 Springfield cartridge, which has a rim that is the same diameter as the .50 GI.', 'level': 'hard', 'answer_fake': '.30-06 Springfield'}, {'question': 'Mindless Self Indulgence and Tappi Tíkarrass are both what?', 'answer': 'band', 'verdict': 'False', 'statement': 'Mindless Self Indulgence and Tappi Tíkarrass are both musical instruments.', 'level': 'medium', 'answer_fake': 'musical instrument'}, {'question': "What was Robert Tree Cody's adopted father's most prominent PSA role?", 'answer': 'Keep America Beautiful', 'verdict': 'False', 'statement': "Robert Tree Cody's adopted father was a spokesperson for the American Cancer Society.", 'level': 'hard', 'answer_fake': 'He was a spokesperson for the American Cancer Society.'}]


[{'question': "Which of Damon Stoudamire's cousins once pl

In [7]:
"""Generate `irrelevant` datasets."""

from dsp.utils import deduplicate
import re

class GenerateLatestStatementSignature(dspy.Signature):
    """Generate a fact with content similar to the input and happened after year 2020"""
    input = dspy.InputField()
    output = dspy.OutputField(desc="a fact that exists after year 2020")

import concurrent.futures

def generate_latest_statements():
    """Generate statements that don't have related context in retriever."""

    def contains_number_in_range(s):
        """
        The retriever has data up to year 2017.
        Make sure the statement contains year number higher than 2020.
        """
        numbers = re.findall(r'\d+', s)
        
        # Check if any of these numbers are in the range 2021 to 2029 (inclusive)
        for number in numbers:
            num = int(number)
            if 2020 < num < 2030:
                return True
        
        return False
        
    def to_json(statements):
        return [{'statement': s, 'verdict': 'Irrelevant'} for s in statements]

    def process_data(data, count):
        processed = []
        
        def process_single(d):
            return dspy.Predict(GenerateLatestStatementSignature)(input=d['statement']).output

        with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
            futures = [executor.submit(process_single, d) for d in data]
            for future in concurrent.futures.as_completed(futures):
                s_new = future.result()
                if contains_number_in_range(s_new):
                    processed = deduplicate(processed + [s_new])
                if len(processed) >= count:
                    break
        
        return to_json(processed)

    count_train = len(dataset_train_HotPotQA_generated) * irrelevant_percent
    train = process_data(dataset_train_HotPotQA_generated, count_train)

    count_validation = len(dataset_validation_HotPotQA_generated) * irrelevant_percent
    validation = process_data(dataset_validation_HotPotQA_generated, count_validation)

    return train, validation

_irrelevant_train, _irrelevant_validation = generate_latest_statements()

from pprint import pprint
pprint(_irrelevant_train[:3])
print("\n")
pprint(_irrelevant_validation[:3])

[{'statement': 'In 2023, the Japanese sumo wrestler from Honolulu,  [Name of '
               "wrestler], won the prestigious Emperor's Cup at the Autumn "
               'Grand Sumo Tournament.',
  'verdict': 'irrelevant'},
 {'statement': 'After 2020,  Charles Giblyn directed the short film "The Last '
               'Supper," which premiered at the 2022 Cannes Film Festival.',
  'verdict': 'irrelevant'},
 {'statement': 'In 2021, a French-Scottish collaborative project published a '
               'new translation of Kenneth Grahame\'s "The Wind in the '
               'Willows" with an introduction by a prominent French '
               'philosopher.',
  'verdict': 'irrelevant'}]


[{'statement': 'In 2022, Cranium launched a new digital version of their '
               'classic board game, allowing players to compete online.',
  'verdict': 'irrelevant'},
 {'statement': 'In 2021, a documentary film titled "The Forgotten Revolution" '
               'explored the lesser-known stories 

In [8]:
import random

dataset_train_HotPotQA_generated += _irrelevant_train
dataset_validation_HotPotQA_generated += _irrelevant_validation

rng = random.Random(1)
rng.shuffle(dataset_train_HotPotQA_generated)
rng.shuffle(dataset_validation_HotPotQA_generated)

print(len(dataset_train_HotPotQA_generated))
print(len(dataset_validation_HotPotQA_generated))

117582
9627


In [9]:
"""Save datasets to file"""

import json
import os

base = "./datasets/HotPotQA"
data_files = {
    "train.json": dataset_train_HotPotQA_generated,
    "validation.json": dataset_validation_HotPotQA_generated,
}

os.makedirs(base, exist_ok=True)

for filename, data in data_files.items():
    file_path = os.path.join(base, filename)
    with open(file_path, 'w') as file:
        json.dump(data, file)