In [1]:

import os

os.environ["RWKV_JIT_ON"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys
import gc
import types
import copy
from src.model_run import RWKV_RNN
import numpy as np
import torch
from src.utils import TOKENIZER
import json
import re
import datetime
import csv
# os.chdir("/app/CSS_AI")

# from flask import Flask, request, jsonify
# from gevent.pywsgi import WSGIServer
# import requests

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
np.set_printoptions(precision=4, suppress=True, linewidth=200)

WORD_NAME = [
    "20B_tokenizer.json",
    "20B_tokenizer.json",
]  # [vocab, vocab] for Pile model
UNKNOWN_CHAR = None
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)

args = types.SimpleNamespace()
args.RUN_DEVICE = "cuda"  # 'cpu' (already very fast) // 'cuda'
# fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
args.FLOAT_MODE = "fp16"
args.vocab_size = 50277
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0


args.lora_r = 8
args.lora_alpha = 16

# args.MODEL_NAME = '../RWKV-4-Pile-3B-EngChn-test4-20230115'
# args.n_layer = 32
# args.n_embd = 2560
# args.ctx_len = 1024

# args.MODEL_LORA = './model_lora'
user = "Bob"
bot = "Alice"
interface = ":"

init_prompt = '''

'''

args.MODEL_NAME = '../RWKV-4-Pile-7B-EngChn-test5-20230326'
args.n_layer = 32
args.n_embd = 4096
args.ctx_len = 1024

# Modify this to use LoRA models; lora_r = 0 will not use LoRA weights.
args.MODEL_LORA = './out_7b_tag/rwkv-1'
# args.MODEL_LORA = './rwkv-raw'

# user = "Q"
# bot = "A"
# interface = ":"

# init_prompt = '''
# The following is a coherent verbose detailed conversation between a Chinese girl named {bot} and her friend {user}. \
# {bot} is very intelligent, creative and friendly. \
# {bot} likes to tell {user} a lot about herself and her opinions. \
# {bot} usually gives {user} kind, helpful and informative advices.
# '''
# HELP_MSG = '''指令:
# 直接输入内容 --> 和机器人聊天，用\\n代表换行
# +alt --> 让机器人换个回答
# +reset --> 重置对话

# +gen 某某内容 --> 续写任何中英文内容，用\\n代表换行
# +qa 某某问题 --> 问独立的问题（忽略上下文），用\\n代表换行
# +more --> 继续 +gen / +qa 的回答
# +retry --> 换个 +gen / +qa 的回答

# 现在可以输入内容和机器人聊天（注意它不怎么懂中文，它可能更懂英文）。请经常使用 +reset 重置机器人记忆。
# '''

# Load Model

os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE
MODEL_NAME = args.MODEL_NAME

print(f'loading... {MODEL_NAME}')
model = RWKV_RNN(args)

model_tokens = []

current_state = None

########################################################################################################


def run_rnn(tokens, newline_adj=0):
    global model_tokens, current_state
    for i in range(len(tokens)):
        model_tokens += [int(tokens[i])]
        if i == len(tokens) - 1:
            out, current_state = model.forward(model_tokens, current_state)
        else:
            current_state = model.forward(model_tokens,
                                          current_state,
                                          preprocess_only=True)

    # print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]')

    out[0] = -999999999  # disable <|endoftext|>
    out[187] += newline_adj
    # if newline_adj > 0:
    #     out[15] += newline_adj / 2 # '.'
    return out


all_state = {}


def save_all_stat(srv, name, last_out):
    n = f'{name}_{srv}'
    all_state[n] = {}
    all_state[n]['out'] = last_out
    all_state[n]['rnn'] = copy.deepcopy(current_state)
    all_state[n]['token'] = copy.deepcopy(model_tokens)


def load_all_stat(srv, name):
    global model_tokens, current_state
    n = f'{name}_{srv}'
    current_state = copy.deepcopy(all_state[n]['rnn'])
    model_tokens = copy.deepcopy(all_state[n]['token'])
    return all_state[n]['out']


########################################################################################################

# Run inference
print(f'\nRun prompt...')

out = run_rnn(tokenizer.tokenizer.encode(init_prompt))
gc.collect()
torch.cuda.empty_cache()

save_all_stat('', 'chat_init', out)

srv_list = ['dummy_server']
for s in srv_list:
    save_all_stat(s, 'chat', out)

print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n')

# def reply_msg(msg):
#     print(f'{bot}{interface} {msg}\n')


def on_message(message):
    global model_tokens, current_state

    srv = 'dummy_server'

    msg = message.replace('\\n', '\n').strip()
    if len(msg) > 1000:
        return ('your message is too long (max 1000 tokens)')

    x_temp = 1.0
    x_top_p = 0.85
    if ("-temp=" in msg):
        x_temp = float(msg.split("-temp=")[1].split(" ")[0])
        msg = msg.replace("-temp=" + f'{x_temp:g}', "")
        # print(f"temp: {x_temp}")
    if ("-top_p=" in msg):
        x_top_p = float(msg.split("-top_p=")[1].split(" ")[0])
        msg = msg.replace("-top_p=" + f'{x_top_p:g}', "")
        # print(f"top_p: {x_top_p}")
    if x_temp <= 0.2:
        x_temp = 0.2
    if x_temp >= 5:
        x_temp = 5
    if x_top_p <= 0:
        x_top_p = 0

    if msg == '+reset':
        out = load_all_stat('', 'chat_init')
        save_all_stat(srv, 'chat', out)
        return ("Chat reset.")

    else:
        if msg.lower() == '+alt':
            try:
                out = load_all_stat(srv, 'chat_pre')
            except:
                return
        else:
            out = load_all_stat(srv, 'chat')
            new = f"{user}{interface} {msg}\n\n{bot}{interface}"
            # print(f'### add ###\n[{new}]')
            out = run_rnn(tokenizer.tokenizer.encode(new),
                          newline_adj=-999999999)
            save_all_stat(srv, 'chat_pre', out)

        begin = len(model_tokens)
        out_last = begin
        # print(f'{bot}{interface}', end='', flush=True)
        out_string = ""
        for i in range(999):
            if i <= 0:
                newline_adj = -999999999
            elif i <= 30:
                newline_adj = (i - 30) / 10
            elif i <= 130:
                newline_adj = 0
            else:
                newline_adj = (i - 130) * 0.25  # MUST END THE GENERATION
            token = tokenizer.sample_logits(
                out,
                model_tokens,
                args.ctx_len,
                temperature=x_temp,
                top_p_usual=x_top_p,
                top_p_newline=x_top_p,
            )
            out = run_rnn([token], newline_adj=newline_adj)

            xxx = tokenizer.tokenizer.decode(model_tokens[out_last:])
            if '\ufffd' not in xxx:
                # print(xxx, end='', flush=True)
                out_string += xxx
                out_last = begin + i + 1

            send_msg = tokenizer.tokenizer.decode(model_tokens[begin:])
            if '\n\n' in send_msg:
                send_msg = send_msg.strip()
                break

            # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip()
            # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!!
            #     send_msg = send_msg[:-len(f'{user}{interface}')].strip()
            #     break
            # if send_msg.endswith(f'{bot}{interface}'):
            #     send_msg = send_msg[:-len(f'{bot}{interface}')].strip()
            #     break

        # print(f'{model_tokens}')
        # print(f'[{tokenizer.tokenizer.decode(model_tokens)}]')

        # print(f'### send ###\n[{send_msg}]')
        # reply_msg(send_msg)
        save_all_stat(srv, 'chat', out)
        return out_string


RWKV_HEAD_QK_DIM 0 RWKV_JIT_ON 1

loading... ../RWKV-4-Pile-7B-EngChn-test5-20230326
merging blocks.4.ffn.key.lora_A and blocks.4.ffn.key.lora_B into blocks.4.ffn.key.weight
merging blocks.20.att.receptance.lora_A and blocks.20.att.receptance.lora_B into blocks.20.att.receptance.weight
merging blocks.9.att.key.lora_A and blocks.9.att.key.lora_B into blocks.9.att.key.weight
merging blocks.5.att.receptance.lora_A and blocks.5.att.receptance.lora_B into blocks.5.att.receptance.weight
merging blocks.15.att.key.lora_A and blocks.15.att.key.lora_B into blocks.15.att.key.weight
merging blocks.15.ffn.key.lora_A and blocks.15.ffn.key.lora_B into blocks.15.ffn.key.weight
merging blocks.20.att.value.lora_A and blocks.20.att.value.lora_B into blocks.20.att.value.weight
merging blocks.15.ffn.value.lora_A and blocks.15.ffn.value.lora_B into blocks.15.ffn.value.weight
merging blocks.16.att.value.lora_A and blocks.16.att.value.lora_B into blocks.16.att.value.weight
merging blocks.16.att.receptance.lo

In [2]:
from tqdm.notebook import tqdm
import re


def extract_bracketed_content(s):
    return re.findall(r'\[(?:.*?\((.*?)\).*?)\]', s)
    # return re.findall(r'\[(.*?)\]', s)

In [3]:
from difflib import SequenceMatcher
import random


def similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

In [4]:
def calculate_precision_recall(list_a, list_b):
    if len(set(list_a)) == 0:
        return 0, 0
    true_positives = len(set(list_a) & set(list_b))
    false_positives = len(set(list_a) - set(list_b))
    false_negatives = len(set(list_b) - set(list_a))

    precision = true_positives / (true_positives + false_positives)
    recall = true_positives / (true_positives + false_negatives)

    return precision, recall

In [5]:
# with open("train.txt", "r", encoding="utf-8") as f:
with open("test1.txt", "r", encoding="utf-8") as f:
    # 逐行读取
    lines = f.readlines()
    num_lines = len(lines)

on_message("+reset")
# random.shuffle(lines)
simil_score = 0
precision_score = 0
recall_score = 0
n = 0
for i, line in tqdm(enumerate(lines), total=len(lines)):
# for i, line in tqdm(enumerate(lines[:100]), total=100):
    # 去掉换行符
    line = line.strip()
    # 按制表符分隔
    parts = line.split("\t")
    answer = on_message(parts[0]+"。 请你帮我抽取这段话中的指标、标签、维度")
    # answer = on_message(parts[0])
    on_message("+reset")
    print("RAW: " + parts[0])
    print("NER: " + answer.strip())
    print("ANS: " + parts[1])
    string_a = answer.strip()
    string_b = parts[1]
    list_a = extract_bracketed_content(string_a)
    list_b = extract_bracketed_content(string_b)
    # try:
    precision, recall = calculate_precision_recall(list_a, list_b)
    precision_score += precision
    recall_score += recall
    simil_score += similarity(string_a, string_b)
    n += 1
    # except:
    #     continue

  0%|          | 0/13 [00:00<?, ?it/s]

RAW: 请告诉我北京去年4月千户集团制造行业税收收入
NER: 指标：入库税额
标签：千户集团集团企业
维度：去年4月(时间维度值)，北京市(行政区划维度值)，制造业(行业门类维度值)。 请你帮我抽取这段话中的指标、标签、维度
ANS: [北京(行政区划维度值;北京市)][去年4月(时间维度值)][千户集团(标签;千户集团企业标志)][制造业(行业门类维度值;制造业)][税收收入(指标;入库税额)]
RAW: 山东今年1月农业入库税额
NER: 指标：入库税额
标签：农业
维度：今年1月(时间维度值)，山东省(行政区划维度值)，农、林、牧、渔业(行业门类维度值) 请你帮我抽取这段话中的指标、标签、维度
ANS: [山东(行政区划维度值;山东省)][今年1月(时间维度值)][农业(行业大类维度值;农业)][入库税额(指标;入库税额)]
RAW: 去年12月总分机构增值税延期缴纳税款
NER: 指标：开具增值税异常扣税凭证金额
标签：总分机构企业
维度：去年12月(时间维度值)，总分机构企业(行业门类维度值)。 请你帮我抽取这段话中的指标、标签、维度
ANS: [去年12月(时间维度值)][总分机构(标签;总分机构企业)][增值税延期缴纳税款(指标;增值税延期缴纳税款)]
RAW: 新疆去年4月制造行业中千户集团的税收金额
NER: 指标：入库税额
标签：中千户集团企业
维度：去年4月(时间维度值)，新疆维吾尔自治区(行政区划维度值)，制造业(行业门类维度值) 的指标、标签、维度
(请你帮我抽取这段话中的指标、标签、维度) 请你帮我抽取这段话中的指标、标签、维度
ANS: [新疆(行政区划维度值;新疆维吾尔自治区)][去年4月(时间维度值)][制造业(行业门类维度值;制造业)][千户集团(标签;千户集团企业标志)][税收金额(指标;入库税额)]
RAW: 2022年陕西千户集团中央企业缴纳的税
NER: 指标：入库税额
标签：千户集团中央企业
维度：2022年(时间维度值)，陕西省(行政区划维度值)，其他非居民企业(行业门类维度值) 请你帮我抽取这段话中的指标、标签、维度
ANS: [2022年(时间维度值)][陕西(行政区划维度值;陕西省)][千户集团中央企业(标签;千户集团中央企业)][缴纳的税(指标;入库税额)]
RAW: 2021年上半年制造业中专精特新小巨人

In [6]:
simil_score = simil_score / n
precision_score = precision_score / n
recall_score = recall_score / n
print("总样本量：\t" + str(n))
print("综合相似度：\t" + str(simil_score))
print("综合准确率：\t" + str(precision_score))
print("综合召回率：\t" + str(recall_score))

总样本量：	13
综合相似度：	0.37523666630924646
综合准确率：	0.0
综合召回率：	0.0
