In [1]:
import os
import json
import pandas as pd
import random
from tqdm import tqdm
from vllm import LLM, SamplingParams
tqdm.pandas()
random.seed(42)

In [2]:
CHUNK_SIZE = 10
NUM_QUESTIONS_PER_CHUNK = 4

In [3]:
meta_prompt = """\
Context information is below.\n
Given the context information and no prior knowledge.\n
Generate only questions based on the below query.\n
The context of each question should be easily inferred by reading the question. If you are referring to a specific event, mention the exact name and date of the event.\n

You are a Teacher/ Professor. Your task is to setup \
{num_questions_per_chunk} multiple-choice questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. Do not repeat the same question twice. Restrict the questions to the \
context information provided. Each question should have exactly four possible answers. Only one answer should be correct. \
Your reply must be a single number (0, 1, 2, or 3) related to the correct answer. \ 
Return a valid JSON formatted string with the following fields: \
question, possible_answers, correct_answer.\
"""

context_prompt = """---------------------\n{context_str}\n---------------------\n\
Each question should have an index value as its key, given by the following list: {question_indices}. The end result should look like this:\n\n
{{'Q_23': {{'question': 'In what year did the NBA playoffs for the 2022-23 season begin?', 'possible_answers': ['January 1, 2023', 'April 1, 2023', 'June 1, 2023', 'October 1, 2022'], 'correct_answer': 1}}, 'Q_16254': {{'question': 'Which team won the 2023 NBA Finals?', 'possible_answers': ['Boston Celtics', 'Golden State Warriors', 'Denver Nuggets', 'Miami Heat'], 'correct_answer': 2}}}}\n\
"""

num_questions_per_chunk = 4
question_indices = ["Q_23", "Q_16254", "Q_224", "Q_3"]
context_str = "The 2023 NBA playoffs was the postseason tournament of the National Basketball Association's 2022–23 season. \
The playoffs began on April 15 and concluded on June 12 with the Denver Nuggets winning the 2023 NBA Finals."

message= [{"role": "user", "content": meta_prompt.format(num_questions_per_chunk=num_questions_per_chunk) + context_prompt.format(context_str=context_str, question_indices=question_indices)}]

## Load full data

In [4]:
df = pd.read_csv("wiki_events_aug-nov_2023.csv", index_col=0)
df.fillna("None", inplace=True)
df['text_len'] = df['text'].str.len()
df = df[df['text_len'] > 100]

df['prefix'] = ("### Article: " + df['topic_name'] + ".\n" +
                "### Section: " + df['section'] + ".\n" +
                "### Subsection: " + df['subsection'] + ".\n" +
                "### Text: ")
df = df[['prefix', 'text']]
df.reset_index(inplace=True, drop=True)
df

Unnamed: 0,prefix,text
0,### Article: 2023 Louisiana wildfires.\n### Se...,In a three-month period from August to October...
1,### Article: 2023 Louisiana wildfires.\n### Se...,"On August 22, a fire started in Beauregard Par..."
2,### Article: 2023 Louisiana wildfires.\n### Se...,"On August 24, a fire described as ""out of cont..."
3,### Article: 2023 Louisiana wildfires.\n### Se...,The Federal Emergency Management Agency approv...
4,### Article: United States abortion protests (...,A series of ongoing protests supporting aborti...
...,...,...
682,### Article: 2023 Virginia Senate election.\n#...,"Five incumbent senators, four Democrats and on..."
683,### Article: 2023 Virginia Senate election.\n#...,District 21: Won by State Delegate Angelia Wil...
684,### Article: 2023 Virginia Senate election.\n#...,District 1: Won by farmer Timmy French\r\nDist...
685,### Article: 2023 Virginia Senate election.\n#...,"Incumbent Republican Jen Kiggans, first electe..."


In [5]:
df_dict = df.to_dict(orient='records')
corpus = []

for row in tqdm(df_dict):
    sentences = row['text'].split(".")
    sentences = [s.strip() + '.' for s in sentences if len(s) > 0]
    chunks = [sentences[i:i + CHUNK_SIZE] for i in range(0, len(sentences), CHUNK_SIZE)]
    chunks = [' '.join(chunk).strip() for chunk in chunks]
    
    # Limit the number of chunks for each article to 2
    if len(chunks) > 2:
        chunks = random.sample(chunks, 2)
    
    # Adding the prefix to each chunk
    chunks = [row['prefix'] + chunk for chunk in chunks]
    corpus.extend(chunks)

100%|█████████████████████████████████████████████████████████████████████████████| 687/687 [00:00<00:00, 152985.76it/s]


## Generating the dataset

In [6]:
from transformers import AutoTokenizer
model_id = "casperhansen/llama-3-70b-instruct-awq"
model = LLM(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

INFO 05-26 18:54:39 llm_engine.py:73] Initializing an LLM engine with config: model='casperhansen/llama-3-70b-instruct-awq', tokenizer='casperhansen/llama-3-70b-instruct-awq', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=8192, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=awq, enforce_eager=False, seed=0)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


INFO 05-26 18:55:09 llm_engine.py:223] # GPU blocks: 4729, # CPU blocks: 819
INFO 05-26 18:55:10 model_runner.py:394] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-26 18:55:27 model_runner.py:437] Graph capturing finished in 17 secs.


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
def postprocess_output(output):
    output = output.outputs[0].text
    output = output[output.find("{"):output.rfind("}")+1]
    output = output.replace("\'", "\"").replace('\\', '')
    output = json.loads(output)
    return output

In [8]:
sampling_params = SamplingParams(
        temperature=0.2, top_p=0.95, max_tokens=500, stop_token_ids=terminators)
output = model.generate(tokenizer.apply_chat_template(message, tokenize=False),
                        sampling_params)
print(postprocess_output(output[0]))

Processed prompts: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:15<00:00, 15.07s/it]

{'Q_23': {'question': 'In what year did the NBA playoffs for the 2022-23 season begin?', 'possible_answers': ['2021', '2022', '2023', '2024'], 'correct_answer': 2}, 'Q_16254': {'question': 'Which team won the 2023 NBA Finals?', 'possible_answers': ['Boston Celtics', 'Golden State Warriors', 'Denver Nuggets', 'Miami Heat'], 'correct_answer': 2}, 'Q_224': {'question': 'On which date did the 2023 NBA playoffs conclude?', 'possible_answers': ['June 1, 2023', 'June 12, 2023', 'June 20, 2023', 'June 30, 2023'], 'correct_answer': 1}, 'Q_3': {'question': 'On which date did the 2023 NBA playoffs begin?', 'possible_answers': ['April 10, 2023', 'April 15, 2023', 'April 20, 2023', 'April 25, 2023'], 'correct_answer': 1}}





In [12]:
count = 0
json_path = 'current_events_questions_llama3.json'

NUM_PROMPTS_AT_ONCE = 10
corpus_chunks = [corpus[i:i+NUM_PROMPTS_AT_ONCE] for i in range(0, len(corpus), NUM_PROMPTS_AT_ONCE)]

p_bar = tqdm(corpus_chunks, desc="Processing chunks")
for corpus_chunk in p_bar:
    messages = []
    for chunk in corpus_chunk:
        question_indices = [f"Q_{i}" for i in range(count*NUM_QUESTIONS_PER_CHUNK, (count+1)*NUM_QUESTIONS_PER_CHUNK)]
        message = [{
            "role": "user", 
            "content": meta_prompt.format(num_questions_per_chunk=NUM_QUESTIONS_PER_CHUNK) + context_prompt.format(context_str=chunk, question_indices=question_indices)
        }]
        messages.append(tokenizer.apply_chat_template(message, tokenize=False))
    
    sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=500, stop_token_ids=terminators)
    output = model.generate(messages, sampling_params)
    for i in range(len(output)):
        try:
            python_dict = postprocess_output(output[i])
        except Exception as e:
            continue
        
        chunk = corpus_chunk[i]
        for q in python_dict:
            python_dict[q]['context'] = chunk
    
        if count > 0:
            with open(json_path, 'r') as json_file:
                json_dict = json.load(json_file)
            json_dict.update(python_dict)
        else:
            json_dict = python_dict
    
        if len(json_dict) > 0:
            with open(json_path, 'w') as json_file:
                json.dump(json_dict, json_file)
        count += 1
        p_bar.set_postfix(successes=count)
                

Processing chunks:   0%|                                                                         | 0/94 [00:00<?, ?it/s]

Processed prompts:   0%|                                                                         | 0/10 [00:00<?, ?it/s][A[A

Processed prompts:  10%|██████▌                                                          | 1/10 [00:28<04:15, 28.35s/it][A[A

Processed prompts:  20%|█████████████                                                    | 2/10 [00:29<01:36, 12.10s/it][A[A

Processed prompts:  30%|███████████████████▌                                             | 3/10 [00:29<00:46,  6.63s/it][A[A

Processed prompts:  40%|██████████████████████████                                       | 4/10 [00:29<00:25,  4.30s/it][A[A

Processed prompts:  50%|████████████████████████████████▌                                | 5/10 [00:30<00:15,  3.01s/it][A[A

Processed prompts:  60%|███████████████████████████████████████                          | 6/10 [00:31<00:08, 

## Correcting the dataset
Keeps only specific questions

In [29]:
questions_dataset = pd.read_json(json_path).T

In [30]:
questions_dataset

Unnamed: 0,question,possible_answers,correct_answer,context
Q_0,On what date did the Supreme Court officially ...,"[June 24, 2021, June 24, 2022, July 24, 2022, ...",1,### Article: United States abortion protests (...
Q_1,What was the margin of the Supreme Court decis...,"[6–3, 5–4, 7–2, 8–1]",1,### Article: United States abortion protests (...
Q_2,What percentage of the American public was spl...,"[40 to 45%, 50 to 55%, 55 to 60%, 60 to 65%]",2,### Article: United States abortion protests (...
Q_3,How did international observers and foreign le...,"[They generally supported the decision, They w...",2,### Article: United States abortion protests (...
Q_24,Where was a demonstration held in London durin...,"[Birmingham, Battersea, Manchester, Liverpool]",1,### Article: United States abortion protests (...
...,...,...,...,...
Q_2067,What is the number of the district that Jeff C...,"[5, 6, 7, 8]",1,### Article: 2023 Virginia House of Delegates ...
Q_2088,In which district did Timmy French win the ele...,"[District 1, District 3, District 10, District...",0,### Article: 2023 Virginia Senate election.\n#...
Q_2089,Who won the election in District 3?,"[Timmy French, Chris Head, John McGuire, Emily...",1,### Article: 2023 Virginia Senate election.\n#...
Q_2090,How many districts are mentioned in the article?,"[3, 4, 5, 6]",3,### Article: 2023 Virginia Senate election.\n#...


In [31]:
questions_dataset['context_num'] = pd.factorize(questions_dataset['context'])[0]

In [32]:
questions_dataset.shape

(376, 5)

In [33]:
questions_dataset.context_num.nunique()

87

In [34]:
def clean_correct_answers(x):
    try:
        return int(x) if int(x) in range(4) else None 
    except:
        return None
questions_dataset.correct_answer = questions_dataset.correct_answer.apply(lambda x: clean_correct_answers(x))
questions_dataset.dropna(inplace=True)

In [35]:
questions_dataset

Unnamed: 0,question,possible_answers,correct_answer,context,context_num
Q_0,On what date did the Supreme Court officially ...,"[June 24, 2021, June 24, 2022, July 24, 2022, ...",1,### Article: United States abortion protests (...,0
Q_1,What was the margin of the Supreme Court decis...,"[6–3, 5–4, 7–2, 8–1]",1,### Article: United States abortion protests (...,0
Q_2,What percentage of the American public was spl...,"[40 to 45%, 50 to 55%, 55 to 60%, 60 to 65%]",2,### Article: United States abortion protests (...,0
Q_3,How did international observers and foreign le...,"[They generally supported the decision, They w...",2,### Article: United States abortion protests (...,0
Q_24,Where was a demonstration held in London durin...,"[Birmingham, Battersea, Manchester, Liverpool]",1,### Article: United States abortion protests (...,1
...,...,...,...,...,...
Q_2067,What is the number of the district that Jeff C...,"[5, 6, 7, 8]",1,### Article: 2023 Virginia House of Delegates ...,85
Q_2088,In which district did Timmy French win the ele...,"[District 1, District 3, District 10, District...",0,### Article: 2023 Virginia Senate election.\n#...,86
Q_2089,Who won the election in District 3?,"[Timmy French, Chris Head, John McGuire, Emily...",1,### Article: 2023 Virginia Senate election.\n#...,86
Q_2090,How many districts are mentioned in the article?,"[3, 4, 5, 6]",3,### Article: 2023 Virginia Senate election.\n#...,86


In [36]:
# dumb heuristic to remove contexts with no information
questions_dataset = questions_dataset[questions_dataset.context.apply(lambda x: len(x.split())>70)]
questions_dataset = questions_dataset.reset_index().rename(columns={'index': 'orig_question_num', 'context': 'text'})

In [37]:
questions_dataset

Unnamed: 0,orig_question_num,question,possible_answers,correct_answer,text,context_num
0,Q_0,On what date did the Supreme Court officially ...,"[June 24, 2021, June 24, 2022, July 24, 2022, ...",1,### Article: United States abortion protests (...,0
1,Q_1,What was the margin of the Supreme Court decis...,"[6–3, 5–4, 7–2, 8–1]",1,### Article: United States abortion protests (...,0
2,Q_2,What percentage of the American public was spl...,"[40 to 45%, 50 to 55%, 55 to 60%, 60 to 65%]",2,### Article: United States abortion protests (...,0
3,Q_3,How did international observers and foreign le...,"[They generally supported the decision, They w...",2,### Article: United States abortion protests (...,0
4,Q_76,What is the current number of counties require...,"[30 counties, 44 counties, 50 counties, 88 cou...",1,### Article: August 2023 Ohio Issue 1.\n### Se...,3
...,...,...,...,...,...,...
259,Q_1983,At what time was the strike officially suspend...,"[12:00am on November 8, 12:01am on November 9,...",1,### Article: 2023 SAG-AFTRA strike.\n### Secti...,24
260,Q_2000,By what date had the FAA concluded the safety ...,"[October 15, 2023, October 31, 2023, November ...",1,### Article: SpaceX Starship Second Integrated...,82
261,Q_2001,What was the original target launch date annou...,"[November 10, 2023, November 15, 2023, Novembe...",2,### Article: SpaceX Starship Second Integrated...,82
262,Q_2002,What was the reason for the flight being delay...,"[Weather conditions, Technical issues with the...",2,### Article: SpaceX Starship Second Integrated...,82


In [45]:
questions_dataset.to_csv('current_events_questions_updated_hf_llama_3.csv', index=False)