In [11]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [12]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="none", type=str)
parser.add_argument("--model", default="llama3", type=str, help="qwen1.5-14b-chat and qwen-turbo are better")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.7,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"retrieved_data",
    type=str,
)

parser.add_argument(
    "--text_path",
    default=r"WikiTables-WithLinks-master\request_tok",
    
    type=str,
)
parser.add_argument(
    "--table_path",
    default=r"WikiTables-WithLinks-master\tables_tok",
    
    type=str,
)

args = parser.parse_args("")

In [13]:
demonstration = {}
demonstration["none"] = ""


In [14]:
def read_data(args):
    data_test_traced = json.load(open(args.traced_json_file, "r"))
    data_list = []
    for sample in tqdm(data_test_traced[args.start:args.end]):
        table_id = sample["table_id"]
        question_data = None
        for q_data in questions_data:
            if q_data['table_id'] == table_id:
                question_data = q_data
                break
        if question_data is None:
            print(f"No question data found for {table_id}")
            continue

        try:
            tables_tok_path = f"WikiTables-WithLinks-master\\tables_tok\\{table_id}.json"
            with open(tables_tok_path, 'r') as f:
                table_data = json.load(f)
        except Exception:
            print(f"The file {table_id} does not exist.")
            continue

        # Read text data
        try:
            text_file = os.path.join(args.text_path, f"{table_id}.json")
            with open(text_file, "r") as f:
                text_data = json.load(f)
        except Exception:
            print(f"The file {text_file} does not exist.")
            continue
            
        question_text = sample["question"]
        answer_text = sample["pred"]
         # Extract wiki links from nodes and target
        wikis = [
            node[2]
            for node in sample["nodes"]
            if node[2] is not None and node[2].startswith("/wiki")
        ]
        
        target_wiki = sample["target"][2]
        if target_wiki and target_wiki.startswith("/wiki"):
            wikis.append(target_wiki)
        
        # Get the corresponding text for each wiki link
        wiki_text = ""
        if wikis:
            wiki_lines = [text_data.get(wiki, "") for wiki in wikis]
            wiki_text = "\n".join(wiki_lines)
        
        
        # Create a DataFrame from the table data
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in table_data["data"]],
            columns=list(zip(*table_data["header"]))[0],
        )

        data_list.append({
            "table_id": table_id,
            "question": question_text,
            "answer": answer_text,
            "table": df,
            "wiki": wiki_text,
            "title": table_data["title"],
            "intro": table_data["intro"]
        })

    return data_list

questions_path = "Data/HybridQA/test.json"
with open(questions_path, 'r') as f:
    questions_data = json.load(f)

def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except Exception as e:
        #print(f"Error formatting table: {data}, error: {e}")
        return ""


In [15]:
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
fw = open(f"outputs/response_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
fw.write(json.dumps(tmp) + "\n")

3820

In [16]:
data_list = read_data(args)

 32%|█████████████████████████▏                                                      | 189/600 [00:09<00:20, 20.02it/s]

The file List_of_songs_in_The_Beatles:_Rock_Band_0 does not exist.


 47%|█████████████████████████████████████▋                                          | 283/600 [00:14<00:14, 21.51it/s]

The file ISO_3166-2:KN_1 does not exist.


 54%|███████████████████████████████████████████                                     | 323/600 [00:17<00:13, 21.01it/s]

The file Roy_"Royalty"_Hamilton_0 does not exist.


 76%|████████████████████████████████████████████████████████████▌                   | 454/600 [00:25<00:05, 25.82it/s]

The file List_of_Summer_Olympics_venues:_L_0 does not exist.


 95%|████████████████████████████████████████████████████████████████████████████▏   | 571/600 [00:31<00:01, 23.05it/s]

The file Magic:_The_Gathering_Hall_of_Fame_0 does not exist.


 99%|██████████████████████████████████████████████████████████████████████████████▉ | 592/600 [00:32<00:00, 23.68it/s]

The file ISO_3166-2:FR_2 does not exist.


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:32<00:00, 18.26it/s]


load llama model here


In [None]:

for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']
    
    # Generate multiple reasoning chains
    reasoning_chains = []
    for _ in range(3):  # Generate 3 reasoning chains
        prompt = demonstration[args.option] + '\n\n'
        prompt += f'Read the table and text regarding "{entry["title"]}" to answer the question.\n\n'
        prompt += df_format(entry['table']) + '\n'
        
        if entry['wiki']:
            prompt += "Text" + '\n' + entry['wiki'] + '\n\n'
        prompt += 'Question: ' + question + '\nAnswer:'
        
        response_raw = query({'inputs': prompt})
        try:
            response = response_raw[0].get('generated_text', '').split('\nAnswer:')[5].split('Reasoning process')[0].strip()
        except KeyError:
            response = ''
        response = response.split('\n')[0].strip()
        reasoning_chains.append(response)
    
    # Combine reasoning chains for self-consistency
    combined_prompt = f"Based on the following reasoning chains, provide the most reasonable answer:\n\n"
    for i, chain in enumerate(reasoning_chains, 1):
        combined_prompt += f"Chain {i}: {chain}\n\n"
    combined_prompt += f"Question: {question}\nFinal Answer:"
    
    final_response_raw = query({'inputs': combined_prompt})
    try:
        final_response = final_response_raw[0].get('generated_text', '').split('Final Answer:')[1].strip()
    except IndexError:
        final_response = ''
    
    tmp = {
        "question": question,
        "response": final_response,
        "answer": answer,
        "table_id": entry["table_id"],
        "reasoning_chains": reasoning_chains
    }
    fw.write(json.dumps(tmp) + "\n")

fw.close()
