In [None]:
import os
import sys
from wayne_utils import load_data, save_data
_ROOT_PATH = "workspace/TKGT"
sys.path.insert( 0, _ROOT_PATH )
_Data_path = os.path.join( _ROOT_PATH, "data/CPL/dynamic")
version_dir = os.path.join( _ROOT_PATH, "test/CPL_dynamic/v1")

label_lists = load_data( os.path.join( version_dir, "label_lists.pickle"), "pickle")
prefix_lists = load_data( os.path.join( version_dir, "table_prefix_list.pickle"), "pickle")
prompt_lists = load_data( os.path.join( version_dir, "table_prompt_list.pickle"), "pickle")



In [132]:
model_list_table = [
    "ChatGLM3-6B",
    "Qwen1.5-7B-Chat",
    "Baichuan2-7B-Chat",
    "Chinese-Mistral-7B-Instruct-v0.1",
    "Qwen2.5-0.5B",
    "CPL_dynamic_tabel_Qwen1.5-7B-Chat-4epoch",
    "CPL_dynamic_tabel_ChatGLM3-6B-4epoch",
    "CPL_dynamic_tabel_Chinese-Mistral-7B-Instruct-v0.1-4epoch",
]
scopes = model_list_table[7:]
scopes

['CPL_dynamic_tabel_Chinese-Mistral-7B-Instruct-v0.1-4epoch']

In [127]:
def get_answer( text ):
    if text.strip() == "":
        return None
    elif "not found" in text.lower() or "notfound" in text.lower() or "抱歉" in text:
        return None
    else:
        return text.strip()

def post_process( predicts, prefixs, prompts):
    """处理单个文件"""
    if len( predicts ) != len( prefixs ):
        raise Exception( f"Predict和prefix列表长度不一致：{len( predicts )} vs {len( prefixs )}" )
    temp = set()
    for i in range( len(predicts) ):
        entity_type, field = prefixs[i][0], prefixs[i][1]
        predict = predicts[i]
        try:
            ans = get_answer( predict )
        except:
            raise Exception( f"{i} {predict}" )
        if ans != None and "\n相关上下文：[]" not in prompts[i]:
            temp.add( (entity_type, field, ans))
    return temp

def ratio_post( value ):
    if not value.strip()[-1].isdigit():
        value = value[:-1]
    if value == "" or not value[-1].isdigit():
        return None
    if "%" in value:
        value = value.split("%")[0].strip()
    try:
        float_value = float( value )
    except:
        return None
    if float_value < 0.99:
        ret = str( float_value * 100)
    elif float_value >100:
        return None
    else:
        ret = str(float_value)
    if ret.endswith(".0"):
        ret = ret[:-2]
    return ret

from datetime import datetime
def date_post(date_str):
    # 尝试不同的日期格式进行解析
    date_formats = [
        "%Y/%m/%d",  # 2012/2/2
        "%Y.%m.%d",  # 2012/2/2
        "%Y年%m月%d日",  # 2000年2月2日
        "%Y年-%m月-%d日",  # 2000年-2月-2日
        "%Y-%m-%d"  # 2000-2-2
    ]
    
    for date_format in date_formats:
        try:
            # 尝试使用当前格式解析日期
            date_obj = datetime.strptime(date_str, date_format)
            # 格式化为统一的"xxxx年xx月xx日"格式
            formatted_date = date_obj.strftime("%Y年%m月%d日")
            return formatted_date
        except ValueError:
            # 如果解析失败，继续尝试下一个格式
            continue
    
    # 如果所有格式都尝试失败，返回错误信息
    return None

def split_names( value ):
    if "、" in value:
        sep = "、"
    elif "，" in value:
        sep = "，"
    elif "," in value:
        sep = ","
    else:
        sep = None

    if sep == None:
        return value.strip()
    else:
        ret = []
        names = value.split( sep )
        for name in names:
            ret.append( name.strip() )
        return ret

def value_type_unify( value_set ):
    """
    value_list: 预测或label
    three: 是否是三元组
    """
    temp = set()
    # 遍历单个文档的元组集合
    for item in value_set:
        # 获取每个元组的组成
        entity, field, value = item[0], item[1], item[2]
        if field == "姓名名称":
            names = split_names( value )
            if isinstance( names, list):
                for name in names:
                    temp.add( (entity, field, name) )
            else:
                temp.add( (entity, field, names) )
        elif "约定情况" in field:
            if "口头" in value:
                temp.add( (entity, field, "口头约定") )
            elif "书面" in value:
                temp.add( (entity, field, "书面约定") )
            elif "未" in value:
                temp.add( (entity, field, "未约定") )
        elif "（百分比或元）" in field:
            value = ratio_post( value )
            if value != None:
                temp.add( ( entity, field, value ) )
        elif "日期" in field or "时间" in field:
            value = date_post( value )
            if value != None:
                temp.add( ( entity, field, value ) )
        else:
            temp.add( ( entity, field, value ) )
    return temp


In [128]:
for model_name in scopes:
    predict_lists = load_data( os.path.join( version_dir, f"table_predict_list_{model_name}.pickle"), "pickle")
    predict_post = []
    for index in range( len(prefix_lists)):
        #index = 1           # doc
        temp_set = set( )
        for inner_index in range( len(prefix_lists[index])):
            # inner_index = 3     # prefix
            entity, top_field, number = prefix_lists[index][inner_index][0], prefix_lists[index][inner_index][1], prefix_lists[index][inner_index][2]
            predict = predict_lists[index][inner_index]

            if "\n回答：" in predict:
                predict = predict.split( "\n回答：")[1].strip()
            if "如下：" in predict:
                predict = predict.split( "如下：")[1].strip()
            labels_subset = []
            # 遍历该文档的每个三元组，将于目前prefix契合的放入子集
            for item in label_lists[index]:
                total_field = item[1]
                if entity == item[0] and top_field in total_field:          # 这个标签是和当前Prefix匹配的
                    labels_subset.append( item )
            # 遍历子集
            for item in labels_subset:
                total_field = item[1]
                if "期限的还款日期" in total_field or "实际发生时间" in total_field:
                    
                    date_line = item[2]
                    
                    try:
                        if "、" in date_line:
                            date_line = date_line.split("、")[0].strip()
                        try:
                            year, month, day = date_line.split("-")
                        except:
                            year, month, day = date_line.split(".")
                    except:
                        if item[2] == "借条" or "被告曾国印向原告借款50000元" in item[2] or "6/？" in item[2] or "2017年5月底" in item[2] or '43097' in item[2] or \
                        '农历2015/12/3' in item[2] or '定于2018年10月21日之前归还' in item[2] or '40487' in item[2] : 
                            pass
                        else:
                            raise Exception( f"label中的日期格式不对{date_line} {item}")
                    if date_line in predict or f"{year}年{month}月{day}日" in predict:
                        temp_set.add( (entity, total_field, date_line) )
                elif "的金额（元）" in total_field:
                    money = item[2]
                    if money in predict:
                        temp_set.add( (entity, total_field, money) )
                    elif money.endswith("0000"):
                        wan_yuan = money[-4]+"万"
                        if wan_yuan in predict:
                            temp_set.add( (entity, total_field, money) )
                elif "数值（百分比或元）" in total_field:
                    ratio = str(item[2])
                    new_ration = ratio_post( ratio )
                    if ratio in predict or str(new_ration) in predict:
                        temp_set.add( (entity, total_field, ratio) )
                elif "约定情况" in total_field:
                    approve = item[2]
                    for type_ in ["口头", "书面"]:
                        if type_ in approve and type_ in predict:
                            temp_set.add( (entity, total_field, approve) )
                            break
                else:
                    names_or_content = item[2]
                    if names_or_content in predict:
                        temp_set.add( (entity, total_field, names_or_content) )

        predict_post.append( temp_set )
    new_predict_list = []
    new_label_list = []
    for i in range(len(predict_post)):
        new_predict_list.append( value_type_unify( predict_post[i] ) )
        new_label_list.append(   value_type_unify( label_lists[i] ) )
    save_data( (new_label_list, new_predict_list), os.path.join( version_dir, f"table_predict_lists_filter_{model_name}.pickle"))

## Eval

In [133]:
from evaluation.Text2table.evaluate import eval_main
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
for model_name in scopes:    

    loaded_pair_list = load_data( os.path.join( version_dir, f"table_predict_lists_filter_{model_name}.pickle"), "pickle")
    
    results_save_path = os.path.join( _ROOT_PATH, "/home/jiangpeiwen2/jiangpeiwen2/TKGT/evaluation/CountCPL/results.json")
    eval_main( f"CPL_dynamic_static_baseline_{model_name}", loaded_pair_list, results_save_path, "multi_entity" )

100%|██████████| 211/211 [00:00<00:00, 1525.69it/s]


Row header: precision = 100.00; recall = 99.53; f1 = 99.72
Col header: precision = 100.00; recall = 80.30; f1 = 87.65
Non-header cell: precision = 100.00; recall = 78.63; f1 = 87.15


100%|██████████| 211/211 [00:01<00:00, 210.33it/s]


Row header: precision = 100.00; recall = 99.59; f1 = 99.76
Col header: precision = 100.00; recall = 85.10; f1 = 90.83
Non-header cell: precision = 100.00; recall = 80.89; f1 = 88.65


100%|██████████| 211/211 [03:06<00:00,  1.13it/s]

Row header: precision = 100.00; recall = 99.63; f1 = 99.78
Col header: precision = 100.00; recall = 87.79; f1 = 92.55
Non-header cell: precision = 100.00; recall = 82.81; f1 = 89.87



