## Imports

In [44]:
import pandas as pd
import os
from main import list_files_in_folder
from dotenv import load_dotenv
from dataClass import DataTable
from langchain_mistralai import ChatMistralAI
import json
from tqdm import tqdm
from utils import get_table_str, candidates_as_str, build_prompt
from prompts import generate_CEA_prompt_with_t_desc
import time

## LLMs

In [20]:
mistral_api_key = os.getenv("MISTRAL_API_KEY")
model_22 = "open-mixtral-8x22b"
model_7 = "open-mixtral-8x7b"
llm_22 = ChatMistralAI(model=model_22, temperature=0, api_key=mistral_api_key)
llm_7 = ChatMistralAI(model=model_7, temperature=0, api_key=mistral_api_key)

## Data

In [36]:
# This json contains for each cell to be annotated, a list of possible candidates
# retrieved with LamAPI but without the right candidate.
with open('nogit/HardTablesR1_Valid_CEA_ER_without_gt.json') as f:
    data = json.load(f) 
with open('nogit/table_descriptionsHTV.json') as f:
    table_descriptions = json.load(f) 

gt_path = 'data/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv'
gt = pd.read_csv(gt_path, header=None)

print(f"Number of Tables: {len(data)}")

Number of Tables: 200


In [37]:
ncells = 0
for k, v in data.items():
    ncells += len(v)
print(f"Number of cells: {ncells}")

if ncells != len(gt):
    print("Cells in ground_truth don't match cells in test set.")

Number of cells: 1406


In [43]:
table_name = None
export = {}
for i, (name, r, c, l)  in tqdm(gt.iterrows()):
    print(i, name, c, r, l)
    
    if name != table_name:
        table_name = name
        table = DataTable(f"data/HardTablesR1/DataSets/HardTablesR1/Valid/tables/{name}.csv")
        table.t_desc = table_descriptions[table_name]
        export[name] = {}
        table_as_str = get_table_str(table.data)

    
    print(f"\nTable:\n {table.data}")
    print(f"\nTable Description:\n {table.t_desc}")
    
    print(data[name].keys())
    
    
    # Perform CEA (we want all NILs):
    cell_content = data[name][str((r, c))]['cell']
    prompt = generate_CEA_prompt_with_t_desc(table.data, cell_content, candidates_as_str(data[name][str((r, c))]['retrieved_list']), table.t_desc)
    #print(f"\nPrompt:\n{prompt}\n\n")
    out = llm_7.invoke(prompt)
    time.sleep(2)
    # print(out.content)
    y_true.append(target_id)
    y_pred.append(out.content)
    index.append(i)
    export[name][str((r, c))] = {
        'cell': cell_content,
        'table_desc': table.t_desc,
        'cea_prompt': prompt,
        'cea_model': llm_7.model,
        'model_out': out.content
    }

0it [00:00, ?it/s]

0 NQK7B1JD 0 1 http://www.wikidata.org/entity/Q7996268

Table:
                    col0  col1   col2
0      lincoln township   209  42.56
1  stony creek township   265  28.28
2     hartford township   257  24.08

Table Description:
 {'col0': "Column 0 contains names of townships, which are small divisions of a county responsible for providing local services. The townships in this column are from an unknown geographical region in the United States, as they follow the common naming convention of using 'township' in their name.", 'col1': 'Column 1 contains integer values that likely represent population counts for each township. These values range from 209 to 265, indicating small to moderately-sized townships.', 'col2': 'Column 2 contains decimal values that likely represent the geographical area of each township in square miles. These values range from 24.08 to 42.56, indicating that the townships are relatively small in size.'}
dict_keys(['(1, 0)', '(2, 0)', '(3, 0)'])





In [34]:
with open('nogit/table_descriptionsHTV.json', 'w') as f:
    json.dump(table_descriptions, f)

In [46]:
tables_path = 'data/SemTab2020_Table_GT_Target/Round1/tables'
tables = list_files_in_folder(tables_path)
len(tables)

34294