In [1]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="cot", type=str)
parser.add_argument("--model", default="llama2-70b", type=str, help="qwen1.5-14b-chat and qwen-turbo are better")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.5,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"traced.json",#traced file
    type=str,
)
parser.add_argument(
    "--tables_json_file",
    default=r"tables.json",#table files
    type=str,
)
parser.add_argument(
    "--topk_path",
    default=r"request_tok",#text files
    
    type=str,
)

args = parser.parse_args("")

In [3]:


def read_data(args):
    data_train_traced = json.load(open(args.traced_json_file, "r"))
    traindev_table = json.load(open(args.tables_json_file, "r"))

    data_list = []
    for sample in tqdm(data_train_traced[args.start:args.end]):
        table_id = sample["table_id"]
        try:
            topk = json.load(open(os.path.join(args.topk_path, f"{table_id}.json"), "r"))
        except Exception:
            print(f"The file {os.path.join(args.topk_path, f'{table_id}.json')} does not exist.")
            continue
        question_text = sample["question"]
        answer_text = sample["answer-text"]
        wikis = [
            node[2]
            for node in sample["answer-node"]
            if node[2] is not None and node[2].startswith("/wiki")
        ]
        if len(wikis) == 0:
            wiki_text = ""
        else:
            wiki_text = "\n".join([topk[wiki] for wiki in wikis])
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in traindev_table[table_id]["data"]],
            columns=list(zip(*traindev_table[table_id]["header"]))[0],
        )
        data_list.append(
            {
                "question": question_text,
                "answer": answer_text,
                "title": traindev_table[table_id]["title"],
                "table": df,
                "wiki": wiki_text,
                "table_id": table_id,
                "intro": traindev_table[table_id]["intro"]
            }
        )
    return data_list


def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except:
        print(f"wrong table: {csv_path}")
        return ""




In [4]:
demonstration = {}
demonstration["none"] = ""
with open("examples/fullmodel_direct_3shot.json", "r") as f:
    demonstration["direct"] = json.load(f)
with open("examples/fullmodel_cot_3shot.json", "r") as f:
    demonstration["cot"] = json.load(f)

In [5]:
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
fw = open(f"outputs/response_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
fw.write(json.dumps(tmp) + "\n")

22

In [6]:
data_list = read_data(args)

 11%|████████▋                                                                       | 65/600 [00:00<00:03, 176.26it/s]

The file data\traindev_request_tok\Rachael_vs._Guy:_Celebrity_Cook-Off_2.json does not exist.


100%|███████████████████████████████████████████████████████████████████████████████| 600/600 [00:03<00:00, 187.16it/s]


In [7]:
#Load model or API

In [8]:
#very long answer with question and reasoning process

for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the following question.\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "Text:" + '\n' + entry['wiki'] + '\n\n'
    prompt += 'Question: ' + question + '\nAnswer:'

    response_raw = query({'inputs': prompt})

    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[1].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response,
        "answer": answer,
        "table_id": entry["table_id"],
    }

    fw.write(json.dumps(tmp) + "\n")

fw.close()


100%|████████████████████████████████████████████████████████████████████████████████| 599/599 [19:41<00:00,  1.97s/it]


In [9]:
print(prompt)



Read the table and text regarding "World record progression 100 metres butterfly" to answer the following question.

Time | Swimmer | Date | Place
58.77 | Angela Kennedy | February 18 , 1995 | Gelsenkirchen , Germany
58.68 | Liu Limin | December 2 , 1995 | Rio de Janeiro , Brazil
58.29 | Misty Hyman | December 1 , 1996 | Sainte-Foy , Quebec
58.24 | Ayari Aoyama | March 28 , 1997 | Tokyo , Japan
57.79 | Jenny Thompson | April 19 , 1997 | Gothenburg , Sweden
56.90 | Jenny Thompson | December 1 , 1998 | College Station , United States
56.80 | Jenny Thompson | February 12 , 2000 | Paris , France
56.56 | Jenny Thompson | March 18 , 2000 | Athens , Greece
56.55 | Martina Moravcová | January 26 , 2002 | Berlin , Germany
56.34 | Natalie Coughlin | November 22 , 2002 | East Meadow , United States
55.95 | Libby Lenton | August 28 , 2006 | Hobart , Australia
55.89 | Felicity Galvez | April 13 , 2008 | Manchester United Kingdom
55.74 | Libby Trickett | April 26 , 2008 | Canberra , Australia
55.6

In [52]:
response_raw[0]['generated_text'].split('\nAnswer:')[2].split('\n')[0].strip()

'14 December 1979.'

In [74]:
print(response_raw[0]['generated_text'])


This is a demonstration:

Read the table below regarding the "2006 League of Ireland Premier Division". In order to get the answer to the question, you need to combine information from both the table and the text.

Team | Manager | Main sponsor | Kit supplier | Stadium | Capacity
Bohemians | Gareth Farrelly | Des Kelly Carpets | O'Neills | Dalymount Park | 8,500
Bray Wanderers | Eddie Gormley | Slevin Group | Adidas | Carlisle Grounds | 7,000
Cork City | Damien Richardson | Nissan | O'Neills | Turners Cross | 8,000
Derry City | Stephen Kenny | MeteorElectrical.com | Umbro | The Brandywell | 7,700
Drogheda United | Paul Doolin | Murphy Environmental | Jako | United Park | 5,400
Dublin City | Dermot Keely | Carroll 's Irish Gift Stores | Umbro | Dalymount Park | 8,500
Longford Town | Alan Mathews | Flancare | Umbro | Flancare Park | 4,500
Shelbourne | Pat Fenlon | JW Hire | Umbro | Tolka Park | 10,100
Sligo Rovers | Sean Connor | Toher 's | Jako | The Showgrounds | 5,500
St Patrick 's A

In [34]:
for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" to answer the following question.\n\n'
    prompt += f"The table contains important information and this is the introduction of the table:" + '\n' + entry['intro'] + '\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "I believe the following text information will help answer the question:" + '\n' + entry['wiki'] + '\n\n'
        prompt += "Please think step by step. Hint: Using same words from the text and table as answer can achieve better correct rate." + '\n\n'
    prompt += 'Question: ' + question + '\nAnswer:'

    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nAnswer:')[2].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "response": response,
        "answer": answer,
        "table_id": entry["table_id"],
    }

    fw.write(json.dumps(tmp) + "\n")

fw.close()


100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:10<00:00,  2.19s/it]
