In [11]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [12]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="cot", type=str)
parser.add_argument("--model", default="llama2-70b", type=str, help=" ")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.5,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"traced.json",#traced file
    type=str,
)
parser.add_argument(
    "--tables_json_file",
    default=r"tables.json",#table files
    type=str,
)
parser.add_argument(
    "--topk_path",
    default=r"request_tok",#text files
    
    type=str,
)

args = parser.parse_args("")

In [13]:
demonstration = {}
demonstration["none"] = ""
with open("examples/fullmodel_direct_3shot.json", "r") as f:
    demonstration["direct"] = json.load(f)
with open("examples/fullmodel_cot_3shot.json", "r") as f:
    demonstration["cot"] = json.load(f)

In [14]:
def read_data(args):
    # Load traced JSON file
    data_test_traced = json.load(open(args.traced_json_file, "r"))
    data_list = []
    for sample in tqdm(data_test_traced[args.start:args.end]):
        table_id = sample["table_id"]
        question_data = None
        for q_data in questions_data:
            if q_data['table_id'] == table_id:
                question_data = q_data
                break
        if question_data is None:
            print(f"No question data found for {table_id}")
            continue
        
        # Read JSON file from tables_tok
        try:
            tables_tok_path = f"{table_id}.json"  # put your traced table link
            with open(tables_tok_path, 'r') as f:
                table_data = json.load(f)
        except Exception:
            print(f"The file {table_id} does not exist.")
            continue

        question_type = question_data['type']
        if question_type == 'bridge':
            # Get the index of the most relevant row
            row_index = question_data['row_pre']
            relevant_rows = [table_data['data'][row_index]]
        elif question_type == 'comparison':
            # Get the indices of all rows with relevance less than or equal to 1.0
            row_pre_logits = question_data['row_pre_logit']
            relevant_rows = [table_data['data'][i] for i, logit in enumerate(row_pre_logits) if logit <= 1.0]
        else:
            print(f"Unknown question type: {question_type}")
            continue

        # Read text data
        try:
            text_file = os.path.join(args.text_path, f"{table_id}.json")
            with open(text_file, "r") as f:
                text_data = json.load(f)
        except Exception:
            print(f"The file {text_file} does not exist.")
            continue
            
        question_text = sample["question"]
        answer_text = sample["pred"]
        
        # Extract wiki links from nodes and target
        wikis = [
            node[2]
            for node in sample["nodes"]
            if node[2] is not None and node[2].startswith("/wiki")
        ]
        
        target_wiki = sample["target"][2]
        if target_wiki and target_wiki.startswith("/wiki"):
            wikis.append(target_wiki)
        
        # Get the corresponding text for each wiki link
        wiki_text = ""
        if wikis:
            wiki_lines = [text_data.get(wiki, "") for wiki in wikis]
            wiki_text = "\n".join(wiki_lines)
        
        # Create a DataFrame from the table data
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in table_data["data"]],
            columns=list(zip(*table_data["header"]))[0],
        )

        data_list.append({
            "table_id": table_id,
            "question": question_text,
            "answer": answer_text,
            "table": df,
            "wiki": wiki_text,
            "title": table_data["title"],
            "intro": table_data["intro"]
        })

    return data_list

# Load questions data
questions_path = "test.json"  # put text answer here
with open(questions_path, 'r') as f:
    questions_data = json.load(f)

def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except Exception as e:
        #print(f"Error formatting table: {data}, error: {e}")
        return ""


In [15]:
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
fw = open(f"outputs/response_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
fw.write(json.dumps(tmp) + "\n")

3820

In [16]:
data_list = read_data(args)

 32%|█████████████████████████▏                                                      | 189/600 [00:09<00:20, 20.02it/s]

The file List_of_songs_in_The_Beatles:_Rock_Band_0 does not exist.


 47%|█████████████████████████████████████▋                                          | 283/600 [00:14<00:14, 21.51it/s]

The file ISO_3166-2:KN_1 does not exist.


 54%|███████████████████████████████████████████                                     | 323/600 [00:17<00:13, 21.01it/s]

The file Roy_"Royalty"_Hamilton_0 does not exist.


 76%|████████████████████████████████████████████████████████████▌                   | 454/600 [00:25<00:05, 25.82it/s]

The file List_of_Summer_Olympics_venues:_L_0 does not exist.


 95%|████████████████████████████████████████████████████████████████████████████▏   | 571/600 [00:31<00:01, 23.05it/s]

The file Magic:_The_Gathering_Hall_of_Fame_0 does not exist.


 99%|██████████████████████████████████████████████████████████████████████████████▉ | 592/600 [00:32<00:00, 23.68it/s]

The file ISO_3166-2:FR_2 does not exist.


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:32<00:00, 18.26it/s]


In [17]:
#Load model or API

In [18]:
for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the question.\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "Text" + '\n' + entry['wiki'] + '\n\n'

    prompt += 'Question: ' + question + '\nAnswer:'

    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[5].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response,
        "answer": answer,
        "table_id": entry["table_id"],
    }

    fw.write(json.dumps(tmp) + "\n")

fw.close()


100%|████████████████████████████████████████████████████████████████████████████████| 594/594 [11:21<00:00,  1.15s/it]


In [34]:
#old prompt

for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the following question.\n\n'
    prompt += f"The table contains important information and this is the introduction of the table:" + '\n' + entry['intro'] + '\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "I believe the following text information will help answer the question:" + '\n' + entry['wiki'] + '\n\n'
        prompt += "Please think step by step. Please also show me the reasoning process." + '\n\n'
    prompt += 'Question: ' + question + '\nAnswer:'

    response_raw = query({'inputs': prompt})

    try:
        response = response_raw[0]['generated_text'].split('\nRead the table and text')[1].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response_raw[0]['generated_text'].split('\nRead the table and text')[1].strip(),
        "answer": answer,
        "table_id": entry["table_id"],
    }

    fw.write(json.dumps(tmp) + "\n")

fw.close()


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:10<00:00,  2.13s/it]


In [52]:
print(response_raw[0]['generated_text'])



Read the table and text regarding "Stockholm Marathon" to answer the following question.

The table contains important information and this is the introduction of the table:
The Stockholm Marathon, known as the ASICS Stockholm Marathon for sponsorship reasons, is an annual marathon arranged in Stockholm, Sweden, since 1979. It serves as the Swedish marathon championship race. At the 2009 Stockholm Marathon more than 18,500 participants (14,442 men and 4,385 women) were registered. [citation needed]

Year | Athlete | Country | Time ( h : m : s )
1979 | Jukka Toivola | Finland | 2:17:35
1980 | Jeff Wells | United States | 2:15:49
1981 | Bill Rodgers | United States | 2:13:26
1982 | Kjell-Erik Ståhl | Sweden - Hässleholms AIS | 2:19:20
1983 | Hugh Jones | United Kingdom | 2:11:37
1984 | Agapius Masong | Tanzania | 2:13:47
1985 | Tommy Persson | Sweden - Heleneholms IF | 2:17:18
1986 | Kjell-Erik Ståhl | Sweden - Enhörna IF | 2:12:33
1987 | Kevin Forster | United Kingdom | 2:13:52
1988 |