In [1]:
import json
import os
import xml.etree.ElementTree as ET
from dotenv import load_dotenv


def load_finqa_sample(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data  # 先取前 5 筆做實驗


dataset = load_finqa_sample("/root/hsin_research/FinQA-main/dataset/train.json")

print(f"Loaded {len(dataset)} samples for testing.")
print("Sample Question 1:", dataset[0]['qa']['question'])

Loaded 6251 samples for testing.
Sample Question 1: what is the the interest expense in 2009?


In [2]:
dataset[0]['qa'].keys()

dict_keys(['question', 'answer', 'explanation', 'ann_table_rows', 'ann_text_rows', 'steps', 'program', 'gold_inds', 'exe_ans', 'tfidftopn', 'program_re', 'model_input'])

In [9]:
# print(dataset[0].keys())
# print(len(dataset[0]['pre_text']))
# print(dataset[0]['pre_text'][0])
# print(len(dataset[0]['post_text']))
# print(len(dataset[0]['table']))

data = []

for i in range(len(dataset)):
    data.append(dataset[i]['qa']['program'])


In [10]:
import re

def curate_functions(string_list, existing_set=None):
    if existing_set is None:
        existing_set = set()

    for entry in string_list:
        # 1. Split the string into individual function calls
        # This regex matches the function name and everything inside its parentheses
        calls = re.findall(r'\w+\([^)]+\)', entry)
        
        for call in calls:
            # 2. Extract function name
            name = call.split('(')[0].strip()
            
            # 3. Extract arguments to determine input count
            args_str = call.split('(')[1].rstrip(')')
            # Split arguments by comma and filter out empty results
            args = [a.strip() for a in args_str.split(',') if a.strip()]
            arity = len(args)
            
            # 4. Add to set as a tuple (Name, InputCount)
            # This ensures uniqueness based on signature
            existing_set.add((name, arity))
            
    return existing_set


function_set = curate_functions(data)

In [11]:
function_set

{('add', 2),
 ('divide', 2),
 ('exp', 2),
 ('greater', 2),
 ('multiply', 2),
 ('subtract', 2),
 ('table_average', 2),
 ('table_max', 2),
 ('table_min', 2),
 ('table_sum', 2)}

In [None]:
import math
from typing import List, Union

def add(a: float, b: float) -> float:
    """
    Calculates the sum of two numbers.

    Args:
        a: The first number to add.
        b: The second number to add.

    Returns:
        The sum of a and b.
    """
    return a + b

def subtract(a: float, b: float) -> float:
    """
    Calculates the difference between two numbers.

    Args:
        a: The number to be subtracted from (minuend).
        b: The number to subtract (subtrahend).

    Returns:
        The difference of a minus b.
    """
    return a - b

def multiply(a: float, b: float) -> float:
    """
    Calculates the product of two numbers.

    Args:
        a: The first factor.
        b: The second factor.

    Returns:
        The product of a and b.
    """
    return a * b

def divide(a: float, b: float) -> float:
    """
    Calculates the ratio of two numbers. Handles division by zero.

    Args:
        a: The dividend (numerator).
        b: The divisor (denominator).

    Returns:
        The result of a divided by b. Returns 0.0 if the divisor is zero to prevent crashes.
    """
    if b == 0:
        return 0.0
    return a / b

def exp(a: float, b: float) -> float:
    """
    Calculates the power of a number.

    Args:
        a: The base.
        b: The exponent.

    Returns:
        The result of a raised to the power of b.
    """
    return math.pow(a, b)

def greater(a: float, b: float) -> bool:
    """
    Compares two numbers to see if the first is larger than the second.

    Args:
        a: The first number to compare.
        b: The second number to compare.

    Returns:
        True if a is greater than b, False otherwise.
    """
    return a > b

def table_sum(values: List[float]) -> float:
    """
    Calculates the total sum of a list of numerical values extracted from a table.

    Args:
        values: A list of floats to be summed.

    Returns:
        The total sum.
    """
    return sum(values)

def table_average(values: List[float]) -> float:
    """
    Calculates the arithmetic mean of a list of numerical values from a table.

    Args:
        values: A list of floats.

    Returns:
        The average value. Returns 0.0 if the list is empty.
    """
    if not values:
        return 0.0
    return sum(values) / len(values)

def table_max(values: List[float]) -> float:
    """
    Identifies the maximum value in a list of numerical values from a table.

    Args:
        values: A list of floats.

    Returns:
        The highest value in the list.
    """
    if not values:
        return 0.0
    return max(values)

def table_min(values: List[float]) -> float:
    """
    Identifies the minimum value in a list of numerical values from a table.

    Args:
        values: A list of floats.

    Returns:
        The lowest value in the list.
    """
    if not values:
        return 0.0
    return min(values)