In [1]:
import asyncio 

from src.llm.gigachat_class import GigaChat
from src.llm.prompts import TAB_FACT_SYSTEM_PROMPT

from src.utils.dataset_loader import load_tab_fact
from src.utils.table_formatters import *

### Model

In [2]:
giga = GigaChat()



### Data

In [3]:
async def get_formats(data_sample, mode):
    headers, rows = data_sample['table']['header'], data_sample['table']['rows']
    df_ = array_to_df(headers, rows)
    html = convert_df_to_html(df_)
    rst = await convert_html_to_list_table(html)
    if mode == 'html':
        return html
    else:
        return rst

tabfact_data = load_tab_fact()['train']
tabfact_dicts = []
for i in tabfact_data:
    d = {}
    d.update(i)
    tabfact_dicts.append(d)
    
human_message_template = """
Таблица: {table}
Утверждение: {statement}
"""

### Eval

In [14]:
import os
import pickle 
from tqdm import tqdm

from src.config import giga_config

BATCH_SIZE=10
mode = 'rst'
os.makedirs(f'results/{giga_config.GIGACHAT_MODEL}/{mode}', exist_ok=True)
for start in tqdm(range(0, len(tabfact_dicts[:2000]), BATCH_SIZE)):
    end = min(start + BATCH_SIZE, len(tabfact_dicts))
    
    statements = [x['statement'] for x in tabfact_dicts[start:end]]
    labels = [x['label'] for x in tabfact_dicts[start:end]]
    tables = [await get_formats(x, mode) for x in tabfact_dicts[start:end]]
    
    if statements:
        tasks = [giga.async_ask_giga(human_message_template.format(table=table, statement=statement), prompt=TAB_FACT_SYSTEM_PROMPT) for table, statement in zip(tables, statements)]
        results = await asyncio.gather(*tasks)
        with open(f'results/{giga_config.GIGACHAT_MODEL}/{mode}/{start}_{end}.pkl', 'wb') as f:
            pickle.dump(results, f)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [27:55<00:00,  8.38s/it]


### Metrics

In [7]:
import os
import re
import pickle 

In [18]:
i = 0

mode = 'rst'
tp_score = 0
cnt_unparsed_answer = 0

base_file_path = f'results/{giga_config.GIGACHAT_MODEL}/{mode}'
for start in tqdm(range(0, len(tabfact_dicts[:2000]), BATCH_SIZE)):
    end = min(start + BATCH_SIZE, len(tabfact_dicts))
    labels = [x['label'] for x in tabfact_dicts[start:end]]
    with open(os.path.join(base_file_path, f"{start}_{end}.pkl"), 'rb') as f:
        model_answers = pickle.load(f)
    
    res_on_batch = []
    for case_answer in model_answers:
        answer = re.search(r'([Ии]тоговый ответ: \d)|([Ff]inal [Aa]nswer: \d)' ,case_answer)
        if answer:
            answer = int(answer.group().split(":")[-1].strip())
            res_on_batch.append(answer)
        else:
            cnt_unparsed_answer += 1
            # print(case_answer)
            # print('=== END ===')
            res_on_batch.append(2)
    
    for predicted, gt in zip(res_on_batch, labels):
        tp_score += predicted == gt
print(tp_score, tp_score/(2000-cnt_unparsed_answer), cnt_unparsed_answer)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 1381.70it/s]

1677 0.8393393393393394 2



