In [None]:
import json
import random
import string
import pandas as pd

import spacy
from spacy.language import Language
from spacy.util import filter_spans

from pathlib import Path
from ast import literal_eval
from openai import OpenAI, AsyncOpenAI
from dotenv import load_dotenv

from tqdm.autonotebook import tqdm

load_dotenv()

In [2]:
seed = 12345
sample_n = 100
trial_n = 3 # Number of JSON parsing attempts
output_step = 10 # Number of attempts for step 5

random.seed(seed)

## Settings

In [3]:
# QA_TYPE = "UQA"
# QA_TYPE = "RQA"
QA_TYPE = "AQuA"

In [None]:
# Input CSV path
path_csv = Path(f"../data/input/{QA_TYPE}.csv")
path_csv.exists()

In [5]:
col_id = "question_id"
# col_evidence = "evidence_wo_url"
# col_evidence = "evidence"
col_evidence = "rationale"
col_question = "question"
# col_question = "question_sentence"
# col_choices = "choices"
col_choices = "options"
# col_answer = "answer"
col_answer = "correct"

list_convert_cols = [col_choices]
# list_convert_cols = [col_choices, col_answer]

In [6]:
ollama_model = "gemma2:9b"
openai_model = "gpt-4o-mini-2024-07-18"

In [7]:
# Output paths
dir_output = Path(f"../data/output/{QA_TYPE}_output_4/")
if not dir_output.exists():
    dir_output.mkdir()

output_path_qa = dir_output / f"{QA_TYPE}_{ollama_model.replace(':', '_')}_{{trial_no}}.csv"
output_path_qa_tmp = dir_output / f"{QA_TYPE}_{ollama_model.replace(':', '_')}_tmp.csv" # 途中経過

output_path_step1 = dir_output / f"{QA_TYPE}_step1.csv"
output_path_step2 = dir_output / f"{QA_TYPE}_step2.csv"
output_path_step2_use = dir_output / f"{QA_TYPE}_step2_use.csv"
output_path_step4 = dir_output / f"{QA_TYPE}_step4.csv"

In [None]:
# List of parts of speech to be masked
POS_CONTENT_WORD = ['PROPN', 'NOUN', 'VERB', 'ADJ', 'ADV', ]

In [None]:
list_mask_rate = list(range(0, 105, 5))
print(list_mask_rate)

### Prompt templates

In [10]:
STEP2_PROMPT = """The following meta table contains metadata extracted from Text1, Text2, and Text3 according to the definition. Fill in the values for "Part of Speech", "Category", and "Meaning" in the table, and output the JSON in the format {{'data' : [ {{'word': str, 'part_of_speech': str, 'category': str, 'meaning': str}}, ...]}}.

## Meta Table
{step1_output}

## Definition
- Category: One of "organization name, individual name, technical term"
- Meaning: Words that express higher-level concepts (multiple possible). Do not use other "words".

## Example
Word | Part of Speech | Category | Meaning
---|---|---|---
Medical Team | Common Noun | Organization Name | Healthcare
Relay Station | Common Noun | Technical Term | Communication
Tanaka Vehicles | Proper Noun | Organization Name | Company, Manufacturing
Chronowar | Proper Noun | Technical Term | Product Name
Napoleon | Proper Noun | Individual Name | Historical Figure

## Text1
{context_input}

## Text2
{q_input}

## Text3
{choices_input}"""

In [11]:
STEP3_PROMPT = """Change the numerical parts of Text1, Text2, and Text3 according to the following conditions. Output the response in the specified format.

## Conditions
- Increase the numerical values by 32%
- Display numerical values to two decimal places
- If there are no numerical values, return the text as is
- Keep year and month values unchanged

## Format
- JSON
- keys:
    - conditions: list[str]
        - List as "<original number> -> <changed number>"
        - If no numbers were changed, use "no change"
    - result: {{'text1': str, 'text2': str, 'text3': str}}
        - Text after replacing the numerical values

## Text1
{context_input}

## Text2
{q_input}

## Text3
{choices_input}"""

In [12]:
STEP5_PROMPT = """The following is a text and metadata related to the code terms within the text. Answer the question concisely according to the instructions.

## Instructions
- Choose the answer from the options and respond with the corresponding number.
- Respond in JSON format as {{'basis': str, 'answer': int}}
- Use only the text as a reference for the basis

## Text
{context}

## Metadata
{metadata}

## Question
{question}

## Options
{option_list}"""

## LLM Settings

In [13]:
openai_async_client = AsyncOpenAI()
ollama_client = OpenAI(
    base_url='http://localhost:11434/v1/',
    api_key='ollama'
)

In [14]:
USER = 'user'
AI = 'assistant'
SYS = 'system'

In [15]:
def generate(client:OpenAI, model:str, messages:list[dict], stream:bool=True, output_json:bool=False):
    param = {
        'model':model,
        'messages':messages,
        'stream':stream
    }
    
    if output_json:
        param['response_format'] = {'type': 'json_object'}

    return client.chat.completions.create(**param)

def write_stream(stream) -> str:
    ret = ""
    for c in stream:
        dlt = c.choices[0].delta.content
        if dlt:
            ret += dlt
            print(dlt, end="", flush=True)

    return ret

def append_message(role:str, content:str, messages:list[dict]=[]):
    messages.append({'role': role, 'content': content})
    return messages

async def aget_response(client:AsyncOpenAI, model:str, messages:list[dict], output_json:bool=False) -> str:
    param = {
        'model':model,
        'messages':messages
    }
    if output_json:
        param['response_format'] = {'type': 'json_object'}

    ret = await client.chat.completions.create(**param)
    return ret.choices[0].message.content

## Data loading

In [15]:
if len(list_convert_cols) > 0:
    param = {
        'converters': {col:literal_eval for col in list_convert_cols}
    }
else:
    param = {}

df_qa = pd.read_csv(path_csv, **param)

In [None]:
# Filter down to sample_n
if df_qa.shape[0] > sample_n:
    print('filtering:', df_qa.shape[0], '->', sample_n)
    df_qa = df_qa.sample(sample_n, random_state=seed).copy(deep=True)

df_qa.shape

## step1
- Extract words from `context`, `question`, and `choices`
- Determine if it is a content word
- Create the word list

In [16]:
@Language.component("merge_hyphenated")
def merge_hyphenated(doc):
    """Detect hyphenated words as a single word"""
    spans = []
    for i in range(len(doc) - 2):
        if doc[i + 1].text == '-' and not doc[i + 1].whitespace_:
            spans.append(doc[i:i+3])
    filtered_spans = filter_spans(spans)
    with doc.retokenize() as retokenizer:
        for span in filtered_spans:
            retokenizer.merge(span)
    return doc

def get_entity_token_index(doc:spacy.tokens.doc.Doc, start_i:int) -> list[int]:
    """Search until the end of the entity token and return the index up to the end position"""
    list_ind = [start_i]
    if len(doc) <= (start_i+1):
        return list_ind
    
    cur = start_i + 1
    while cur < len(doc):
        if doc[cur].ent_iob_ == "O":
            return list_ind
        
        list_ind.append(cur)
        cur += 1

    return list_ind

def process(nlp:spacy.language.Language, text:str) -> pd.DataFrame:
    """Output a list of morphemes from the text and return a dataframe of words (including duplicates)"""
    doc = nlp(text)

    list_words = []
    list_checked_i = []

    for token in doc:
        # entity token
        if token.ent_type_:
            # Start token of the entity
            if token.ent_iob_ == "B":
                tmp_index = get_entity_token_index(doc, token.i)
                tmp_words = doc[tmp_index[0]:tmp_index[-1]+1].text
                tmp_lemma = doc[tmp_index[0]:tmp_index[-1]+1].lemma_
                list_words.append(
                    {
                        'word': tmp_words,
                        'part_of_speech': token.pos_,
                        'category': token.ent_type_, 
                        'lemma': tmp_lemma,
                        'word_count': len(tmp_words.split()),
                        'index': tmp_index
                    }
                )
                list_checked_i.extend(tmp_index)
        else:
            word_data = {
                    'word': token.text,
                    'part_of_speech': token.pos_,
                    'category': token.ent_type_,
                    'lemma': token.lemma_,
                    'word_count': len(token.text.split()),
                    'index':[token.i]
                }
            
            if token.i in list_checked_i:
                continue
            
            if token.pos_ in ['PUNCT', 'SPACE']:
                word_data['word_count'] = 0

            list_words.append(
                word_data
            )
            list_checked_i.append(token.i)

    df = pd.DataFrame(list_words)

    return df

In [17]:
def is_content_word(text:str, pos:str, category:str) -> bool:
    """Return True if the text is a content word"""
    if (category != "") or text.count(' ') > 0:
        return True
    
    if pos in POS_CONTENT_WORD:
        return True

    return False

In [18]:
def step1(df_qa:pd.DataFrame) -> pd.DataFrame:
    """Extract morphemes from the text column of the QA"""
    list_result = []

    nlp = spacy.load("en_core_web_sm")
    nlp.add_pipe("merge_hyphenated", before='parser')

    for col in [col_evidence, col_question, col_choices]:
        if col == col_choices:
            list_word_df = df_qa[col].apply(lambda x: pd.concat([process(nlp, choice) for choice in x]))
        else:
            list_word_df = df_qa[col].apply(lambda x: process(nlp, x))

        for i, tmp in enumerate(list_word_df):
            tmp[col_id] = df_qa.iloc[i][col_id]
            tmp['mask_col'] = col

        list_result.extend(list_word_df)
    
    df_result = pd.concat(list_result)

    df_result['is_content_word'] = df_result.apply(lambda x: is_content_word(x['word'], x['part_of_speech'], x['category']), axis=1)

    return df_result

In [21]:
df_word = step1(df_qa)

#### Remove morphemes extracted from `options` (AQuA only)

In [None]:
df_word = df_word[df_word['mask_col']!=col_choices]

In [None]:
df_word = df_word[~df_word['category'].isin(["CARDINAL", "DATE", "PERCENT", "ORDINAL", "QUANTITY"])]

## step2
- Generate category and meaning for each content word using `gemma 2`

In [19]:
def convert_df_to_markdown(df:pd.DataFrame):
    meta_table_s2 = ' | '.join(df.columns) + '\n'
    meta_table_s2 += ' | '.join(['---']*len(df.columns.to_list())) + "\n"
    meta_table_s2 += '\n'.join(df.apply(lambda x: ' | '.join([x[col] for col in df.columns]), axis=1).values)

    return meta_table_s2

In [20]:
def step2(client:OpenAI, model:str, df_qa:pd.DataFrame, df_word:pd.DataFrame) -> pd.DataFrame:
    """Create a word list at the lemma level from df_word and add category and meaning for content words"""
    list_result = []

    nlp = spacy.load("en_core_web_sm")
    nlp.add_pipe("merge_hyphenated", before='parser')

    for i, row_qa in tqdm(list(df_qa.iterrows())):
        qa_id = row_qa[col_id]
        rows_word = df_word[(df_word[col_id]==qa_id) & df_word['is_content_word']]
        
        # Group by lemma
        cols_unique = ['part_of_speech', 'category']
        df_gr_lemma = rows_word.groupby('lemma').agg({col:'unique' for col in cols_unique}).reset_index().copy()
        for col in cols_unique:
            df_gr_lemma[col] = df_gr_lemma[col].apply(', '.join)
        df_gr_lemma = df_gr_lemma.rename(columns={'lemma':'word'})
        df_gr_lemma['meaning'] = ""
        df_gr_lemma['word_lower'] = df_gr_lemma['word'].apply(lambda x: x.lower().strip(string.punctuation))# 紐づけ用

        # Process 10 words at a time to avoid failure
        for start in range(0, len(df_gr_lemma), 10):
            end = start + 10
            chunk = df_gr_lemma.iloc[start:end]

            # Reprocess until the meaning is filled
            total_output_count = chunk.shape[0]
            list_tmp_result = []
            chunk_count = 10 
            while (total_output_count > 0) and (chunk_count > 0) :
                print(qa_id, total_output_count, chunk['word'].unique())
                md_metatable = convert_df_to_markdown(chunk.drop(columns='word_lower'))

                messages = []
                prompt = STEP2_PROMPT.format(
                    step1_output=md_metatable, 
                    context_input=row_qa[col_evidence],
                    q_input=row_qa[col_question],
                    choices_input='\n'.join(row_qa[col_choices]),
                )
                messages = append_message(USER, prompt, messages)
                
                # Generate
                count = trial_n
                chunk_count -= 1
                while count > 0:
                    try:
                        response = generate(client, model, messages, stream=False, output_json=True).choices[0].message.content
                        data_s2 = json.loads(response)['data']
                        df_tmp = pd.DataFrame(data_s2)
                        df_tmp[col_id] = qa_id

                        # 1. Convert to lowercase and remove punctuation and join them
                        df_tmp_1 = df_tmp.copy()
                        df_tmp_1['word_lower'] = df_tmp_1['word'].apply(lambda x: x.lower().strip(string.punctuation)) # 紐づけ用
                        df_tmp_1 = df_tmp_1[df_tmp_1['word_lower'].isin(chunk['word_lower'].unique())].copy()
                        df_tmp_1 = df_tmp_1[(df_tmp_1['meaning'].notna())&(df_tmp_1['meaning']!='')]
                        print(df_tmp_1['word'].unique())
                        list_tmp_result.append(df_tmp_1.drop(columns='word_lower').copy())
                        chunk = chunk[~chunk['word_lower'].isin(df_tmp_1['word_lower'].unique())]
                        
                        # 2. Convert the extracted words to lemmas and join them
                        if chunk.shape[0] > 0:
                            df_tmp_2 = df_tmp.loc[~df_tmp.index.isin(df_tmp_1.index.to_list())].copy()
                            df_tmp_2['lemma_'] = df_tmp_2['word']
                            for i, row in df_tmp_2.iterrows():
                                doc = nlp(row['word'])
                                df_tmp.loc[i, 'lemma_'] = doc[0:].lemma_

                            # Join
                            df_tmp_2 = df_tmp_2[df_tmp_2['lemma_'].isin(chunk['word'])]
                            print(df_tmp_2['lemma_'].unique())
                            list_tmp_result.append(df_tmp_2.drop(columns='lemma_').copy())
                            chunk = chunk[~chunk['word'].isin(df_tmp_2['lemma_'].unique())]

                        total_output_count = chunk.shape[0]
                        break
                    
                    except Exception as e:
                        print('row', i, ': error', e)
                        count -= 1
                
            list_result.extend(list_tmp_result)

    return pd.concat(list_result, ignore_index=True)

In [None]:
df_lemma = step2(ollama_client, ollama_model, df_qa, df_word)

In [31]:
list_del_index = []
list_concat_df = []
if 'data' in df_lemma.columns:
    rows = df_lemma[df_lemma['data'].notna()]
    for i, row in rows.iterrows():
        tmp_df = pd.DataFrame(row['data'])
        tmp_df[col_id] = row[col_id]
        list_del_index.append(i)
        list_concat_df.append(tmp_df.copy())

    df_lemma.drop(index=list_del_index, inplace=True)
    df_lemma.drop(columns=['data'], inplace=True)
    df_lemma = pd.concat([df_lemma] + list_concat_df).reset_index(drop=True)

In [32]:
df_word.to_csv(output_path_step1, encoding='utf-8-sig', index=False)
df_lemma.to_csv(output_path_step2, encoding='utf-8-sig', index=False)

#### Create a word list for code conversion

In [24]:
df_word_use = df_word[df_word['is_content_word']].copy()
df_word_use['code'] = df_word_use.groupby([col_id, 'lemma']).ngroup()
df_word_use['code'] = df_word_use.groupby(col_id)['code'].rank(method='dense').astype(int)
df_word_use['code'] = df_word_use['code'].apply(lambda x: "r"+str(x).zfill(3))

In [25]:
# Merging process
list_merge = []
checked_index = []
usecols = ['word', 'part_of_speech', 'category', 'lemma', 'word_count', 'index',
       'question_id', 'mask_col', 'is_content_word', 'code',
       'part_of_speech_output', 'category_output', 'meaning']
# 1. Merge the lemma and word columns
df_merge_tmp = df_word_use.merge(df_lemma.reset_index(), how='inner', left_on=[col_id, 'lemma'], right_on=[col_id, 'word'], suffixes=['', '_output'])
list_merge.append(df_merge_tmp[usecols].fillna('').copy())
df_lemma_tmp = df_lemma.loc[~df_lemma.index.isin(df_merge_tmp['index_output'])].copy()[df_lemma.columns]
checked_index.extend(df_merge_tmp.index.tolist())

In [26]:
#  2. Convert all to lowercase and merge
df_word_use_tmp = df_word_use[~df_word_use.index.isin(checked_index)].copy()
df_word_use_tmp['word_lower'] = df_word_use_tmp['lemma'].apply(lambda x: x.lower().strip(string.punctuation))
df_lemma_tmp['word_lower'] = df_lemma_tmp['word'].apply(lambda x: x.lower().strip(string.punctuation))
df_merge_tmp = df_word_use_tmp.merge(df_lemma_tmp.drop(columns=['word']).reset_index(), how='inner', on=[col_id, 'word_lower'], suffixes=['', '_output'])
list_merge.append(df_merge_tmp[usecols].fillna('').copy())
df_lemma_tmp = df_lemma_tmp.loc[~df_lemma_tmp.index.isin(df_merge_tmp['index_output'])].copy()[df_lemma.columns]
checked_index.extend(df_merge_tmp.index.tolist())

In [27]:
# 3. Reconvert the word in df_lemma to a lemma and merge
df_word_use_tmp = df_word_use[~df_word_use.index.isin(checked_index)].copy()
df_word_use_tmp['word_lower'] = df_word_use_tmp['lemma'].apply(lambda x: x.lower().strip(string.punctuation))
df_lemma_tmp['lemma'] = None

nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("merge_hyphenated", before='parser')

for i, row in df_lemma_tmp.iterrows():
    doc = nlp(row['word'])
    df_lemma_tmp.loc[i, 'lemma'] = doc[0:].lemma_

df_lemma_tmp['word_lower'] = df_lemma_tmp['lemma'].apply(lambda x: x.lower().strip(string.punctuation))

df_merge_tmp = df_word_use_tmp.merge(df_lemma_tmp.drop(columns=['word']).reset_index(), how='inner', on=[col_id, 'word_lower'], suffixes=['', '_output'])
list_merge.append(df_merge_tmp[usecols].fillna('').copy())
df_lemma_tmp = df_lemma_tmp.loc[~df_lemma_tmp.index.isin(df_merge_tmp['index_output'])].copy()[df_lemma.columns]
checked_index.extend(df_merge_tmp.index.tolist())

In [28]:
df_word_meaning = pd.concat(list_merge + [df_word_use.loc[~df_word_use.index.isin(checked_index)]], ignore_index=True)

for col in ['part_of_speech', 'category']:
    df_word_meaning[f"{col}_new"] = df_word_meaning[f'{col}']
    df_word_meaning.loc[df_word_meaning[f"{col}_new"]=="", f"{col}_new"] = df_word_meaning[df_word_meaning[f"{col}_new"]==""][f'{col}_output']

df_word_new = df_word_meaning[['lemma', 'word', 'part_of_speech_new', 'category_new', 'meaning', col_id, 'code']].rename(columns=
    {
        'part_of_speech_new': 'part_of_speech',    
        'category_new': 'category',    
    }
)

In [30]:
df_word_new.to_csv(output_path_step2_use, encoding='utf-8-sig', index=False)

## step3
- Numeric conversion
- Convert only `question` and `options`; do not convert `rationale`
- Convert using code

In [None]:
import re

def scale_numbers_in_text(text):
    # Regular expression to find numbers (supports both integers and decimals)
    def replace(match):
        # Retrieve the matched number, scale it by 1.23, and round it to 2 decimal places
        number = float(match.group())
        scaled_number = round(number * 1.23, 2)
        # Format as a string with 2 decimal places
        return f"{scaled_number:.2f}"

    # Use re.sub() to replace all numbers with their scaled values (1.23 times)
    updated_text = re.sub(r'\d+(?:\.\d+)?', replace, text)
    return updated_text

In [21]:
def step3(client:OpenAI, model:str, df_qa:pd.DataFrame) -> pd.DataFrame:
    df_qa['s3_output_context'] = df_qa[col_evidence]
    df_qa['s3_output_Q'] = df_qa[col_question].apply(scale_numbers_in_text)
    df_qa['s3_output_choices'] = df_qa[col_choices].apply(lambda x: [scale_numbers_in_text(i) for i in x])
    df_qa['s3_conditions'] =  df_qa[col_evidence].apply(scale_numbers_in_text)

    return df_qa

In [33]:
df_qa = step3(ollama_client, ollama_model, df_qa)

In [34]:
df_qa.to_csv(output_path_qa_tmp, encoding='utf-8-sig', index=False)

# 1. regular masking

## step4
- Select words to be coded based on the masking rate

In [22]:
def select_mask_row(df_word:pd.DataFrame, mask_rate:int):
    """Select words to be coded based on the masking rate"""
    assert (mask_rate >= 0) and (mask_rate <= 100), f"`mask_rate` must be set between 0 and 100. mask_rate: {mask_rate}"

    df_copy = df_word.copy()
    mask_col = f'p_{mask_rate}_masked'
    df_copy[mask_col] = False

    for qa_id, rows in df_copy.groupby(col_id):
        list_lemma = rows['lemma'].unique().tolist()
        mask_row_n = round(len(list_lemma) * (mask_rate/100))
        print(f'ID_{qa_id} {mask_rate}% number of words to be masked:', mask_row_n)

        list_mask_lemma = random.sample(list_lemma, k=mask_row_n)
        mask_rows_index = rows[rows['lemma'].isin(list_mask_lemma)].index.values

        df_copy.loc[mask_rows_index, mask_col] = True

    return df_copy

In [None]:
for mask_rate in list_mask_rate:
    df_word_new = select_mask_row(df_word_new, mask_rate)

In [23]:
def step4(df_qa:pd.DataFrame, list_mask_rate:list[int], df_word_new:pd.DataFrame) -> pd.DataFrame:
    """Add masked text based on the masking rate to df_qa"""
    df_copy = df_qa.copy()
    for mask_rate in list_mask_rate:
        df_copy[f's4_prg_encode_context_{mask_rate}'] = ""
        df_copy[f's4_prg_encode_Q_{mask_rate}'] = ""
        # Use `s3_output_choices` for `Choices` only this time
        df_copy[f's4_prg_encode_choices_{mask_rate}'] = df_copy['s3_output_choices']

        for i, row in df_copy.iterrows():
            qa_id = row[col_id]
            values = df_word_new[(df_word_new[col_id]==qa_id) & (df_word_new[f'p_{mask_rate}_masked'])].apply(lambda x: {x['code']:x['word']}, axis=1)
            values = sorted(values, key=lambda x: x[next(iter(x))].count(' '), reverse=True)
            
            # for col in ['s3_output_context', 's3_output_Q', 's3_output_choices']:
            for col in ['s3_output_context', 's3_output_Q']:
                sub_text = row[col]
                for pair in values:
                    code = list(pair.keys())[0]
                    word = list(pair.values())[0]
                    sub_text = sub_text.replace(word, f"<{code}>")
                df_copy.loc[i, col.replace('s3_output_', 's4_prg_encode_')+f"_{mask_rate}"] = sub_text

    return df_copy

In [38]:
df_qa = step4(df_qa, list_mask_rate, df_word_new)

In [39]:
df_qa.to_csv(output_path_qa_tmp, encoding='utf-8-sig', index=False)

In [40]:
df_word_new.to_csv(output_path_step4, encoding='utf-8-sig', index=False)

In [34]:
df_word_new = pd.read_csv(output_path_step4)

In [None]:
# Calculate the missing rate of meanings for each masking rate
list_tmp = []
for mask_rate in list_mask_rate:
    if mask_rate == 0:
        continue
    
    tmp_total_lemma = df_word_new[df_word_new[f'p_{mask_rate}_masked']]['lemma'].nunique()
    tmp_empty_meaning_lemma = tmp_total_lemma - df_word_new[df_word_new[f'p_{mask_rate}_masked'] & df_word_new['meaning'].notna()]['lemma'].nunique()
    list_tmp.append({'MR': mask_rate, 'number of words with missing meanings': tmp_empty_meaning_lemma, 'word count': tmp_total_lemma, 'missing rate of meanings': tmp_empty_meaning_lemma / tmp_total_lemma})
    
pd.DataFrame(list_tmp)

## step5
- generate answers

In [24]:
import asyncio

In [40]:
async def exec_step5(client:AsyncOpenAI, model:str, i:int, output_col:str, query_col:str, context:str, question:str, choices: list, metadata: str, count:int):
    messages = []
    prompt = STEP5_PROMPT.format(
        context=context,
        metadata=metadata,
        question=question,
        option_list=[f"{i}. {val}" for i, val in enumerate(choices, 1)]
    )
    messages = append_message(USER, prompt, messages)
    
    while count > 0:
        try:
            res = await aget_response(client, model, messages, output_json=True)
            ret = (i, json.loads(res)['answer'], prompt, output_col, query_col)
            return ret

        except Exception as e:
            print('error', e)
            count -= 1

    return (i, '-1', '', output_col, query_col)

In [26]:
async def step5(aclient:AsyncOpenAI, model:str, df_qa:pd.DataFrame, df_word_new:pd.DataFrame) -> pd.DataFrame:
    df_copy = df_qa.copy()
    print('model:', model)

    results = []

    for mask_rate in list_mask_rate:
        tasks = []

        output_col = f'answer_{model}_{mask_rate}'
        query_col = f'query_{mask_rate}'
        df_copy[output_col] = ""
        df_copy[query_col] = ""

        col_context = f's4_prg_encode_context_{mask_rate}'
        col_Q = f's4_prg_encode_Q_{mask_rate}'
        col_choices_tmp = f's4_prg_encode_choices_{mask_rate}'
        
        for i, row in df_copy.iterrows():
            qa_id = row[col_id]

            meta_table = df_word_new[(df_word_new[col_id]==qa_id) & (df_word_new[f'p_{mask_rate}_masked'])].groupby('lemma').agg(
                {col: 'unique' for col in ['part_of_speech', 'category', 'meaning', 'code']}
            ).copy()
            if not meta_table.empty:
                for col in ['part_of_speech', 'category', 'meaning', 'code']:
                    meta_table[col] = meta_table[col].apply(lambda x: ', '.join(t for t in x if type(t)==str))
            md_metatable = convert_df_to_markdown(meta_table)
            count = trial_n
            tasks.append(asyncio.ensure_future(exec_step5(aclient, model, i, output_col, query_col, row[col_context], row[col_Q], row[col_choices_tmp], md_metatable, count)))

        results.extend(await asyncio.gather(*tasks))

    for i, ans, query, output_col, query_col in results:
        df_copy.loc[i, output_col] = ans
        df_copy.loc[i, query_col] = query

    return df_copy.copy()

In [None]:
for output_num in range(1, output_step+1):
    print('trial:', output_num) 
    df_result = await step5(openai_async_client, openai_model, df_qa, df_word_new)
    save_path_tmp = str(output_path_qa).format(trial_no=output_num)
    print(save_path_tmp)
    df_result.to_csv(save_path_tmp, encoding='utf-8-sig', index=False)

# 2. partial lifting

## step4

In [72]:
param = {   
    'converters': {col:literal_eval for col in list_convert_cols + ['s3_output_choices']}
}
df_qa = pd.read_csv(output_path_qa_tmp, **param).fillna('')

In [73]:
df_word_new_all_meaning = df_word_new[df_word_new['meaning'].notna()]

In [74]:
df_qa_all_meaning = step4(df_qa, list_mask_rate, df_word_new_all_meaning)

In [None]:
list_tmp = []
for mask_rate in list_mask_rate:
    if mask_rate == 0:
        continue

    tmp_total_lemma = df_word_new_all_meaning[df_word_new_all_meaning[f'p_{mask_rate}_masked']]['lemma'].nunique()
    tmp_empty_meaning_lemma = tmp_total_lemma - df_word_new_all_meaning[df_word_new_all_meaning[f'p_{mask_rate}_masked'] & df_word_new_all_meaning['meaning'].notna()]['lemma'].nunique()
    list_tmp.append({'MR': mask_rate, 'number of words with missing meanings': tmp_empty_meaning_lemma, 'word count': tmp_total_lemma, 'missing rate of meanings': tmp_empty_meaning_lemma / tmp_total_lemma})

pd.DataFrame(list_tmp)

## step5

In [None]:
for output_num in range(1, output_step+1):
    print('trial:', output_num) 
    df_result = await step5(openai_async_client, openai_model, df_qa_all_meaning, df_word_new_all_meaning)
    save_path_tmp = str(output_path_qa).format(trial_no=f"{output_num}_filtered")
    print(save_path_tmp)
    df_result.to_csv(save_path_tmp, encoding='utf-8-sig', index=False)

# 3. strict masking

### step4

In [115]:
param = {   
    'converters': {col:literal_eval for col in list_convert_cols + ['s3_output_choices']}
}
df_qa_no_meaning = pd.read_csv(output_path_qa_tmp, **param).fillna('')

In [116]:
df_word_new_no_meaning = pd.read_csv(output_path_step4)
df_word_new_no_meaning['meaning'] = None

In [117]:
df_qa_no_meaning = step4(df_qa_no_meaning, list_mask_rate, df_word_new_no_meaning)

In [None]:
list_tmp = []
for mask_rate in list_mask_rate:
    if mask_rate == 0:
        continue
    
    tmp_total_lemma = df_word_new_no_meaning[df_word_new_no_meaning[f'p_{mask_rate}_masked']]['lemma'].nunique()
    tmp_empty_meaning_lemma = tmp_total_lemma - df_word_new_no_meaning[df_word_new_no_meaning[f'p_{mask_rate}_masked'] & df_word_new_no_meaning['meaning'].notna()]['lemma'].nunique()
    list_tmp.append({'MR': mask_rate, 'number of words with missing meanings': tmp_empty_meaning_lemma, 'word count': tmp_total_lemma, 'missing rate of meanings': tmp_empty_meaning_lemma / tmp_total_lemma})
    
pd.DataFrame(list_tmp)

### step5

In [None]:
for output_num in range(1, output_step+1):
    print('trial:', output_num) 
    df_result = await step5(openai_async_client, openai_model, df_qa_no_meaning, df_word_new_no_meaning)
    save_path_tmp = str(output_path_qa).format(trial_no=f"{output_num}_no_meaning")
    print(save_path_tmp)
    df_result.to_csv(save_path_tmp, encoding='utf-8-sig', index=False)

# 4. lenient masking

### step4

In [125]:
param = {   
    'converters': {col:literal_eval for col in list_convert_cols + ['s3_output_choices']}
}
df_qa_no_verb = pd.read_csv(output_path_qa_tmp, **param).fillna("")
df_word_new_no_verb = pd.read_csv(output_path_step4)

In [127]:
unique_lemma = df_word_new_no_verb.groupby([col_id, 'lemma'])['part_of_speech'].apply('unique').reset_index()
unique_lemma['contains_verb'] = unique_lemma['part_of_speech'].apply(lambda x: 'verb' in [pos.lower() for pos in x if pos])

In [None]:
df_word_new_no_verb = df_word_new_no_verb.merge(unique_lemma.rename(columns={'part_of_speech': 'POS_unique'}), how='left', on=[col_id, 'lemma'])
df_word_new_no_verb.loc[df_word_new_no_verb['contains_verb'], [f'p_{mr}_masked' for mr in list_mask_rate]] = False

In [130]:
df_qa_no_verb = step4(df_qa_no_verb, list_mask_rate, df_word_new_no_verb)

In [None]:
list_tmp = []
for mask_rate in list_mask_rate:
    if mask_rate == 0:
        continue
    
    tmp_total_lemma = df_word_new_no_verb[df_word_new_no_verb[f'p_{mask_rate}_masked']]['lemma'].nunique()
    tmp_empty_meaning_lemma = tmp_total_lemma - df_word_new_no_verb[df_word_new_no_verb[f'p_{mask_rate}_masked'] & df_word_new_no_verb['meaning'].notna()]['lemma'].nunique()
    list_tmp.append({'MR': mask_rate, 'number of words with missing meanings': tmp_empty_meaning_lemma, 'word count': tmp_total_lemma, 'missing rate of meanings': tmp_empty_meaning_lemma / tmp_total_lemma})
    
pd.DataFrame(list_tmp)

### step5

In [None]:
for output_num in range(1, output_step+1):
    print('trial:', output_num) 
    df_result = await step5(openai_async_client, openai_model, df_qa_no_verb, df_word_new_no_verb)
    save_path_tmp = str(output_path_qa).format(trial_no=f"{output_num}_no_verb")
    print(save_path_tmp)
    df_result.to_csv(save_path_tmp, encoding='utf-8-sig', index=False)