## Imports

In [1]:
import os
from dotenv import load_dotenv
from dataClass import DataTable
from langchain_mistralai import ChatMistralAI
import json

In [2]:
def list_files_in_folder(folder_path):
    try:
        # Get a list of all entries in the directory
        entries = os.listdir(folder_path)
        
        # Filter out the files from the entries
        files = [os.path.join(folder_path, entry) for entry in entries if os.path.isfile(os.path.join(folder_path, entry))]
        
        return files
    except FileNotFoundError:
        return f"The folder '{folder_path}' does not exist."
    except Exception as e:
        return f"An error occurred: {e}"

## Load Data

In [3]:
load_dotenv()
mistral_api_key = os.getenv("MISTRAL_API_KEY")
lamapi_key = os.getenv("LAMAPI_KEY")
model = "open-mixtral-8x7b"
llm = ChatMistralAI(model=model, temperature=0, api_key=mistral_api_key)
tables_path = 'data/HardTablesR1/DataSets/HardTablesR1/Valid/tables'
tables = list_files_in_folder(tables_path)
print(f"\nNumber of tables: {len(tables)}\n")


Number of tables: 200



In [4]:
t = DataTable(tables[1])

In [5]:
print(f"Table name: {t.name}")
print(f"Table shape: {t.data.shape}")

Table name: DFRU6OJ0
Table shape: (6, 2)


In [6]:
t.data

Unnamed: 0,col0,col1
0,predlitz,157
1,stadtbergen,164
2,michelstetten,157
3,ahorn,169
4,stillfried,169
5,oberweg,164


## Generate table description

In [7]:
from prompts import generate_tableDesc_prompt

prompt = generate_tableDesc_prompt(t.data)
print(prompt)

For every column write a short description, in the format: {"0": "description of column 0", "1": "description of column 1", ..., "n": "description of column n"}. Cells of the same column represent must have be of the same type (all football teams, all actors, etc.)

Table:
| col0 | col1 |
|------|------|
| predlitz | 157 |
| stadtbergen | 164 |
| michelstetten | 157 |
| ahorn | 169 |
| stillfried | 169 |
| oberweg | 164 |





In [8]:
t.generate_t_description(llm)
print(t.t_desc)

{'0': 'Column 0 contains names of places, which could be towns or villages.', '1': 'Column 1 contains numbers, which could represent various metrics associated with the corresponding place in column 0, such as population size, elevation, or some other relevant statistic.'}


## Columns Classification

In [9]:
from prompts import generate_NER_prompt

prompt = generate_NER_prompt(t.data)
print(prompt)

You are to classify each column in a given table as either a Named Entity Column (NEC) or a Literal Column (LC).

Definitions:
- Named Entity Column (NEC): Columns that contain names of people, organizations, locations, or other proper nouns.
- Literal Column (LC): Columns that contain numerical values, dates, measurements, or other literal values.

Examples:
- Named Entity Column (NEC) Examples:
  - Column with values: ["John Doe", "Jane Smith", "Company XYZ", "Paris"]
  - Column with values: ["Microsoft", "Apple", "Google", "Amazon"]

- Literal Column (LC) Examples:
  - Column with values: [34, 56, 78, 23]
  - Column with values: ["2021-01-01", "2022-05-12", "2023-08-23"]
  - Column with values: [5.6, 3.4, 2.8, 4.5]

Table for Classification:
| col0 | col1 |
|------|------|
| predlitz | 157 |
| stadtbergen | 164 |
| michelstetten | 157 |
| ahorn | 169 |
| stillfried | 169 |
| oberweg | 164 |


Classification Request:
Based on the above definitions and examples, please classify each c

In [10]:
t.generate_ner_labels(llm)
print(t.ner)

{'0': 'NEC', '1': 'LC'}


## Cell Entity Annotation

In [11]:
from prompts import generate_CEA_prompt_with_t_desc
from dataClass import LamAPI

cell_content = t.data.iloc[0,0] # Show prompt only for the first cell
ER = LamAPI(cell_content)
prompt = generate_CEA_prompt_with_t_desc(t.data, cell_content, ER, t.t_desc)
print(prompt)

Based on the table, table description, cell content and retrieved entities and their types you have to associate the cell content to one of the retrieved entities 

Table:
| col0 | col1 |
|------|------|
| predlitz | 157 |
| stadtbergen | 164 |
| michelstetten | 157 |
| ahorn | 169 |
| stillfried | 169 |
| oberweg | 164 |


Table Description: {'0': 'Column 0 contains names of places, which could be towns or villages.', '1': 'Column 1 contains numbers, which could represent various metrics associated with the corresponding place in column 0, such as population size, elevation, or some other relevant statistic.'}

Cell Content: predlitz

Retrieved Entities and their types: {id: Q28020967, Predlitz, types: { Wikimedia disambiguation page}}, {id: Q42344347, Predlitz, types: { cadastral municipality of Austria}, { locality}}, {id: Q82019201, Preilitz, types: { locality}}, {id: Q29572069, Pr\u00f6dlitz, types: { Wikimedia disambiguation page}}, {id: Q37777803, Redlitz, types: { locality}, { 

In [13]:
t.generate_cea_annotatons(llm)
print(t.cea)

{'(0, 0)': {'id': 'Q42344347', 'llm_output': '[[[Q42344347]]]'}, '(1, 0)': {'id': 'Q503300', 'llm_output': '[[[Q503300]]]'}, '(2, 0)': {'id': 'Q33103647', 'llm_output': '[[[Q33103647]]]'}, '(3, 0)': {'id': 'Q29015792', 'llm_output': '[[[Q29015792]]]'}, '(4, 0)': {'id': 'Q2349606', 'llm_output': '[[[Q2349606]]]'}, '(5, 0)': {'id': 'Q696108', 'llm_output': '[[[Q696108]]]'}}


## Evaluation

In [14]:
import pandas as pd
gt = pd.read_csv('data/HardTablesR1/DataSets/HardTablesR1/Valid/gt/cea_gt.csv', header=None)  

print(t.name)

DFRU6OJ0


In [18]:
filtered_df = gt[gt[0] == 'DFRU6OJ0']
print(filtered_df)

            0  1  2                                         3
595  DFRU6OJ0  1  0  http://www.wikidata.org/entity/Q42344347
596  DFRU6OJ0  2  0  http://www.wikidata.org/entity/Q41709246
597  DFRU6OJ0  3  0  http://www.wikidata.org/entity/Q33103647
598  DFRU6OJ0  4  0  http://www.wikidata.org/entity/Q29015792
599  DFRU6OJ0  5  0   http://www.wikidata.org/entity/Q2349606
600  DFRU6OJ0  6  0   http://www.wikidata.org/entity/Q2011429


In [19]:
pred = []
for cell in t.cea:
    print(cell, t.cea[cell]['id'])

(0, 0) Q42344347
(1, 0) Q503300
(2, 0) Q33103647
(3, 0) Q29015792
(4, 0) Q2349606
(5, 0) Q696108
