In [None]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="cot", type=str)
parser.add_argument("--model", default="llama2-70b", type=str, help=" ")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.5,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"traced.json",#traced file
    type=str,
)
parser.add_argument(
    "--tables_json_file",
    default=r"tables.json",#table files
    type=str,
)
parser.add_argument(
    "--topk_path",
    default=r"request_tok",#text files
    
    type=str,
)

args = parser.parse_args("")

In [None]:
demonstration = {}
demonstration["none"] = ""
with open("examples/fullmodel_direct_3shot.json", "r") as f:
    demonstration["direct"] = json.load(f)
with open("examples/fullmodel_cot_3shot.json", "r") as f:
    demonstration["cot"] = json.load(f)

In [None]:
def read_data(args):
    # Load traced JSON file
    data_test_traced = json.load(open(args.traced_json_file, "r"))
    data_list = []
    for sample in tqdm(data_test_traced[args.start:args.end]):
        table_id = sample["table_id"]
        question_data = None
        for q_data in questions_data:
            if q_data['table_id'] == table_id:
                question_data = q_data
                break
        if question_data is None:
            print(f"No question data found for {table_id}")
            continue
        
        # Read JSON file from tables_tok
        try:
            tables_tok_path = f"{table_id}.json"  # put your traced table link
            with open(tables_tok_path, 'r') as f:
                table_data = json.load(f)
        except Exception:
            print(f"The file {table_id} does not exist.")
            continue

        question_type = question_data['type']
        if question_type == 'bridge':
            # Get the index of the most relevant row
            row_index = question_data['row_pre']
            relevant_rows = [table_data['data'][row_index]]
        elif question_type == 'comparison':
            # Get the indices of all rows with relevance less than or equal to 1.0
            row_pre_logits = question_data['row_pre_logit']
            relevant_rows = [table_data['data'][i] for i, logit in enumerate(row_pre_logits) if logit <= 1.0]
        else:
            print(f"Unknown question type: {question_type}")
            continue

        # Read text data
        try:
            text_file = os.path.join(args.text_path, f"{table_id}.json")
            with open(text_file, "r") as f:
                text_data = json.load(f)
        except Exception:
            print(f"The file {text_file} does not exist.")
            continue
            
        question_text = sample["question"]
        answer_text = sample["pred"]
        
        # Extract wiki links from nodes and target
        wikis = [
            node[2]
            for node in sample["nodes"]
            if node[2] is not None and node[2].startswith("/wiki")
        ]
        
        target_wiki = sample["target"][2]
        if target_wiki and target_wiki.startswith("/wiki"):
            wikis.append(target_wiki)
        
        # Get the corresponding text for each wiki link
        wiki_text = ""
        if wikis:
            wiki_lines = [text_data.get(wiki, "") for wiki in wikis]
            wiki_text = "\n".join(wiki_lines)
        
        # Create a DataFrame from the table data
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in table_data["data"]],
            columns=list(zip(*table_data["header"]))[0],
        )

        data_list.append({
            "table_id": table_id,
            "question": question_text,
            "answer": answer_text,
            "table": df,
            "wiki": wiki_text,
            "title": table_data["title"],
            "intro": table_data["intro"]
        })

    return data_list

# Load questions data
questions_path = "test.json"  # put text answer here
with open(questions_path, 'r') as f:
    questions_data = json.load(f)

def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except Exception as e:
        #print(f"Error formatting table: {data}, error: {e}")
        return ""


In [None]:
#Load model or API

In [None]:
run_count = 0

subquestion_file = f"outputs/subquestion_s{args.start}_e{args.end}_{args.option}_{args.model}_{run_count}.json"
subquestion_fw = open(subquestion_file, "w")

tmp = {"demonstration": demonstration[args.option]}
subquestion_fw.write(json.dumps(tmp) + "\n")

In [None]:
data_list = read_data(args)

In [None]:
with open('question_test.json', 'r') as f:
    all_questions = []
    for line in f:
        data = json.loads(line)
        all_questions.append(data['response'])

with open('subquestion_spacy.txt', 'r', encoding='utf-8') as f:
    entity_data = [line.strip() for line in f]

question_idx = 0

for entry, entity_entry in zip(tqdm(data_list), entity_data):
    # Retrieve the current question from the list of questions
    if question_idx < len(all_questions):
        question = all_questions[question_idx]
    else:
        # Terminate the loop if the list of questions is exhausted
        break
    
    prompt = demonstration[args.option] + '\n\n'
    # Formalizing the k-shot demonstration.
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the question.\n\n'
    prompt += df_format(entry['table']) + '\n'

    if entry['wiki']:
        prompt += "Text:" + '\n' + entry['wiki'] + '\n\n'
    prompt += 'The answer should be a/an ' + entity_entry + '\n\n'
    prompt += 'Let\'s think step by step, to answer the question: ' + question + '\nAnswer:'

    # Process the question and answer...

    # Update the question index
    question_idx += 1
    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[3].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "sub_question": question,
        "sub_answer": response,
        "table_id": entry["table_id"],
    }

    subquestion_fw.write(json.dumps(tmp) + "\n")

subquestion_fw.close()


In [None]:
response_raw

In [None]:
print(prompt)

In [None]:
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
answer_fw = open(f"outputs/answer_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
answer_fw.write(json.dumps(tmp) + "\n")

In [None]:
data_list = read_data(args)

In [None]:
with open('outputs/subanswer.json', 'r') as f:  #subquestion answer here
    next(f)  # Skip the first line
    subquestion_data = [json.loads(line) for line in f]
    
with open('entity.txt', 'r', encoding='utf-8') as f:
    entity_data = [line.strip() for line in f]
    
# Iterate over data_list and evidence_data simultaneously
for entry, subquestion_entry, entity_entry in zip(tqdm(data_list), subquestion_data, entity_data):
    question = entry['question']
    answer = entry['answer']
    table_id = entry['table_id']
    subanswer = subquestion_entry.get('sub_answer', '')  # Use .get() to handle KeyError
    subquestion = subquestion_entry.get('sub_question', '')
    subquestion_table_id = subquestion_entry.get('table_id', '')  # Get evidence table_id
    

    # Check if evidence table_id matches the entry table_id
    if subquestion_table_id != table_id:
        print(f"Warning: Table ID mismatch for question '{question}'.")
        # Optionally, you can choose to skip this entry or handle it differently

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the following table and text regarding "{entry["title"]}":'+'and answer the question.\n\n'
    prompt += df_format(entry['table']) + '\n'

    if entry['wiki']:
        prompt += "Text: " + '\n' + entry['wiki'] + '\n\n'
        
    # Add evidence to the prompt
    prompt += "Subquestion: " + subquestion + "\nThe answer of subquestion: " + subanswer + '\n\n'
    prompt += "Using exactly the same word from the text and table as answer can achieve better correct rate.\n"
    prompt += "Simplify your answer to a/an :" + entity_entry
    prompt += 'Lets think step by step and answer question: ' + question 
    prompt += '\nAnswer:'
    response_raw = query({'inputs': prompt})

    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[4].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response,
        "answer": answer,
        "entity":entity_entry,
        "table_id": entry["table_id"],
        "sub_answer": subanswer
    }

    answer_fw.write(json.dumps(tmp) + "\n")

answer_fw.close()


In [None]:
print(prompt)

In [None]:
response_raw[0]['generated_text'].split('\nAnswer:')[1].split('\n')[0].strip()

In [None]:
print(response_raw[0]['generated_text'])