In [1]:
import os
import json
import argparse
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import requests

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--option", default="cot", type=str)
parser.add_argument("--model", default="llama2-70b", type=str, help=" ")
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=None, type=int)
parser.add_argument(
    "--temperature",
    type=float,
    default=0.5,
    help="temperature of 0 implies greedy sampling.",
)
parser.add_argument(
    "--traced_json_file",
    default=r"traced.json",#traced file
    type=str,
)
parser.add_argument(
    "--tables_json_file",
    default=r"tables.json",#table files
    type=str,
)
parser.add_argument(
    "--topk_path",
    default=r"request_tok",#text files
    
    type=str,
)

args = parser.parse_args("")

In [3]:
demonstration = {}
demonstration["none"] = ""


In [4]:
def read_data(args):
    # Read traced JSON file
    data_test_traced = json.load(open(args.traced_json_file, "r"))
    data_list = []
    for sample in tqdm(data_test_traced[args.start:args.end]):
        table_id = sample["table_id"]
        question_data = None
        for q_data in questions_data:
            if q_data['table_id'] == table_id:
                question_data = q_data
                break
        if question_data is None:
            print(f"No question data found for {table_id}")
            continue
        
        # Read JSON file from tables_tok
        try:
            tables_tok_path = f"{table_id}.json"  # put your traced table link
            with open(tables_tok_path, 'r') as f:
                table_data = json.load(f)
        except Exception:
            print(f"The file {table_id} does not exist.")
            continue

        question_type = question_data['type']
        if question_type == 'bridge':
            # Get index of the most relevant row
            row_index = question_data['row_pre']
            relevant_rows = [table_data['data'][row_index]]
        elif question_type == 'comparison':
            # Get indices of all rows with relevance less than 1.0
            row_pre_logits = question_data['row_pre_logit']
            relevant_rows = [table_data['data'][i] for i, logit in enumerate(row_pre_logits) if logit < 1.0]
        else:
            print(f"Unknown question type: {question_type}")
            continue

        # Read text data
        try:
            text_file = os.path.join(args.text_path, f"{table_id}.json")
            with open(text_file, "r") as f:
                text_data = json.load(f)
        except Exception:
            print(f"The file {text_file} does not exist.")
            continue
            
        question_text = sample["question"]
        answer_text = sample["answer-text"]
        

        # Extract wiki links from nodes
        wikis = [
            node[2]
            for node in sample.get("answer-node", [])
            if node[2] is not None and node[2].startswith("/wiki")
        ]

        # Get the corresponding text for each wiki link
        wiki_texts = []
        for wiki_link in wikis:
            wiki_text = text_data.get(wiki_link, "")
            wiki_texts.append(wiki_text)

        # Concatenate wiki_texts into a string, separating each wiki's text content with newline
        wiki_text = "\n".join(wiki_texts)


        # Create a DataFrame from the table data
        df = pd.DataFrame(
            [tuple(zip(*row))[0] for row in table_data["data"]],
            columns=list(zip(*table_data["header"]))[0],
        )

        # Create a DataFrame from the relevant rows
        try:
            # Flatten the table header
            flattened_header = [col[0] for col in table_data["header"]]
            # Flatten the relevant rows
            flattened_data = [[cell if not isinstance(cell, list) else cell[0] for cell in row] for row in relevant_rows]
            df = pd.DataFrame(flattened_data, columns=flattened_header)
        except Exception as e:
            print(f"Error creating DataFrame for {table_id}: {e}")
            continue

        data_list.append({
            "table_id": table_id,
            "question": question_text,
            "answer": answer_text,
            "table": df,
            "wiki": wiki_text,
            "title": table_data["title"],
            "intro": table_data["intro"]
        })

    return data_list

# Read questions data from the development set's standard answers
questions_path = "dev"
with open(questions_path, 'r') as f:
    questions_data = json.load(f)

def df_format(data):
    try:
        formatted_str = " | ".join(data.columns) + "\n"
        for _, row in data.iterrows():
            row_str = " | ".join([str(row[col]) for col in data.columns])
            formatted_str += row_str + "\n"
        return formatted_str
    except Exception as e:
        #print(f"Error formatting table: {data}, error: {e}")
        return ""


In [5]:
#Load model or API

In [22]:
'''
now = datetime.now()
dt_string = now.strftime("%d_%H_%M")
fw = open(f"outputs/response_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)
tmp = {"demonstration": demonstration[args.option]}
fw.write(json.dumps(tmp) + "\n")
'''

'\nnow = datetime.now()\ndt_string = now.strftime("%d_%H_%M")\nfw = open(f"outputs/response_s{args.start}_e{args.end}_{args.option}_{args.model}_{dt_string}.json", "w",)\ntmp = {"demonstration": demonstration[args.option]}\nfw.write(json.dumps(tmp) + "\n")\n'

In [6]:
data_list = read_data(args)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.30it/s]


In [7]:
import os
from tqdm import tqdm
from datetime import datetime
import json

# Ensure the output directory exists
output_dir = "outputs/summary"
os.makedirs(output_dir, exist_ok=True)


for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" and create a summary.\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "Text" + '\n' + entry['wiki'] + '\n\n'

    prompt += 'Summarize the given table and text. ' + '\nSummary:'

    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nSummary:')[1].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    output_file_path = os.path.join(output_dir, f"{entry['table_id']}.txt")

    with open(output_file_path, "w", encoding="utf-8") as fw:
        fw.write(f"Prompt:\n{prompt}\n\n")
        fw.write(f"Summary: {response}\n")


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.78s/it]


In [15]:
for entry in tqdm(data_list):
    question = entry['question']
    answer = entry['answer']

    #### Formalizing the k-shot demonstration. #####
    prompt = demonstration[args.option] + '\n\n'
    prompt += f'Read the table and text regarding "{entry["title"]}" and create a summary.\n\n'
    #prompt += f"This is the introduction of the table:" + '\n' + entry['intro'] + '\n\n'
    prompt += df_format(entry['table']) + '\n'
    
    if entry['wiki']:
        prompt += "Text" + '\n' + entry['wiki'] + '\n\n'

    prompt += 'Summarize the given table and text. ' + '\nSummary:'

    response_raw = query({'inputs': prompt})
    try:
        response = response_raw[0].get('generated_text', '').split('\nSummary:')[1].split('Reasoning process')[0].strip()
    except KeyError:
        response = ''

    response = response.split('\n')[0].strip()

    tmp = {
        "question": question,
        "summary": response,
        "table_id": entry["table_id"],
    }

    fw.write(json.dumps(tmp) + "\n")

fw.close()

100%|████████████████████████████████████████████████████████████████████████████████| 200/200 [09:06<00:00,  2.73s/it]


In [13]:
print(response_raw)

{'error': 'Rate limit reached. You reached PRO hourly usage limit. Use Inference Endpoints (dedicated) to scale your endpoint.'}
