In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

pd.set_option('max_colwidth', None)

# Set up MixQG Model

In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

mixqg_tokenizer = AutoTokenizer.from_pretrained('Salesforce/mixqg-base')
mixqg_model = AutoModelForSeq2SeqLM.from_pretrained('Salesforce/mixqg-base').to('cuda')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def format_inputs(context: str, answer: str):
    return f"{answer} \\n {context}"

In [6]:
import re

def postprocess_question(ques):
    """Clean up the question after being generated.
    Pipeline: Clean double spaces, clean extra punctuations

    Args:
        ques (str): generated question
    Returns:
        str: new question
    """

    puncts_to_remove = ['.', ',', '!']

    ques_c = list(ques)
    for i in range(len(ques_c) - 1, 0, -1):
        if ques_c[i].isalnum():
            break

        if ques_c[i] in puncts_to_remove:
            ques_c.pop(i)

    new_ques = ''.join(ques_c)
    new_ques = re.sub(r'\s+', ' ', new_ques)
    return new_ques

In [7]:
def postprocess_answer(ans):
    """Clean up the answer after being generated.
    Pipeline: Clean double spaces, clean extra punctuations, capitalize first word.

    Args:
        ans (str): generated answer
    Returns:
        str: new answer
    """

    ending_puncts = ['.', '!']

    ans_c = list(ans)
    has_punct = False
    for i in range(len(ans_c) - 1, 0, -1):
        if ans_c[i].isalnum():
            break

        if ans_c[i] in ending_puncts:
            if not has_punct:
                has_punct = True
                continue
            ans_c.pop(i)
        
    if not has_punct:
        ans_c.append('.')
        
    for i in range(len(ans_c)):
        if len(ans_c[i].strip()) == 0:
            continue
        ans_c[i] = ans_c[i].upper()
        ans_c = ans_c[i:]
        break

    new_ans = ''.join(ans_c)
    new_ans = re.sub(r'\s+', ' ', new_ans)
    return new_ans.strip()

In [8]:
def generate_question(context, answer):
    """Generate a question based on context and answer, need globally available mixqg_model and mixqg_tokenizer, \
        answer has to be formatted correctly beforehand (capitalize, no extra spaces, etc. check the postprocess_answer() function)

    Args:
        context (str): 
        answer (str): 

    Return:
        str
    """

    inputs = format_inputs(context, answer)
    input_ids = mixqg_tokenizer(inputs, return_tensors="pt").input_ids.to('cuda')

    output_seqs = mixqg_model.generate(input_ids, max_length=32, num_beams=4)
    output = mixqg_tokenizer.batch_decode(output_seqs, skip_special_tokens=True)
    question = output[0]
    question = postprocess_question(question)

    return question

# Inference

Utils

In [9]:
def read_ans_file(file_name):
    answers = []
    try:
        with open(file_name, 'r') as f:
            for line in f:
                answers.append(line.strip())
    except:
        print("File not found!")

    return answers

In [10]:
import re
def split_into_sentences(text):
    """Split text into sentence

    Args:
        text (str): text to be splited
    Return: 
        list[str]: split text
    """
    sents = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)", text)
    sents = [sent for sent in sents if len(sent.strip())]
    return sents

In [11]:
def find_source_sent(sents, text):
    """Find what sent the text belongs to. Text has to be a subspan of a sentence.

    Args:
        sents (List(str)):
        text (str): 
    Return: 
        str: source sent
    """

    for sent in sents:
        if text.strip() in sent:
            return sent

    return ""

Inference

In [12]:
context_path = 'wikipedia_articles/personalized_learning.txt'
f = open(context_path, 'r')
context = f.read()

answers = read_ans_file("chosen_spans.txt")

questions = []
for ans in answers:
    ans = postprocess_answer(ans)
    ques = generate_question(context, ans)
    questions.append((ques, ans))

Token indices sequence length is longer than the specified maximum sequence length for this model (2092 > 512). Running this sequence through the model will result in indexing errors


In [13]:
sents = split_into_sentences(context)
questions_df = pd.DataFrame(questions, columns=['question', 'answer'])
questions_df['source_sent'] = questions_df.apply(lambda x: find_source_sent(sents, x['answer'][1:-1]), axis=1) # [1:-1] to skip the modifications of postprocess_answer()
questions_df = questions_df[['source_sent', 'question', 'answer']] # rearrange columns' order

In [14]:
questions_df.to_csv("generated_questions/mixqg_questions_personalized_learning.csv", index=False)