In [None]:
import numpy as np
import pandas as pd
from ipywidgets import FloatProgress
import re
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn import metrics
import json
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import os
from tqdm import tqdm
import ast
import re
import math
from collections import Counter, defaultdict
import re
from rank_bm25 import BM25Okapi
import jieba

#### Please modify this section to include a usable API

In [None]:
custom_base_url = "http://change me"
model_name = 'change me'
api_key = 'change me'
llm = ChatOpenAI(model_name=model_name, openai_api_key=api_key, base_url=custom_base_url)

# ICL for SLU prediction

In [None]:
json_file = '../dataset/train.json'
with open(json_file, 'r', encoding='utf-8') as f:
    datas = json.load(f)
all_intent  = []
for data in datas:
    intent = data['intent']
    if intent not in all_intent:
        all_intent.append(intent)

In [None]:
corpus = all_intent
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in corpus]
# 构建索引
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
template_path = './prompts/SLU_prompt.json'
with open(template_path, 'r', encoding='utf-8') as f:
    template = json.load(f)

In [None]:
def extract_intent_and_slots(text):
    # Regular expression patterns to extract intent and slots
    intent_pattern = r"intent[:>)]?\s*(.*?)(?=[<\(]_?slots_?[>\)])"
    slots_pattern = r"[<\(]_?slots_?[>\)]\s*(\[.*?\])"
    #intent_pattern = r"intent[:>)]?\s*([^,]+?)(?=[<\(]_?slots_?[:>)]?)"
    #slots_pattern = r"[<\(]_?slots_?[:>)]?\s*(\[.*?\])"

    # Extracting intent
    intent_match = re.search(intent_pattern, text)
    if intent_match:
        intent = intent_match.group(1).strip()
    else:
        intent = None

    # Extracting slots
    slots_match = re.search(slots_pattern, text)
    if slots_match:
        slots_str = slots_match.group(1)
        try:
            slots_list = ast.literal_eval(slots_str)  # Safely evaluates the string as a list
        except (SyntaxError, ValueError) as e:
            print(f"Error parsing slots: {e}")
            slots_list = None
    else:
        slots_list = None

    return intent, slots_list

def slots_to_dict(slots_list):
    slots_dict = {}
    if slots_list:
        for slot in slots_list:
            key, value = slot.split(':')
            slots_dict[key.strip()] = value.strip()
    return slots_dict



def extract_slu(text: dict):
    pred = text.lower().strip()
    pred = pred.replace('_', '')
    pred = pred.replace('\n', '')
    if 'intent:' in pred:
        pred = pred.replace('intent:', '<intent>')
    if 'slots:' in pred:
        pred = pred.replace('slots:', '<slots>')
    if ' slots ' in pred:
        pred = pred.replace(' slots ', '<slots>')
    if '」' in pred: 
        pred = pred.replace('」', '\']')
    intent, slots = extract_intent_and_slots(pred)
    intent = bm25.get_top_n(list(jieba.cut_for_search(intent)), corpus, n=1)
    if intent != None and slots != None:
        extract_intent = intent
        extract__slots = slots
    else:
        print(f'This：{pred}')
        extract_intent = 'error'
        extract__slots = text
    return extract_intent, extract__slots

In [None]:
# Convert slots into keywords
def restore_keywords_from_query(query, slots):
    keywords = []
    current_tokens = []
    current_label = None
    query = list(query)
    if isinstance(slots, str):
        slots = slots.split(' ')
    if slots[0] == '[CLS]':
        slots = slots[1:-1]

    for token, slot in zip(query, slots):
        if slot.startswith('B-'):
            if current_tokens:
                keywords.append((''.join(current_tokens), current_label))
                current_tokens = []
            current_label = slot[2:]
            current_tokens.append(token)
        elif slot.startswith('I-') and current_label == slot[2:]:
            current_tokens.append(token)
        else:
            if current_tokens:
                keywords.append((''.join(current_tokens), current_label))
                current_tokens = []
                current_label = None

    if current_tokens:
        keywords.append((''.join(current_tokens), current_label))
    keyword_pair = []
    for keyword in keywords:
        if keyword[-1] == 'city':
            keyword_pair.append(f'城市:{keyword[0]}')
        elif keyword[-1] == 'district':
            keyword_pair.append(f'区域:{keyword[0]}')
        elif keyword[-1] == 'development':
            keyword_pair.append(f'项目名称:{keyword[0]}')
        elif keyword[-1] == 'company':
            keyword_pair.append(f'企业名称:{keyword[0]}')
        elif keyword[-1] == 'year':
            keyword_pair.append(f'年份:{keyword[0]}')
        elif keyword[-1] == 'month':
            keyword_pair.append(f'月份:{keyword[0]}')
        elif keyword[-1] == 'land':
            keyword_pair.append(f'地块名称:{keyword[0]}')

    return keyword_pair

In [None]:

json_file = '../dataset/test.json'
with open(json_file, 'r', encoding='utf-8') as f:
    datas = json.load(f)

for data in tqdm(datas):
    query = data['query']
    slots = data["slots"].split(',')
    data['true_slots_name'] = restore_keywords_from_query(query, slots)
    prompt = ChatPromptTemplate.from_template(template)
    chain = prompt | llm
    key_words = restore_keywords_from_query(query, slots)
    response = chain.invoke({"query":query})
    # print(response.content)
    data['ICL_pred_intent'], data['ICL_pred_slots'] = extract_slu(response.content)

save_json_file = './results/test-with-ICL_pred_intent+slots.json'
with open(save_json_file, 'w', encoding='utf-8') as file:
    json.dump(datas, file, ensure_ascii=False, indent=4)

### Calculate SLU prediction metrics

In [None]:
def metric_compute(trues: list, preds: list):
    if len(trues) != len(preds):
        return 'Input lengthes not equal!'
    precision = 0
    precision_all = 0
    recall = 0
    recall_all = 0
    for true_label, pred_label in zip(trues, preds):
        if isinstance(true_label, type('')):
            true_label = [true_label]
        if isinstance(pred_label, type('')):
            pred_label = [pred_label]
        for pred in pred_label:
            if pred in true_label:
                precision += 1
            precision_all += 1
        for true in true_label:
            if true in pred_label:
                recall += 1
            recall_all += 1
    P = precision/precision_all
    R = recall/recall_all
    F1 = 2 * P * R / (P + R)
    return P, R, F1

In [None]:
datas_saved_file = './results/test-with-ICL_pred_intent+slots.json'
with open(datas_saved_file, 'r', encoding='utf-8') as f:
    datas = json.load(f)
predicts = []
trues = []

for data in datas:
    trues.append(data['intent'])
    try:
        predicts.append(data['ICL_pred_intent'])
    except:
        print(data)

print('ICL Intent prediction results')
metric_compute(trues, predicts)

In [None]:
datas_saved_file = './results/test-with-ICL_pred_intent+slots.json'
with open(datas_saved_file, 'r', encoding='utf-8') as f:
    datas = json.load(f)
predicts = []
trues = []

for data in datas:
    trues.append(data['true_slots_name'])
    try:
        predicts.append(data['ICL_pred_slots'])
    except:
        print(data)

print('ICL Slots prediction results')
metric_compute(trues, predicts)

# SR module

In [None]:
# Store the generated table names in a JSON file.
class PredictResultStorage:
    def __init__(self, file_name='testdata.json'):
        self.current_data = {}
        self.file_name = file_name

    def set_predict_tabel_name(self, predict_tabel_name):
        self.current_data["predict_tabel_name"] = predict_tabel_name
        
    def set_true_tabel_name(self, true_tabel_name):
        self.current_data["true_tabel_name"] = true_tabel_name

    def set_uuid(self, uuid):
        self.current_data["uuid"] = uuid

    def save_data(self):
        if os.path.exists(self.file_name):
            with open(self.file_name, 'r+') as f:
                data = json.load(f)
                data.append(self.current_data)
                f.seek(0)
                json.dump(data, f, indent=4)
        else:
            with open(self.file_name, 'w') as f:
                json.dump([self.current_data], f, indent=4)
        # Clear current_data to prepare for storing the next piece of data.
        self.current_data = {}


In [None]:
# read templates
prompt_file_name = './prompts/retrival_prompt.json'
with open(prompt_file_name, 'r', encoding='utf-8') as f:
    templates = json.load(f)
#templates

In [None]:
result_saved_file = './results/SLUTQA(ICL)-SR-module-result.json'
json_file = './results/test-with-ICL_pred_intent+slots.json'
with open(json_file, 'r', encoding='utf-8') as f:
    datas = json.load(f)
for data in tqdm(datas):
    try:
        query = data['query']
        intent = data['ICL_pred_intent'][0]
        slots = data['ICL_pred_slots']
        prompt = ChatPromptTemplate.from_template(templates[intent])
        chain = prompt | llm
        response = chain.invoke({"input_query":query,"intent":intent, "slots":slots})
        data['predict_tabel_caption'] = response.content
    except KeyboardInterrupt:
        print(f"An unexpected interruption occurred at record {i}.")
        break

In [None]:
# Read the table names and store them as a vector in memory.
table_names_file = './table_names.json'
with open(table_names_file, 'r', encoding='utf-8') as f:
    name_data = json.load(f)
# Corpus
corpus = name_data
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in corpus]
# Build the index
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
predicts = []
bm_25s = []
trues = []

for data in tqdm(datas):
    # Perform preliminary cleaning on content generated by the LLM
    if '：' in data['predict_tabel_caption'] or ':' in data['predict_tabel_caption']:
        #print(data['predict_tabel_name'])
        if '：' in data['predict_tabel_caption']:
            idx = data['predict_tabel_caption'].rfind('：')
        if ':' in data['predict_tabel_caption']:
            idx = data['predict_tabel_caption'].rfind(':')
        data['predict_tabel_caption'] = data['predict_tabel_caption'][idx+1:]
    if ',' in data['predict_tabel_caption'] and '[' in data['predict_tabel_caption']:
        #print(data['predict_tabel_name'])
        #print('%%%%%%%%%%')
        predict = data['predict_tabel_caption'].split(',')
        for pred in predict:
            pred = pred.replace("[", "")
            pred = pred.replace("]", "")
            pred = pred.replace("'", "")
            pred = pred.replace('"', '')
    else:
        predict = data['predict_tabel_caption']
        predict = predict.strip("[]'\"")
        if ',' in predict:
            predict = predict.split(',')
    # Remove the spaces
    if isinstance(predict, type([])):
        bm_25 = predict.copy()
        for i,pred in enumerate(predict):
            pred = pred.replace(" ", "")
            pred = pred.replace("[", "")
            pred = pred.replace("]", "")
            pred = pred.replace('\\', "")
            pred = pred.replace("\"", "")
            pred = pred.replace("\'", "")
            pred = pred.replace("\n", "")
            predict[i] = pred
            # result = retriever.invoke(pred)
            # Retrieve BM25 matching results
            tokenized_query = list(jieba.cut_for_search(pred))
            result = bm25.get_top_n(tokenized_query, corpus, n=1)
            # result = [doc.page_content for doc in result]
            bm_25[i] = result[0]
    if isinstance(predict, type([])):
        predict = list(set(predict))
        bm_25 = list(set(bm_25))
            
    elif isinstance(predict, type('')):
        pred = predict
        pred = pred.replace(" ", "")
        pred = pred.replace("[", "")
        pred = pred.replace("]", "")
        pred = pred.replace('\\', "")
        pred = pred.replace("\"", "")
        pred = pred.replace("\'", "")
        pred = pred.replace("\n", "")
        predict = pred
        #result = retriever.invoke(pred)
        tokenized_query = list(jieba.cut_for_search(pred))
        result = bm25.get_top_n(tokenized_query, corpus, n=1)
        # result = [doc.page_content for doc in result]
        bm_25 = result[0]
    # Extract Chinese characters and create a new list
    #predict = [re.sub(r'[^\w\s\u4e00-\u9fff]', '', s) for s in predict]
    predicts.append(predict)
    bm_25s.append(bm_25)
    true = data['table_caption_label']
    trues.append(true)
    data['bm25_pred_tabel_caption'] = bm_25
print(f'tabel_caption score：{metric_compute(trues, bm_25s)}')


# SFA module

### Markdown

In [None]:
file_name = './prompts/first_prompts.json'
with open(file_name, 'r', encoding='utf-8') as f:
    first_templates = json.load(f)


file_name = './prompts/second_prompts.json'
with open(file_name, 'r', encoding='utf-8') as f:
    second_templates = json.load(f)

In [None]:
md_tables_json_file = '../database/markdown_tables/all_markdown_table.json'
with open(md_tables_json_file, 'r', encoding='utf-8') as f:
    md_table_datas = json.load(f)


for data in datas:
    names = data['bm25_pred_tabel_caption']
    if isinstance(names, type('')):
        names = [names]
    all_table = ''
    for name in names:
        for md_table_data in md_table_datas:
            if md_table_data['table_name'] == name:
                table = md_table_data['markdown_table']
                all_table = all_table + ' ' + table
    data['pred_markdown_table'] = all_table[1:]

for data in tqdm(datas):
    try:
        query = data['query']
        intent = data['ICL_pred_intent'][0]
        template1 = first_templates[intent]
        template2 = second_templates[intent]
        prompt1 = ChatPromptTemplate.from_template(template1)
        chain1 = prompt1 | llm
        prompt2 = ChatPromptTemplate.from_template(template2)
        chain2 = prompt2 | llm
        if '+' in intent:
            intent = intent.split('+')
        else:
            intent = [intent]
        slots = data['ICL_pred_slots']
        input_markdown_table = data["pred_markdown_table"]
        
        response1 = chain1.invoke({"input_query":query,"intent":intent,"slots":slots,"table":input_markdown_table})
        # print(response.content)
        simple_table = response1.content
        response2 = chain2.invoke({"input_query":query,"intent":intent,"slots":slots,"table":simple_table})
        data['predict_markdown_answer'] = response2.content
    except KeyboardInterrupt:
        print(f"An unexpected interruption occurred at record {i}.")
        break

## Calculate metrics for Markdown results

In [None]:
from utils import *
import json
import re
import argparse
from collections import Counter, defaultdict
from rouge_score import rouge_scorer, scoring
from datasets import load_from_disk
from evaluate import load as evaluate_load
from tabulate import tabulate

In [None]:
def get_rows_columns_cells(line):
    line=line.lower()
    if "col :" in line:
        line=line.split("col :")[1].strip()
    lines=re.split("\s+row\s+[0-9]+\s+:\s+",line)
    rows=[" | ".join([cell.strip() for cell in row.split("|")]) for row in lines[1:]]
    cells=[cell.strip() for row in lines[1:] for cell in row.split("|")]
    columns=[" | ".join([elem.strip() for elem in elems]) for elems in list(zip(*[row.split(" | ") for row in lines]))]
    return rows, columns,cells

def get_correct_total_prediction(target_str, pred_str):
    target_rows, target_columns, target_cells = get_rows_columns_cells(target_str)
    prediction_rows, prediction_columns, prediction_cells = get_rows_columns_cells(pred_str)
    common_rows = Counter(target_rows) & Counter(prediction_rows)
    common_rows = list(common_rows.elements())
    common_columns = Counter(target_columns) & Counter(prediction_columns)
    common_columns = list(common_columns.elements())
    common_cells = Counter(target_cells) & Counter(prediction_cells)
    common_cells = list(common_cells.elements())
    return {"target_rows":target_rows,
          "target_columns":target_columns,
          "target_cells":target_cells,
          "pred_rows":prediction_rows,
          "pred_columns":prediction_columns,
          "pred_cells":prediction_cells,
          "correct_rows":common_rows,
          "correct_columns":common_columns,
          "correct_cells":common_cells}

def compute_result(results: list):
    total_columns_in_dataset = 0
    total_rows_in_dataset = 0
    total_cells_in_dataset = 0
    total_correct_rows = 0
    total_correct_columns = 0
    total_correct_cells = 0
    total_prediced_rows_in_dataset = 0
    total_predicted_columns_in_dataset = 0
    total_predicted_cells_in_dataset = 0
    preds = []
    trues = []
    for result in results:
        true = result['markdown_answer']
        pred = result['predict_markdown_answer']
        trues.append(true)
        preds.append(pred)
        statistics = get_correct_total_prediction(true, pred)
        total_columns_in_dataset += len(statistics['target_columns'])
        total_rows_in_dataset += len(statistics['target_rows'])
        total_cells_in_dataset += len(statistics['target_cells'])
        total_correct_columns += len(statistics['correct_columns'])
        total_correct_rows += len(statistics['correct_rows'])
        total_correct_cells += len(statistics['correct_cells'])
        total_prediced_rows_in_dataset += len(statistics['pred_rows'])
        total_predicted_columns_in_dataset += len(statistics['pred_columns'])
        total_predicted_cells_in_dataset += len(statistics['pred_cells'])
    
    exact_match_metric = evaluate_load("exact_match")
    row_precision = total_correct_rows / total_prediced_rows_in_dataset
    row_recall = total_correct_rows / total_rows_in_dataset
    row_f1 = (2*row_precision*row_recall)/(row_precision+row_recall)
    exact_match_score = exact_match_metric.compute(predictions=preds, references=trues)['exact_match']
    column_precision = total_correct_columns / total_predicted_columns_in_dataset
    column_recall = total_correct_columns / total_columns_in_dataset
    column_f1 = (2*column_precision*column_recall)/(column_precision+column_recall)
    cell_precision = total_correct_cells / total_predicted_cells_in_dataset
    cell_recall = total_correct_cells / total_cells_in_dataset
    cell_f1 = (2*cell_precision*cell_recall)/(cell_precision+cell_recall)
    
    headers = ["Metric", "Row", "Column", "Cell"]
    table = [
        ["Precision", f"{row_precision:.4f}", f"{column_precision:.4f}", f"{cell_precision:.4f}"],
        ["Recall", f"{row_recall:.4f}", f"{column_recall:.4f}", f"{cell_recall:.4f}"],
        ["F1 Score", f"{row_f1:.4f}", f"{column_f1:.4f}", f"{cell_f1:.4f}"]
    ]
    print(f"Table EM: {exact_match_score:.4f}")
    print(tabulate(table, headers=headers, tablefmt="grid"))

In [None]:
compute_result(datas)

# SQL

In [None]:
# read SQL templates
prompt_file_name = './prompts/SQL_prompt.json'
with open(prompt_file_name, 'r', encoding='utf-8') as f:
    templates = json.load(f)

In [None]:
for data in tqdm(datas):
    query = data['query']
    intent = data["ICL_pred_intent"][0]
    prompt = ChatPromptTemplate.from_template(templates[intent])
    chain = prompt | llm
    if '+' in intent:
        intent = intent.split('+')
    else:
        intent = [intent]
    if set(intent).issubset(house_sales_field):
        dbname = '价格查询'
    elif set(intent).issubset(land_sales_field):
        dbname = '土地资产'
    elif set(intent).issubset(enterprise_sales_field):
        dbname = '企业财务'
    slots = data['ICL_pred_slots']
    pred_table_name = data['bm25_pred_tabel_caption']
    response = chain.invoke({"query":query, "intent":intent, "slots":slots, "table_name":pred_table_name})
    sql_statement = response.content
    data['predict_SQL'] = sql_statement
    executor = PostgresQueryExecutor(database=dbname)
    try:
        table_heads, table_results = executor.execute_sql(sql_statement)
        data['predict_SQL_answer'] = table_results
    except:
        data['predict_SQL_answer'] = 'error'
    executor.close()

## Calculate metrics for SQL results

In [None]:
def pass_at_k(n, c, k):
    """
    :param n: total number of samples
    :param c: number of correct samples
    :param k: k in pass@$k$
    """
    if n - c < k:
        return 1.0
    return 1.0 - np.prod(1.0 - k /np.arange(n - c + 1, n + 1))

def calculate_average_pass1(data):
    total_pass1 = sum(item['pass1'] for item in data.values())
    count = len(data)
    if count == 0:
        return 0  
    return total_pass1 / count

In [None]:
i = 0
uuids = {}
for data in datas:
    table_results = data['predict_SQL_answer']
    true_SQL_answer = data['SQL_answer']
    if table_results is not None and all(isinstance(i, list) for i in table_results):
        table_results = [tuple(sublist) for sublist in table_results]
    if true_SQL_answer is not None and all(isinstance(i, list) for i in true_SQL_answer):
        true_SQL_answer = [tuple(sublist) for sublist in true_SQL_answer]
    if table_results != None and set(table_results) == set(true_SQL_answer) and len(table_results) == len(true_SQL_answer):
        data['predict_correctness'] = True
    else:
        data['predict_correctness'] = False
    if f'{i}' not in uuids.keys():
        uuids[f'{i}'] = []
    uuids[f'{i}'].append(data['predict_correctness'])

for key in uuids.keys():
    c = sum(uuids[key])
    n = len(uuids[key])
    uuids[key] = {'c': c, 'n': n}
    uuids[key]['pass1'] = pass_at_k(n = uuids[key]['n'], c = uuids[key]['c'],k=1)
print(f'{len(datas)}')
print(f'pass@1：{calculate_average_pass1(uuids)}')

preds = []
trues = []
unexecutable_sql = 0
all_sql = 0
for data in tqdm(datas):
    if data['predict_SQL_answer'] == None:
        data['predict_SQL_answer'] = []
        unexecutable_sql += 1
    preds.append(data['predict_SQL_answer'])
    trues.append(data['SQL_answer'])
    all_sql += 1

ECR = (all_sql - unexecutable_sql)/all_sql


print(f'ECR:{ECR}')