In [1]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="cot", type=str)
parser.add_argument("--model", default="llama2-70b", type=str, help=" ")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.5,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"traced.json",#traced file
    type=str,
)
parser.add_argument(
    "--tables_json_file",
    default=r"tables.json",#table files
    type=str,
)
parser.add_argument(
    "--topk_path",
    default=r"request_tok",#text files
    
    type=str,
)

args = parser.parse_args("")

In [3]:
demonstration = {}
demonstration["none"] = ""
with open("examples/fullmodel_direct_2shot.json", "r") as f:
    demonstration["direct"] = json.load(f)
with open("examples/fullmodel_cot_2shot.json", "r") as f:
    demonstration["cot"] = json.load(f)

In [4]:
def read_data(args):
    # Read traced JSON file
    data_test_traced = json.load(open(args.traced_json_file, "r"))
    data_list = []
    for sample in tqdm(data_test_traced[args.start:args.end]):
        table_id = sample["table_id"]
        question_data = None
        for q_data in questions_data:
            if q_data['table_id'] == table_id:
                question_data = q_data
                break
        if question_data is None:
            print(f"No question data found for {table_id}")
            continue
        
        # Read JSON file from tables_tok
        try:
            tables_tok_path = f"{table_id}.json"  # traced table link
            with open(tables_tok_path, 'r') as f:
                table_data = json.load(f)
        except Exception:
            print(f"The file {table_id} does not exist.")
            continue

        question_type = question_data['type']
        if question_type == 'bridge':
            # Get index of the most relevant row
            row_index = question_data['row_pre']
            relevant_rows = [table_data['data'][row_index]]
        elif question_type == 'comparison':
            # Get indices of all rows with relevance less than 1.0
            row_pre_logits = question_data['row_pre_logit']
            relevant_rows = [table_data['data'][i] for i, logit in enumerate(row_pre_logits) if logit < 1.0]
        else:
            print(f"Unknown question type: {question_type}")
            continue

        # Read text data
        try:
            text_file = os.path.join(args.text_path, f"{table_id}.json")
            with open(text_file, "r") as f:
                text_data = json.load(f)
        except Exception:
            print(f"The file {text_file} does not exist.")
            continue
            
        question_text = sample["question"]
        answer_text = sample["answer-text"]
        

        # Extract wiki links from nodes
        wikis = [
            node[2]
            for node in sample.get("answer-node", [])
            if node[2] is not None and node[2].startswith("/wiki")
        ]

        # Get the corresponding text for each wiki link
        wiki_texts = []
        for wiki_link in wikis:
            wiki_text = text_data.get(wiki_link, "")
            wiki_texts.append(wiki_text)

        # Concatenate wiki_texts into a string, separating each wiki's text content with newline
        wiki_text = "\n".join(wiki_texts)


        # Create a DataFrame from the table data
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in table_data["data"]],
            columns=list(zip(*table_data["header"]))[0],
        )

        # Create a DataFrame from the relevant rows
        try:
            # Flatten the table header
            flattened_header = [col[0] for col in table_data["header"]]
            # Flatten the relevant rows
            flattened_data = [[cell if not isinstance(cell, list) else cell[0] for cell in row] for row in relevant_rows]
            df = pd.DataFrame(flattened_data, columns=flattened_header)
        except Exception as e:
            print(f"Error creating DataFrame for {table_id}: {e}")
            continue

        data_list.append({
            "table_id": table_id,
            "question": question_text,
            "answer": answer_text,
            "table": df,
            "wiki": wiki_text,
            "title": table_data["title"],
            "intro": table_data["intro"]
        })

    return data_list

# Read questions data from the development set's standard answers
questions_path = "dev"
with open(questions_path, 'r') as f:
    questions_data = json.load(f)

def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except Exception as e:
        #print(f"Error formatting table: {data}, error: {e}")
        return ""


In [5]:
#Load model or API

In [14]:
run_count = 0

subquestion_file = f"outputs/subquestion_s{args.start}_e{args.end}_{args.option}_{args.model}_{run_count}.json"
subquestion_fw = open(subquestion_file, "w")

tmp = {"demonstration": demonstration[args.option]}
subquestion_fw.write(json.dumps(tmp) + "\n")

2556

In [72]:
data_list = read_data(args)

 37%|████████████████████████████▋                                                 | 221/600 [00:00<00:00, 1101.04it/s]

The file Ebertfest:_Roger_Ebert's_Film_Festival_3 does not exist.
The file Ebertfest:_Roger_Ebert's_Film_Festival_16 does not exist.
The file List_of_National_Treasures_of_Japan_(writings:_Japanese_books)_0 does not exist.
The file Brad_Nelson_(Magic:_The_Gathering_player)_0 does not exist.
The file Looney_Tunes_Golden_Collection:_Volume_3_0 does not exist.


100%|██████████████████████████████████████████████████████████████████████████████| 600/600 [00:00<00:00, 1046.17it/s]

The file Looney_Tunes_Golden_Collection:_Volume_5_1 does not exist.
The file Ebertfest:_Roger_Ebert's_Film_Festival_13 does not exist.
The file List_of_microcars_by_country_of_origin:_J_0 does not exist.





In [None]:
with open('question_dev.json', 'r', encoding='utf-8') as f:
    subquestion_data = [json.loads(line) for line in f]
        
with open('summary.json', 'r') as f: #load your summary here
    summary_data = [json.loads(line) for line in f]

with open('subquestion_entity.txt', 'r', encoding='utf-8') as f: #load your entity for subquestion here
    entity_data = [line.strip() for line in f]


question_idx = 0

for entry, entity_entry,subquestion_entry, summary_entry in zip(tqdm(data_list), entity_data, subquestion_data, summary_data):
    summary = summary_entry.get('summary', '')
    subquestion = subquestion_entry.get('response', '')

    prompt = demonstration[args.option] + '\n\n'
    #### Formalizing the k-shot demonstration. #####
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the question.\n\n'
    prompt += df_format(entry['table']) + '\n'

    if entry['wiki']:
        prompt += "Text:" + '\n' + entry['wiki'] + '\n\n'
    #prompt += 'Summary: ' + summary + '\n\n'
    prompt += 'The answer should be a/an ' + entity_entry + '\n\n'
    prompt += 'Let us think step by step, and answer the question: ' + subquestion + '\nAnswer:'

    # 处理问题和答案...

    # 更新问题索引
    question_idx += 1
    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[3].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "sub_question": subquestion,
        "entity":entity_entry,
        "sub_answer": response,
        "table_id": entry["table_id"],
    }

    subquestion_fw.write(json.dumps(tmp) + "\n")

subquestion_fw.close()


In [6]:
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
answer_fw = open(f"outputs/answer_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
answer_fw.write(json.dumps(tmp) + "\n")

5647

In [7]:
data_list = read_data(args)

  6%|█████▎                                                                           | 39/600 [00:02<00:29, 18.99it/s]

The file Ebertfest:_Roger_Ebert's_Film_Festival_3 does not exist.


  9%|███████▎                                                                         | 54/600 [00:03<00:26, 20.65it/s]

The file Ebertfest:_Roger_Ebert's_Film_Festival_16 does not exist.


 26%|████████████████████▊                                                           | 156/600 [00:09<00:20, 21.51it/s]

The file List_of_National_Treasures_of_Japan_(writings:_Japanese_books)_0 does not exist.


 31%|████████████████████████▌                                                       | 184/600 [00:10<00:23, 18.00it/s]

The file Brad_Nelson_(Magic:_The_Gathering_player)_0 does not exist.


 40%|████████████████████████████████▍                                               | 243/600 [00:14<00:20, 17.10it/s]

The file Looney_Tunes_Golden_Collection:_Volume_3_0 does not exist.


 89%|███████████████████████████████████████████████████████████████████████▎        | 535/600 [00:33<00:03, 20.92it/s]

The file Looney_Tunes_Golden_Collection:_Volume_5_1 does not exist.


 90%|████████████████████████████████████████████████████████████████████████▏       | 541/600 [00:33<00:02, 21.96it/s]

The file Ebertfest:_Roger_Ebert's_Film_Festival_13 does not exist.


 98%|██████████████████████████████████████████████████████████████████████████████▌ | 589/600 [00:36<00:00, 20.38it/s]

The file List_of_microcars_by_country_of_origin:_J_0 does not exist.


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:36<00:00, 16.39it/s]


In [9]:
with open('outputs/subquestion.json', 'r') as f: #load your subquestion answer here
    subquestion_data = [json.loads(line) for line in f]
    
with open('spacy_dev.txt', 'r', encoding='utf-8') as f: #load your entity for full question here
    entity_data = [line.strip() for line in f]
    
with open('summary.json', 'r') as f: #load your summary here
    summary_data = [json.loads(line) for line in f]
    
# Iterate over data_list and evidence_data simultaneously
for entry, subquestion_entry, entity_entry, summary_entry in zip(tqdm(data_list), subquestion_data, entity_data, summary_data):
    question = entry['question']
    answer = entry['answer']
    table_id = entry['table_id']
    subanswer = subquestion_entry.get('sub_answer', '')  # Use .get() to handle KeyError
    subquestion = subquestion_entry.get('sub_question', '')
    subquestion_table_id = subquestion_entry.get('table_id', '')  # Get evidence table_id
    summary = summary_entry.get('summary', '')
    

    # Check if evidence table_id matches the entry table_id
    if subquestion_table_id != table_id:
        print(f"Warning: Table ID mismatch for question '{question}'.")
        # Optionally, you can choose to skip this entry or handle it differently

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the following table and text regarding "{entry["title"]}":'+'and answer the question.\n\n'
    prompt += df_format(entry['table']) + '\n'

    if entry['wiki']:
        prompt += "Text: " + '\n' + entry['wiki'] + '\n\n'
    #prompt +=  "Summary: " + summary+ '\n\n'
    # Add evidence to the prompt
    prompt += "Subquestion as hint: " + subquestion + "\nThe answer of subquestion: " + subanswer + '\n\n'
    prompt += "Using exactly the same word from the text and table as answer can achieve better correct rate. Simplify your answer to a/an:" + entity_entry + '\n\n'
    prompt += 'Lets think step by step, and answer the question: ' + question 
    prompt += '\nAnswer:'
    response_raw = query({'inputs': prompt})

    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[4].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response,
        "answer": answer,
        "entity":entity_entry,
        "table_id": entry["table_id"],
        "sub_answer": subanswer
    }

    answer_fw.write(json.dumps(tmp) + "\n")

answer_fw.close()


100%|████████████████████████████████████████████████████████████████████████████████| 592/592 [22:05<00:00,  2.24s/it]


In [52]:
print(response_raw[0]['generated_text'])



Read the table and text regarding "Stockholm Marathon" to answer the following question.

The table contains important information and this is the introduction of the table:
The Stockholm Marathon, known as the ASICS Stockholm Marathon for sponsorship reasons, is an annual marathon arranged in Stockholm, Sweden, since 1979. It serves as the Swedish marathon championship race. At the 2009 Stockholm Marathon more than 18,500 participants (14,442 men and 4,385 women) were registered. [citation needed]

Year | Athlete | Country | Time ( h : m : s )
1979 | Jukka Toivola | Finland | 2:17:35
1980 | Jeff Wells | United States | 2:15:49
1981 | Bill Rodgers | United States | 2:13:26
1982 | Kjell-Erik Ståhl | Sweden - Hässleholms AIS | 2:19:20
1983 | Hugh Jones | United Kingdom | 2:11:37
1984 | Agapius Masong | Tanzania | 2:13:47
1985 | Tommy Persson | Sweden - Heleneholms IF | 2:17:18
1986 | Kjell-Erik Ståhl | Sweden - Enhörna IF | 2:12:33
1987 | Kevin Forster | United Kingdom | 2:13:52
1988 |