# Question Generation

> Question Generation: Here we will put together classes / methods that provides Question Generation workflow for our passages

In [1]:
#| default_exp question_generation

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| hide
import nbdev; nbdev.nbdev_export()

## Imports

In [584]:
from loguru import logger
import os
from pathlib import Path
from fastcore.basics import patch_to, patch
from typing import Union

from zeroqaret.dataset import BEIRDataset, our_list as eval_list
from zeroqaret.helper import get_today

import pandas as pd

import torch
from torch import Tensor

import textwrap

from colbert.modeling.colbert import ColBERT
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection

from tqdm import tqdm

## Dataset

In [5]:
beir_dataset = BEIRDataset()

[32m2023-10-27 12:40:57.735[0m | [1mINFO    [0m | [36mzeroqaret.dataset[0m:[36m__init__[0m:[36m51[0m - [1mDatasets will be saved in '/home/bengsoon/Projects/xcs224u_project/zeroqaret/datasets'[0m


We will load the corpus for `trec-covid` as a start. We will not load the `queries` in "reality", we do not have access to these:

In [838]:
dataset_name = "trec-covid"

# We want to load only the corpus / passages  
raw_corpus, _, _ = beir_dataset.load_dataset(dataset_name)

  0%|          | 0/171332 [00:00<?, ?it/s]

Let's convert `trec_corpus` into a list of passages, but first let's map the positional indices of the list to the original dataset's `pid`(s) 

In [852]:
trec_corpus_ids = {idx: str(val) for idx, val in enumerate(list(raw_corpus))}
trec_corpus = [(passage.get("title", "") + " " + passage["text"].strip()).strip() for passage in raw_corpus.values()] 

Let's look at the first 5 samples of the passages

In [840]:
print("\n\n".join(trec_corpus[:5]))

Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60%) were associated with pneumonia, 14 (35

## LLM

In [415]:
from langchain.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 

from typing import Dict

In [362]:
llm_model_name = "mistral:instruct" # for 4-bit q: use `mistral:instruct`. for 8-bit q: use `mistral:7b-instruct-q8_0`.
base_url = "http://localhost:11434"


callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

llm = Ollama(base_url=base_url,
             model=llm_model_name, 
             callback_manager = callback_manager)

Let's see if we are connected to our model:

In [55]:
llm("Can you tell me who you are?")

I'm Mistral, a language model trained by the Mistral AI team.

"I'm Mistral, a language model trained by the Mistral AI team."

In [56]:
llm("What can you do?")

I am unable to perform actions as I do not have the ability to execute commands or interact with the physical world. I can only provide information, answer questions, and engage in text-based conversation.

'I am unable to perform actions as I do not have the ability to execute commands or interact with the physical world. I can only provide information, answer questions, and engage in text-based conversation.'

> Great! Now let's create a pipeline for question generation

## Question Generation with `Mistral-7B-4q` & Round-trip Consistency Check

In [410]:
from langchain.prompts import PromptTemplate
from langchain.schema import StrOutputParser
from langchain.output_parsers import StructuredOutputParser, ResponseSchema

#### Generate Question (QG)

In [411]:
# Question Generation template
qg_prompt_template = """<s>[INST]
You are a curious person who loves to ask pertinent questions. Given the Passage below, it is your job to give a correct highly descriptive title, ask the relevant right question and correct one-sentenced short answer strictly from the given passage.
----
{format_instructions}
---- 
Passage: {passage}
----
Title:
Question: 
Answer:
[/INST]"""

Let's create `response_schemas` so that we can instruct the model to output in a specific JSON format:

In [345]:
# create response schemas
qg_response_schemas = [
    # ResponseSchema(name="passage", description="Repeat of the input passage"), 
    ResponseSchema(name="title", description="Descriptive generated title based on the passage"),
    ResponseSchema(name="question", description="Relevant generated question based on the passage"),
    ResponseSchema(name="answer", description="Generated one-sentenced short answer based on the generated question and passage"),

]

# create an output parser
qg_output_parser = StructuredOutputParser.from_response_schemas(qg_response_schemas)

# get format instructions to enforce the expected json format
qg_format_instructions = qg_output_parser.get_format_instructions()

In [346]:
print(qg_format_instructions)

The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"title": string  // Descriptive generated title based on the passage
	"question": string  // Relevant generated question based on the passage
	"answer": string  // Generated one-sentenced short answer based on the generated question and passage
}
```


In [347]:
qg_prompt = PromptTemplate(
    template=qg_prompt_template,
    input_variables=["passage"],
    partial_variables={"format_instructions": qg_format_instructions}
)

In [348]:
print(qg_prompt.format_prompt(passage=trec_corpus[45]).text)

<s>[INST]
You are a curious person who loves to ask pertinent questions. Given the Passage below, it is your job to give a correct highly descriptive title, ask the relevant right question and correct one-sentenced short answer strictly from the given passage.
----
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"title": string  // Descriptive generated title based on the passage
	"question": string  // Relevant generated question based on the passage
	"answer": string  // Generated one-sentenced short answer based on the generated question and passage
}
```
---- 
Passage: BACKGROUND: Development of a practical gene point-of-care testing device (g-POCT device) requires innovative detection methods for demonstrating the results of the gene amplification reaction without the use of expensive equipment. We have studied a new method for the sequence-specific visual detection of minute amount

In [349]:
qg_chain = qg_prompt | llm

In [369]:
qg_output_trec_45 = llm(qg_prompt.format_prompt(passage=trec_corpus[45]).text, seed=158)

```json
{
	"title": "Simple Visual Detection Method for Gene Amplification",
	"question": "How did the researchers detect the presence or absence of minute amounts of nucleic acid templates?",
	"answer": "The researchers detected the presence or absence of minute amounts of nucleic acid templates by visual assessment for the color of the LAMP amplicon-PEI complex precipitate."
}
```

In [370]:
parsed_output = output_parser.parse(qg_output_trec_45)

In [371]:
parsed_output["answer"]

'The researchers detected the presence or absence of minute amounts of nucleic acid templates by visual assessment for the color of the LAMP amplicon-PEI complex precipitate.'

#### Check Answer

In [372]:
check_answer_prompt_template = """<s>[INST]
You are an expert on the topic in the passage below. Given the Title, Passage and Question below, it is your job to provide a correct and relevant one-sentenced short answer.
----
{format_instructions}
----
Passage: {passage}
Title: {title}
Question: {question}
---- 
Answer:
[/INST]"""

In [373]:
# create response schemas
check_answer_response_schemas = [
    # ResponseSchema(name="passage", description="Repeat of the input passage"), 
    ResponseSchema(name="answer", description="Generated one-sentenced short answer based on the title, question and passage"),
]

# create an output parser
check_answer_output_parser = StructuredOutputParser.from_response_schemas(check_answer_response_schemas)

# get format instructions to enforce the expected json format
check_answer_format_instructions = check_answer_output_parser.get_format_instructions()

In [374]:
check_answer_prompt = PromptTemplate(
    template=check_answer_prompt_template,
    input_variables=["passage", "title", "question"],
    partial_variables={"format_instructions": check_answer_format_instructions}
)

In [375]:
print(check_answer_prompt.format_prompt(passage=trec_corpus[45], title=parsed_output["title"], question=parsed_output["question"]).text)

<s>[INST]
You are an expert on the topic in the passage below. Given the Title, Passage and Question below, it is your job to provide a correct and relevant one-sentenced short answer.
----
The output should be a markdown code snippet formatted in the following schema, including the leading and trailing "```json" and "```":

```json
{
	"answer": string  // Generated one-sentenced short answer based on the title, question and passage
}
```
----
Passage: BACKGROUND: Development of a practical gene point-of-care testing device (g-POCT device) requires innovative detection methods for demonstrating the results of the gene amplification reaction without the use of expensive equipment. We have studied a new method for the sequence-specific visual detection of minute amounts of nucleic acids using precipitation reaction by addition of cationic polymers to amplicons of Loop mediated isothermal Amplification (LAMP). RESULTS: Oligo DNA probes labeled with different fluorescent dyes were prepared

In [357]:
check_answer_chain = check_answer_prompt | llm

In [376]:
check_answer_output_trec_45 = llm(check_answer_prompt.format_prompt(passage=trec_corpus[45], title=parsed_output["title"], question=parsed_output["question"]).text, seed=158)

```json
{
	"answer": "The researchers detected the presence or absence of minute amounts of nucleic acid templates by visual assessment of the color of the LAMP amplicon-PEI complex precipitate."
}
```

In [358]:
check_answer_output_trec_45 = check_answer_chain.invoke({"passage": trec_corpus[45], "title": parsed_output["title"], "question": parsed_output["question"]})

```json
{
    "answer": "The purpose of this study is to develop a practical gene point-of-care testing device using fluorescent labeled oligo DNA probes for detecting the presence or absence of minute amounts of nucleic acid templates in a simple manner through visual assessment."
}
```

In [377]:
check_answer_parsed_output = check_answer_output_parser.parse(check_answer_output_trec_45)

In [378]:
print(f"QG Answer:\n\t {parsed_output['answer']} \n" + "-"*50 + f"\n CA Answer:\n\t {check_answer_parsed_output['answer']}")

QG Answer:
	 The researchers detected the presence or absence of minute amounts of nucleic acid templates by visual assessment for the color of the LAMP amplicon-PEI complex precipitate. 
--------------------------------------------------
 CA Answer:
	 The researchers detected the presence or absence of minute amounts of nucleic acid templates by visual assessment of the color of the LAMP amplicon-PEI complex precipitate.


#### Find similarity

In [384]:
from sentence_transformers import SentenceTransformer, util

In [402]:
sbert_model = SentenceTransformer('all-mpnet-base-v2')

In [403]:
qg_answer_emb = sbert_model.encode(parsed_output["answer"])
ca_answer_emb = sbert_model.encode(check_answer_parsed_output["answer"])

In [404]:
cosine_scores = util.cos_sim(qg_answer_emb, ca_answer_emb).to("cpu")

In [405]:
cosine_scores

tensor([[0.9988]])

In [406]:
oembed = OllamaEmbeddings(base_url=base_url, model=llm_model_name)

In [407]:
qg_answer_oemb = oembed.embed_query(parsed_output["answer"])
ca_answer_oemb = oembed.embed_query(check_answer_parsed_output["answer"])

In [408]:
cosine_scores_oemb = util.cos_sim(qg_answer_oemb, ca_answer_oemb)

In [409]:
cosine_scores_oemb

tensor([[0.9963]])

### Putting It Altogether (Question Generator and Round Trip Consistency)

In [413]:
llm_model_name

'mistral:instruct'

In [466]:
util.cos_sim??

[0;31mSignature:[0m [0mutil[0m[0;34m.[0m[0mcos_sim[0m[0;34m([0m[0ma[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mb[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mcos_sim[0m[0;34m([0m[0ma[0m[0;34m:[0m [0mTensor[0m[0;34m,[0m [0mb[0m[0;34m:[0m [0mTensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34m"""[0m
[0;34m    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.[0m
[0;34m    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])[0m
[0;34m    """[0m[0;34m[0m
[0;34m[0m    [0;32mif[0m [0;32mnot[0m [0misinstance[0m[0;34m([0m[0ma[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0ma[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0ma[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;34m[0m
[0;34m[0m    [0;32mif[0m [0;32m

In [853]:
class QuestionGenerator:

    def __init__(self, 
                 ollama_base_url: str = 'http://localhost:11434',
                 ollama_model_name: str = 'mistral:instruct',
                 random_seed: int = 158
                ):


        self.ollama_base_url = ollama_base_url
        self.ollama_model_name = ollama_model_name
        self.random_seed = random_seed
        
        # callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
        
        self.llm = Ollama(
            base_url=ollama_base_url,
            model=ollama_model_name, 
                )
        
        self._setup_question_generator()
        self._setup_answer_checker()

    def _setup_question_generator(self,
                                 qg_prompt_template: str = None, # LLM Prompt template
                                ) -> None:
        " Sets up all the shared variables and methods for question generator "

        
        # Question Generation template
        self.qg_prompt_template = qg_prompt_template or """<s>[INST]
            You are a curious person who loves to ask pertinent questions. 
            Given the Passage below, it is your job to give a correct highly descriptive title, ask the relevant right question and correct one-sentenced short answer strictly from the given passage.
            ----
            {format_instructions}
            ---- 
            Passage: {passage}
            ----
            Title:
            Question: 
            Answer:
            [/INST]"""

        # create response schemas
        self.qg_response_schemas = [
            # ResponseSchema(name="passage", description="Repeat of the input passage"), 
            ResponseSchema(name="title", description="Descriptive generated title based on the passage"),
            ResponseSchema(name="question", description="Relevant generated question based on the passage"),
            ResponseSchema(name="answer", description="Generated one-sentenced short answer based on the generated question and passage"),
        ]
        
        # create an output parser
        self.qg_output_parser = StructuredOutputParser.from_response_schemas(self.qg_response_schemas)
        
        # get format instructions to enforce the expected json format
        self.qg_format_instructions = self.qg_output_parser.get_format_instructions()

        # prompt template
        self.qg_prompt = PromptTemplate(
                                template=self.qg_prompt_template,
                                input_variables=["passage"],
                                partial_variables={"format_instructions": self.qg_format_instructions}
                            )

    def _setup_answer_checker(self,
                              ac_prompt_template: str = None, # LLM Prompt template
                             ) -> None:
        " Sets up all the shared variables and methods for the Answer Checker "
        
        self.ac_prompt_template = """<s>[INST]
            You are an expert on the topic in the passage below. Given the Title, Passage and Question below, it is your job to provide a correct and relevant one-sentenced short answer.
            ----
            {format_instructions}
            ----
            Passage: {passage}
            Title: {title}
            Question: {question}
            ---- 
            Answer:
            [/INST]"""

        # create response schemas
        self.answer_checker_response_schemas = [
            ResponseSchema(name="answer", description="Generated one-sentenced short answer based on the title, question and passage"),
        ]
        
        # create an output parser
        self.answer_checker_output_parser = StructuredOutputParser.from_response_schemas(self.answer_checker_response_schemas)
        
        # get format instructions to enforce the expected json format
        self.answer_checker_format_instructions = self.answer_checker_output_parser.get_format_instructions()

        # answer checker prompt
        self.answer_checker_prompt = PromptTemplate(
                template=check_answer_prompt_template,
                input_variables=["passage", "title", "question"],
                partial_variables={"format_instructions": check_answer_format_instructions}
            )

    def generate_question(self,
                          passage: str, # passage
                          random_seed: int = None, # if provided, it will replace random_seed
                          verbose: bool = False, # prints prompt
                           **kwargs: "Any",
                         ) -> Dict[str, str]:
        """ Prompts LLM to generate title, question and answer given `passage` """ 
        random_seed=random_seed or self.random_seed
        
        prompt = self.qg_prompt.format_prompt(passage=passage).text
        if verbose in ["all"]: 
            print(prompt)
        res = self.llm(prompt, seed=random_seed, **kwargs)
        
        try:
            res = self.qg_output_parser.parse(res)
        except:
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"Unable to parse results. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_question(passage=passage, random_seed=temp_random_seed, verbose=verbose,  repeat_last_n=0)
        
        if not self._output_is_dict(res): 
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"Generated question output is not dict. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_question(passage=passage, random_seed=temp_random_seed, verbose=verbose,  repeat_last_n=0)
            
        if not self._check_title_question_answer_in_dict(res):
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"Either all or some of ('title', 'question', 'answer') not in dict. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_question(passage=passage, random_seed=temp_random_seed, verbose=verbose,  repeat_last_n=0)
        
        if verbose in ["all"]:
            logger.info(f"{res}")
            
        return res
            
  
            

    def generate_answer_to_check(self,
                                 passage: str, # passage
                                 title: str, # title
                                 question: str, # question
                                 random_seed: int = None, # if provided, it will replace random_seed
                                 verbose: bool = False, # prints prompt
                                 **kwargs: "Any",
                                  ) -> Dict[str, str]:
        """ Prompts LLM to answer given `passage`, `title`, `question` """ 
        random_seed=random_seed or self.random_seed
        
        prompt = self.answer_checker_prompt.format_prompt(passage=passage, title=title, question=question).text
        if verbose in ["all"]: 
            print(prompt)
            
        res = self.llm(prompt, seed=random_seed, **kwargs)
        
        try:
            res = self.answer_checker_output_parser.parse(res)
        except:
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"Unable to parse results. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_answer_to_check(passage=passage, title=title, question=question, verbose=verbose, random_seed=temp_random_seed, repeat_last_n=0)

        if not self._output_is_dict(res): 
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"Generated checker's answer output is not dict. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_answer_to_check(passage=passage, title=title, question=question, verbose=verbose, random_seed=temp_random_seed, repeat_last_n=0)

        if not self._check_answer_in_dict(res):
            temp_random_seed = random.randint(200, 1000)
            logger.info(f"'answer' not in dict. Regenerating with `random_seed = {temp_random_seed}`...")
            return self.generate_answer_to_check(passage=passage, title=title, question=question, verbose=verbose, random_seed=temp_random_seed, repeat_last_n=0)
            
        if verbose in ["all"]:
            logger.info(f"{res}")
            
        return res

    def _output_is_dict(self,
                       results: Union[str, dict], # results from llm
                      ) -> bool:
        " Checks to see if output is dict. "

        return type(results) == dict

    def _check_title_question_answer_in_dict(self,
                             result: dict, # results from llm
                            ) -> bool:
        """ 
        Check to see if all of 'title', 'question' and 'answer' are in `results`. 
        """
        

        return ("title" in result) and ("question" in result) and ("answer" in result)

    def _check_question_in_dict(self,
                               result: dict, # results from llm
                                ) -> bool:
        """ 
        Check to see if 'question' is in `results`. 
        """

        return "question" in result
        
    def _check_answer_in_dict(self,
                               result: dict, # results from llm
                                ) -> bool:
        """ 
        Check to see if 'answer' is in `results`. 
        """

        return "answer" in result
        
    def round_trip_question_generation(self,
                                       passage: str, # passage
                                       embedding_model: str = 'all-MiniLM-L6-v2', # embedding model: any SBERT emb models or 'llm' if use emb from Ollama's llm
                                       cutoff: float = 0.8, # cosine-sim cutoff score to accept [-1, 1]
                                       random_seed: int = None, # if provided, it will replace random_seed
                                       verbose: str = None, # "all" to report everything, "results" to report only results. 
                                      ) -> Union[Dict[str, str], bool]:
        """ Prompts LLM to generate title, question and answer given `passage` and performs round-trip consistency check. """
        """ Note: We have assigned the generated_results as part of the instance variable of this class for debugging purposes when we loop through the corpus. """
        self.rt_random_seed = random_seed or self.random_seed

        if not hasattr(self, "emb_model"):
            logger.info("Setting up embedding model")
            if embedding_model == "llm":
                self.emb_model = OllamaEmbeddings(base_url=base_url, model=llm_model_name).embed_query
            else:
                self.emb_model = SentenceTransformer(embedding_model).encode
            
        logger.info("Generating question...")
        self.generated_results = self.generate_question(passage, verbose=verbose, random_seed=self.rt_random_seed)

        logger.info("Performing round-trip consistency check...")
        self.checker_results = self.generate_answer_to_check(passage, self.generated_results["title"], self.generated_results["question"], verbose=verbose, random_seed=self.rt_random_seed)

        if verbose in ["all", "results"]:
            logger.info("\n" + "." * 100 + "\n" + " Generated Answer: ".center(100, " ") + "\n\n" + textwrap.fill(f"{self.generated_results['answer']}", 100) + "\n\n" +
                        "\n" + " Checker's Answer: ".center(100, " ") + "\n\n" + textwrap.fill(f"{self.checker_results['answer']}", 100) + "\n\n" + "."*100)

        logger.info("Performing similarity calculation ...")
        gen_a_emb = self.emb_model(self.generated_results["answer"])
        checker_a_emb = self.emb_model(self.checker_results["answer"])

        score = self.cos_sim(gen_a_emb, checker_a_emb)
        if verbose in ["all", "results"]: 
            logger.info(f"Score: {score}")

        if score < cutoff:
            logger.info(f"Rejecting generated question set as it failed the consistency check")
            return None
        else:
            logger.info(f"Passed consistency check.")
            return self.generated_results        
            
        
    def cos_sim(self, a: Tensor, b: Tensor):
        """
        Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
        
        :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
        Adapted from https://github.com/UKPLab/sentence-transformers/blob/c5f93f70eca933c78695c5bc686ceda59651ae3b/sentence_transformers/util.py
        """
        if not isinstance(a, torch.Tensor):
            a = torch.tensor(a)
    
        if not isinstance(b, torch.Tensor):
            b = torch.tensor(b)
    
        if len(a.shape) == 1:
            a = a.unsqueeze(0)
    
        if len(b.shape) == 1:
            b = b.unsqueeze(0)
    
        a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
        b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
        return torch.mm(a_norm, b_norm.transpose(0, 1)).to('cpu')

In [854]:
question_generator = QuestionGenerator()

In [855]:
@patch_to(QuestionGenerator)
def generate_questions_corpus(self,
                              corpus: List,
                              corpus_ids: Dict,
                              df_checkpoint_path: str,
                              verbose: str = None
                             ) -> None:
    assert verbose in [None, "all", "results", "disable"], 'verbose options are only [None, "all", "results", "disable"]'
    
    # disable logger if asked to
    if verbose == "disable": 
        logger.disable("__main__")
    else:
        logger.enable("__main__")
    
    # logger.add(f"./{get_today('%Y%m%d')}_qg.log")

    qg_df = pd.DataFrame(columns=["pid", "passage", "title", "question", "answer"])

    for idx, passage in enumerate(tqdm(corpus, "Question Generation Progress: ", len(corpus))):
        pid = corpus_ids[idx]
        logger.info("\n" + f" {idx+1} - {pid} ".center(150, "#"))
        logger.info("")
        res = self.round_trip_question_generation(passage, verbose=verbose)
    
        while not res:
            temp_rand_seed = random.randint(200, 1000)
            logger.info(f"Retrying again with a different seed (`random_seed = {temp_rand_seed}`) ...")
            res = self.round_trip_question_generation(passage, verbose=verbose, random_seed=temp_rand_seed)
            
        res["passage"] = passage
        res["pid"] = pid
    
        qg_df = pd.concat((qg_df, pd.DataFrame(res, index=[idx])))
        
        if idx % 5 == 0:
            logger.info(f"Saving dataframe checkpoint as '{df_checkpoint_path}'")
            qg_df.to_csv(df_checkpoint_path)
        logger.info("")

    # enable logger back
    if verbose == "disable": logger.enable("__main__")



Let's generate questions for `trec-covid` dataset

In [None]:
test_corpus = trec_corpus[8:9]
test_corpus_ids = {idx: trec_corpus_ids[i] for idx, i in enumerate(range(8,9))}

for i in range(10):
    question_generator.generate_questions_corpus(test_corpus, test_corpus_ids, f"../datasets/{dataset_name}/qg/{dataset_name}_qg_test.csv")

logger.info("done")

Question Generation Progress:   0%| | 0/1 [00:00<?[32m2023-10-28 02:25:38.096[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_questions_corpus[0m:[36m22[0m - [1m
#################################################################### 1 - 8qnrcgnk ####################################################################[0m
[32m2023-10-28 02:25:38.097[0m | [1mINFO    [0m | [36m__main__[0m:[36mgenerate_questions_corpus[0m:[36m23[0m - [1m[0m
[32m2023-10-28 02:25:38.099[0m | [1mINFO    [0m | [36m__main__[0m:[36mround_trip_question_generation[0m:[36m225[0m - [1mSetting up embedding model[0m
[32m2023-10-28 02:25:38.310[0m | [1mINFO    [0m | [36m__main__[0m:[36mround_trip_question_generation[0m:[36m231[0m - [1mGenerating question...[0m
[32m2023-10-28 02:25:40.138[0m | [1mINFO    [0m | [36m__main__[0m:[36mround_trip_question_generation[0m:[36m234[0m - [1mPerforming round-trip consistency check...[0m
[32m2023-10-28 02:25:42.847[0m | [1mIN

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()