##  加载外部数据

In [68]:


# 将数据切分成
from dataclasses import dataclass, field
from math import ceil, floor
from enum import Enum
from typing import List, Literal, TypedDict, Optional
import tiktoken
import openai
import os
from dotenv import load_dotenv

class MessageRole(Enum):
    # chat role
    User = "user"
    System = "system"
    Asisstant = "asisstant"
    # inner role
    You = "you"
    YourComputer = "your computer"
class MessageType(Enum):
    AIResponse = "ai_response"
    ActionResults = "action_results"


# MessageRole = Literal["system", "user", "assistant"]
# MessageType = Literal["ai_response", "action_results"]


class MessageDict(TypedDict):
    role: MessageRole
    content: str
@dataclass
class Message:
    role: MessageRole
    content: str
    # TODO msg_type 的意义是什么？
    msg_type: Optional[MessageType] = None

    def raw(self) -> MessageDict:
        return {"role": self.role, "content": self.content}


def count_message_tokens(messages: List[Message], model: str = "gpt-3.5-turbo") -> int:
    encoding = tiktoken.encoding_for_model(model)
    tokens_per_message = 3
    tokens_per_name = 1
    if model.startswith("gpt-3.5"):
        tokens_per_message = 4
        tokens_per_name = -1
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.raw().items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|> assistant<|message>
    return num_tokens


def count_string_tokens(string: str, model_name: str = "gpt-3.5-turbo") -> int:
    encoding = tiktoken.encoding_for_model(model_name)
    return len(encoding.encode(string))  


def get_completion_from_messages(messages, model='gpt-3.5-turbo', temperature=0, max_tokens=500):
    response = openai.ChatCompletion.create(
        model=model,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return response.choices[0].message["content"]

def create_msg(msg:str):
    
    messages = [
   
        {"role": "user", "content": msg},
    ]
    return messages

load_dotenv("../etc/dev/.env")
openai_api_key: str = os.getenv("OPENAI_API_KEY")
openai.api_key = openai_api_key

msg = "北京是哪个国家的首都"
print(get_completion_from_messages(messages=create_msg(msg)))

print(f'count({msg}) ={count_string_tokens(msg)} ')


北京是中华人民共和国的首都。
count(北京是哪个国家的首都) =10 


## 加载KG

In [52]:
import json
import time
import traceback
template_name = "./kg/kg_industry_ana.txt"
template_str = open(template_name).read()
import jinja2
environment = jinja2.Environment()
template = environment.from_string(template_str)

content = open('data.txt').read()
# print(content) 
template_count = count_string_tokens(template_str)
OPENAI_MAX_TOKEN = 4096 - 300
content_token_count = OPENAI_MAX_TOKEN - template_count - 700
# print(f"template count is {template_count}, conent_count is {content_token_count}")

from langchain.text_splitter import CharacterTextSplitter
text_spliter = CharacterTextSplitter(
        separator="\n", chunk_size=content_token_count, chunk_overlap=100, length_function=len
    )
# print(para)
tmp_para_list = text_spliter.split_text(content)
# print(len(tmp_para_list))
tmp_title_list = []
tmp_meta_list = []
index = 0
num_para = len(tmp_para_list)
max_count = -1
max_item = ''
output_file = 'output/600271.kg_industry_ana_in_{}.txt'
openai_out = 'output/600271.kg_industry_ana.openaiout.jsonl'
openai_raw_out = 'output/600271.kg_industry_ana.openaiout.raw'
writer = open(openai_out, "a")
raw_writer = open(openai_raw_out, "w")

for para_item in tmp_para_list:
    item = template.render(content_fragment=para_item)
 
    
    
    token_count = count_string_tokens(item)
    if token_count > 0 and token_count < OPENAI_MAX_TOKEN:
        open(output_file.format(index), "w").write(item)
        
        data = get_completion_from_messages(create_msg(item))
        try:
            raw_writer.write(f"{index}\n\n" + data + "\n\n\n")
            data_str = json.loads(data)
            writer.write(json.dumps(data_str, ensure_ascii=False) + "\n")
            print(f"process done {index+1} / {num_para}")
            
        except Exception as e:
            traceback.print_exc()
            print(e)
            print(f"process wrong {index+1} / {num_para}")
            
        time.sleep(1)
        index += 1 
    if token_count > max_count:
        max_count = token_count
        max_item = item
        
# print(f"MAX TOKEN COUNT IS {max_count}")
# print(max_item)
print(len(tmp_para_list))

Created a chunk of size 1972, which is longer than the specified 1763
Created a chunk of size 2452, which is longer than the specified 1763
Created a chunk of size 4954, which is longer than the specified 1763


process done 1 / 15
process done 2 / 15
process done 3 / 15
process done 4 / 15
process done 5 / 15
process done 6 / 15
process done 7 / 15
process done 8 / 15
process done 9 / 15
process done 10 / 15
process done 11 / 15
process done 12 / 15
process done 13 / 15
process done 14 / 15
15


In [79]:
## merge
lines = open(openai_out, 'r').readlines()
data = []
for line in lines:
    item_info_list = json.loads(line.strip())
    result = {}
    for item in item_info_list:
        for key, value in item.items():
            if value.find("未提及") != -1 or value.find("不适用") != -1:
                continue
            else:
                result[key] = value
                # print(key, value)
    if len(result) > 0:
        data.append(result)
openai_merge_out = 'output/600271.kg_industry_ana.openaiout.merge.jsonl'
openai_merge_raw_out = 'output/600271.kg_industry_ana.openaiout.merge.raw'
writer = open(openai_merge_out, 'w')
raw_writer = open(openai_merge_raw_out, "w")
template_name = "./kg/kg_industry_ana_merge.txt"
template_str = open(template_name).read()
import jinja2
environment = jinja2.Environment()
template = environment.from_string(template_str)
prompts = template.render(enumerate=enumerate, lines=data)
# print(prompts)
open("/tmp/tmp.tmp.txt", 'w').write(prompts)
print(count_string_tokens(prompts))

data = get_completion_from_messages(create_msg(prompts), model='gpt-3.5-turbo-16k-0613')
try:
    print(data)
    raw_writer.write(f"{index}\n\n" + data + "\n\n\n")
    data_str = json.loads(data)
    writer.write(json.dumps(data_str, ensure_ascii=False) + "\n")

            
except Exception as e:
    traceback.print_exc()
    print(e)
    print(f"process wrong {index+1} / {num_para}")

2886
[
    {"公司所处的行业是什么,公司从事哪些主要的什么业务": "航天信息公司从事信息技术业务，主要涉及数字政府和企业数字化两个产业"},
    {"行业基本情况、周期性特点": "信息技术行业处于成熟阶段，市场规模较大，竞争激烈，增长趋缓"},
    {"公司所处的行业地位怎样？": "航天信息公司在信息技术行业中具有一定的市场份额，是国内领先的信息技术集团之一"},
    {"最近是否有新的和公司所处行业相关的的法律、行政法规、部门规章、行业政策？这些内容对行业有什么影响": "最近发布了《关于加强数字政府建设的指导意见》和《关于构建数据基础制度更好发挥数据要素作用的意见》，对行业发展有积极影响"},
    {"管理层对行业格局和趋势的分析和展望是什么，公司在这种情况下的优势和劣势是什么": "管理层认为数字经济、数字化产业发展趋势明显，公司在数字政府、企业数字化领域具备领先地位和一定优势基础，但需要保持竞争优势和适应变化的能力"}
]
