In [None]:
from wayne_utils import load_data, save_data, get_shuffle_index, data_split
import os
import sys
from copy import deepcopy
from tqdm import tqdm

_ROOT_PATH = "/home/jiangpeiwen2/jiangpeiwen2/projects/TST"
_Data_path = os.path.join( _ROOT_PATH, "data/CPL/dynamic")

# 导入检索器
from tools.utils import*
version_dir = os.path.join( _ROOT_PATH, "test/CPL_dynamic/v1")
sys.path.insert( 0, _ROOT_PATH )
from Hybrid_RAG.hybrid_rag import HybridRAG
from Hybrid_RAG.data_capture import type_recognizer
from KGs.dataset_KGs.cpl_level import Dynamic_Simple, Dynamic_Simple_keyword


train_text = load_data( os.path.join( _Data_path, "texts_train.json"), "json")             # 492 
test_text = load_data( os.path.join( _Data_path, "texts_test.json"), "json")               # 211
train_table = load_data( os.path.join( _Data_path, "tables_train_dynamic.json"), "json")
test_table = load_data( os.path.join( _Data_path, "tables_test_dynamic.json"), "json")
label_train_table = load_data( os.path.join( _Data_path, "label_train.pickle"), "pickle")
label_test_table = load_data( os.path.join( _Data_path, "label_test.pickle"), "pickle")

count_label_test = load_data(  os.path.join( _ROOT_PATH, "evaluation/CountCPL/label_count_list_test.json"), "json")
count_label_train = load_data( os.path.join( _ROOT_PATH, "evaluation/CountCPL/label_count_list_train.json"), "json")
version_path = os.path.join( _ROOT_PATH, "test/CPL_dynamic/v1")

  from .autonotebook import tqdm as notebook_tqdm


## 计数 Prompt

In [5]:
ZH_EN_map = {
    '法院': "Court",
    '原告': "Plaintiff",
    '被告': "Defendant",
    '法院裁定_借款凭证': "Court Judges: Lending Evidence",
    '法院裁定_约定的借款金额': "Court Judges: Agreed Lending Amount",
    '法院裁定_约定的还款日期或借款期限': "Court Judges: Agreed Repayment Date",
    '法院裁定_约定的利息': "Court Judges: Agreed Interest",
    '法院裁定_约定的逾期利息': "Court Judges: Agreed Overdue Interest",
    '法院裁定_约定的违约金': "Court Judges: Agreed Liquidated Damages",
    '原告诉称_借款凭证': "Plaintiff claims: Lending Evidence",
    '原告诉称_约定的借款金额': "Plaintiff claims: Agreed Lending Amount",
    '原告诉称_约定的还款日期或借款期限': "Plaintiff claims: Agreed Repayment Date",
    '原告诉称_约定的利息': "Plaintiff claims: Agreed Interest",
    '原告诉称_约定的逾期利息': "Plaintiff claims: Agreed Overdue Interest",
    '原告诉称_约定的违约金': "Plaintiff claims: Agreed Liquidated Damages",
}
Instruction = """你是一名律师助手，需要阅读一份民事借贷案件的裁判文书后准确回答案件中相应的事件或主体的实例个数，注意如下：
（1）所谓实例个数，是指某种对象的实例数量，比如原告这个角色对象有几个人，或者借款凭证这个事物对象有几份，当事人约定借款金额这一行为对象的发生次数；
（2）请你严格控制输出为单纯的整数，以方便我快速从你的输出中提取数量信息，注意，所有数量的取值范围为[0,5]；
（3）下面在 '文本' 部分提供裁判文书内容，在 '问题' 部分提供要计数的对象，你应该先根据 '问题' 判断 '文本' 中哪些实例符合要求，然后进行区分与合并，不要重复计数。最后在 '数量' 后给出数值。
（4）最后请注意，当问题中出现 "有效约定"，指的是双方订立书面或作出口头约定的事件。如果文本中没有关于约定的内容，请回答0。
"""
Example = """"""

Input =  """下面请你进行实践：
文本：{context}
问题： {question}
数量：
"""

In [3]:
def get_question( entity, field ):
    if field == "姓名名称":
        return f"文本中有几个{entity}？"
    else:
        prefix = "根据文本，法院认定的" if entity == "法院" else "根据文本，原告主张的"
        if "借款凭证" in field:
            return prefix + "借款凭证有几份？"
        else:
            if "约定的" not in field:
                raise Exception(f"意外：{entity, field}")
            postfix = f"关于{field[3:]}的有效约定有几次？"
        return prefix + postfix

def get_output( count_dict, field, entity ):
    for key in count_dict.keys():
        if field == "姓名名称" and entity == key:
            return count_dict[key]
        elif entity in key and field in key:
            return count_dict[key]
    print(f"不匹配: {field} {entity} {count_dict}")
    return 0

In [6]:
def counter_construct_prompt( test_text, retrieve_mode, count_label_dict_list = None, mode="test"):
    prompt_lists = []
    for index in tqdm( range( len(test_text) ), desc="Prompt" ):
        if isinstance( test_text, dict):
            text = test_text[ str(index) ]
        else:
            text = test_text[ index ]

        if retrieve_mode == "hybrid":
            embedding_model_path = "/home/jiangpeiwen2/jiangpeiwen2/projects/TKGT/Hybrid_RAG/retriever/embed_model/sentence-transformer"
            llm_path = None
            retriever = HybridRAG( text, embedding_model_path, llm_path, 3, 0.6, 1)
        prompt_list = []
        for entity in [ "法院", "原告", "被告"]:
            for field in [ "姓名名称", "借款凭证", "约定的借款金额", "约定的还款日期或借款期限", "约定的利息", "约定的逾期利息", "约定的违约金"]:
                if entity == "被告" and field != "姓名名称":
                    continue
                if retrieve_mode == "hybrid":
                    context = retriever.hybrid_retrive( f"{entity},{field}" )          # 原始文本：texts   # 类型识别器：
                elif retrieve_mode == "whole":
                    context = text
                elif retrieve_mode == "type_recognizer":
                    context = type_recognizer( entity, field, text )
                else:
                    raise Exception( f"意外的contex检索 {retrieve_mode}")
                question = get_question( entity, field )
                inputs = Input.format( context = context, question = question)
                if mode == "train":
                    count_label_dict = count_label_dict_list[index]
                    # 训练要返回的是字典
                    _output = get_output( count_label_dict, field, entity )
                    prompt = {
                        "instruction": Instruction,
                        "input": Example + inputs,
                        "output": str(_output)
                    }
                else:
                    # 推理prompt要返回的是提示词列表
                    prompt =  Instruction+ Example + inputs
                prompt_list.append( prompt )
        if mode == "test":
            prompt_lists.append( prompt_list )
        else:
            prompt_lists.extend( prompt_list )
    return prompt_lists

In [8]:
retrieve_mode = "hybrid"        # "whole" "type_recognizer"
mode = "test"                   # "train"
test_prompt_lists = counter_construct_prompt( test_text, retrieve_mode, count_label_test, mode)
#fineture_prompt_lists = counter_construct_prompt( train_text, count_label_train, "train")

Prompt:   0%|          | 0/211 [00:00<?, ?it/s]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer from PyStemmer instead.
Prompt:   0%|          | 1/211 [00:05<19:12,  5.49s/it]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer from PyStemmer instead.
Prompt:   1%|          | 2/211 [00:05<08:42,  2.50s/it]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer from PyStemmer instead.
Prompt:   1%|▏         | 3/211 [00:06<05:22,  1.55s/it]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer from PyStemmer instead.
Prompt:   2%|▏         | 4/211 [00:06<03:51,  1.12s/it]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer from PyStemmer instead.
Prompt:   2%|▏         | 5/211 [00:07<02:59,  1.15it/s]The tokenizer parameter is deprecated and will be removed in a future release. Use a stemmer 

In [9]:
save_data( test_prompt_lists, os.path.join( version_dir, "counter_prompt_list_hybrid_rag.pickle"))
# save_data( fineture_prompt_lists, os.path.join( version_dir, "CPL_counter_ft_list.json"))

## 基于计数的提示词

In [None]:
model_name = "CPL_dynamic_counter_Chinese-Mistral-7B-Instruct-v0.1-3epoch"
counter_predict_list = load_data( os.path.join( version_dir, f"counter_predict_list_final_{model_name}.pickle"), "pickle")

In [9]:
Instruction = """你是一名律师助手，需要阅读一份民事借贷案件的裁判文书后准确回答案件每个相应事件或主体的实例信息，注意如下：
（1）下面在 '文本' 部分提供裁判文书内容，在 '问题' 部分提供具体要求，你应该先根据 '问题' 判断 '文本' 中哪些实例符合要求，然后进行区分与合并，不要重复计数。
（2）请注意，当问题中出现 "有效约定"，指的是双方订立书面或作出口头约定的事件。
（3）文本中也有可能并不包含任何有效信息，不要盲从。如果没有问题对应的答案信息，请根据具体要求大胆回答<NOT FOUND>。
"""
Example = """"""

Input =  """下面请你进行实践：
文本：{context}
问题： {question}
回答：
"""

Template = {
    "借款凭证": {
        "名称": "<NOT FOUND>",
        "所载内容": "<NOT FOUND>"
    },
    "约定的借款金额": {
        "约定情况": "<NOT FOUND>",
        "实际发生时间": "<NOT FOUND>",
        "金额": "<NOT FOUND>",
    },
    "约定的还款日期或借款期限": {
        "约定情况": "<NOT FOUND>",
        "实际发生时间": "<NOT FOUND>",
        "还款日期": "<NOT FOUND>",
    },
    "约定的利息": {
        "约定情况": "<NOT FOUND>",
        "实际发生时间": "<NOT FOUND>",
        "利率数值": "<NOT FOUND>",
    },
    "约定的逾期利息": {
        "约定情况": "<NOT FOUND>",
        "实际发生时间": "<NOT FOUND>",
        "利率数值": "<NOT FOUND>",
    },
    "约定的违约金": {
        "约定情况": "<NOT FOUND>",
        "实际发生时间": "<NOT FOUND>",
        "违约金数值": "<NOT FOUND>",
    },
}

In [1]:
def get_number( counts, entity, field):
    if field == "姓名名称":
        return counts[ entity ]
    else:
        for key in counts.keys():
            if entity in key and field in key:
                return counts[ key ]
        print( f"{entity, field}没有数量" )
        return 0
    

def get_content_question( entity, field, number ):
    if field == "姓名名称":
        calls = "姓名" if entity != "法院" else "名称"
        return f"请列出文本中{number}个{entity}的{calls}，若有多个，请用顿号分隔。"
    else:
        liangci = "份" if "借款凭证" in field else "次"
        all_prefix = f"请列出文本中{number}{liangci}" 
        prefix = f"法院认定的" if entity == "法院" else "原告主张的"
        if "约定的" in field:
            post = f"关于{field[3:]}的有效约定信息。"
        elif field == "借款凭证":
            post = f"关于借款凭证的信息。"
        else:
            raise Exception(f"意外：{entity, field}")

        all_post = f"具体要求：\n（1）所有子信息在一行中以冒号键值对的形式列出，键值对间用分号分隔。若有多{liangci}，请换行列出。\n（2）每行中的子信息应遵循格式："
        formats = {
            "借款凭证" : "名称：xxx；所载内容：xxx；",
            "约定的借款金额": "约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；金额：xx元；",
            "约定的还款日期或借款期限": "约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；还款日期：xxxx-xx-xx（年-月-日）；",
            "约定的利息": "约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；利率数值：xx（去除百分号的数值部分）；",
            "约定的逾期利息": "约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；利率数值：xx（去除百分号的数值部分）；",
            "约定的违约金": "约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；违约金数值：xx（去除百分号的数值部分）；",
        }
        not_prompt = "\n（3）若某项子信息不在文本中，则保留键，统一使用符号<NOT FOUND>代替值，便于后续检索使用。"
        return all_prefix + prefix + post + all_post + formats[field] + not_prompt
def ratio_post( value ):
    if not value.strip()[-1].isdigit():
        value = value[:-1]
    if not value[-1].isdigit():
        return None
    if "%" in value:
        value = value.split("%")[0].strip()
    try:
        float_value = float( value )
    except:
        return None
    if float_value < 0.99:
        ret = str( float_value * 100)
    elif float_value >100:
        return None
    else:
        ret = str(float_value)
    if ret.endswith(".0"):
        ret = ret[:-2]
    return ret

def get_content_output( entity, field, labels, number):
    if field == "姓名名称":
        template = ""
    else:
        template = { i:deepcopy(Template[field]) for i in range(1, number+1)}

    # 遍历label中所有三元组
    for triplet in labels:
        if entity == triplet[0]:
            
            if field in triplet[1]:
                if field == "姓名名称":
                    if template == "":
                        template = triplet[2]
                    else:
                        template = template + "、" + triplet[2]
                else:
                    ans = triplet[2]
                    if "数值（百分比或元）" in triplet[1]:
                        ans = ratio_post( ans ) 
                    for i in range(1, number+1):
                        if str(i) in triplet[1]:
                            
                            for key in template[i].keys():
                                if f"{field[-2:]}的{key}" in triplet[1]:
                                    template[i][key] = ans
    if field == "姓名名称":
        return template
    ret = ""
    for i in template.keys():
        for key in template[i].keys():
            ret = ret + f"{key}：{template[i][key]}；"
        ret = ret +"\n"
    if ret.endswith("\n"):
        ret = ret[:-2]
    return ret

In [10]:
print(get_content_question( "法院", "约定的借款金额", 1 ))

请列出文本中1次法院认定的关于借款金额的有效约定信息。具体要求：
（1）所有子信息在一行中以冒号键值对的形式列出，键值对间用分号分隔。若有多次，请换行列出。
（2）每行中的子信息应遵循格式：约定情况: xx（口头约定或书面约定二选一）；实际发生时间：xxxx-xx-xx（年-月-日）；金额：xx元；
（3）若某项子信息不在文本中，则保留键，统一使用符号<NOT FOUND>代替值，便于后续检索使用。


In [11]:
def construct_prompt_based_on_counter( counters_list, test_text, label_test_table, mode ):
    prompt_lists = []
    prefix_lists = []
    if len(test_text) != len(label_test_table):
        raise Exception("长度不一致")
    for index in range( len(test_text)):
        texts = test_text[str(index)]
        labels = label_test_table[index]
        counters = counters_list[index]
        prompt_list = []
        prefix_list = []
        for entity in [ "法院", "原告", "被告"]:
            for field in [ "姓名名称", "借款凭证", "约定的借款金额", "约定的还款日期或借款期限", "约定的利息", "约定的逾期利息", "约定的违约金"]:
                if entity == "被告" and field != "姓名名称":
                    continue
                number = get_number( counters, entity, field)
                if number == 0:
                    continue
                prefix_list.append( (entity, field, number) )
                context = keyword_rag_context( entity, field, texts )
                question = get_content_question( entity, field, number )
                inputs = Input.format( context = context, question = question)
                if mode == "train":
                    # 训练要返回的是字典
                    _output = get_content_output( entity, field, labels, number)
                    prompt = {
                        "instruction": Instruction,
                        "input": Example + inputs,
                        "output": _output
                    }
                else:
                    # 推理prompt要返回的是提示词列表
                    prompt =  Instruction+ Example + inputs
                prompt_list.append( prompt )
        if mode == "test":
            prefix_lists.append( prefix_list )
            prompt_lists.append( prompt_list )
        else:
            prompt_lists.extend( prompt_list )
    return prompt_lists, prefix_lists



In [12]:
def get_prompts(mode):
    counters_list = counter_predict_list if mode == "test" else count_label_train
    tests_list = test_text if mode == "test" else train_text
    labels_list = label_test_table if mode == "test" else label_train_table
    prompt_lists, prefix_list = construct_prompt_based_on_counter( counters_list, tests_list, labels_list, mode )
    return prompt_lists, prefix_list

# for mode in [ "test", "train" ]: 
mode = "test"
prompt_lists, prefix_list = get_prompts(mode)
save_data( prompt_lists, os.path.join( version_dir, "table_prompt_list.pickle"))
save_data( prefix_list, os.path.join( version_dir, "table_prefix_list.pickle"))
mode = "train"
prompt_lists, prefix_list = get_prompts(mode)
save_data( prompt_lists, os.path.join( version_dir, "CPL_table_ft_list.json"))

In [236]:
len(prompt_lists[1])

7

In [None]:
# def post_counter_construct_prompt( texts_list, count_label_dict_list, mode):
prompt_lists = []
# for index in range( len(test_text) ):
index = 0

texts = texts_list[ str(index) ]
count_label_dict = count_label_dict_list[index]
prompt_list = []
for entity in [ "法院", "原告", "被告"]:
    for field in [ "姓名名称", "借款凭证", "约定的借款金额", "约定的还款日期或借款期限", "约定的利息", "约定的逾期利息", "约定的违约金"]:
        if entity == "被告" and field != "姓名名称":
            continue
        context = keyword_rag_context( entity, field, texts )
        question = get_question( entity, field )
        inputs = Input.format( context = context, question = question)
        