In [1]:
from functools import cache
import json
from typing import Protocol
from dataclasses import dataclass
import yaml
from APITool import Msg, chat, Messages, create_client
from DBTool import DatabaseLike, SQliteDB




class PiplineLike(Protocol):
    def invoke(self, sample:Messages) -> Messages: ...

class Rewriter:
    @dataclass
    class _Prompt:
        prompt:str
        temperature:float
        fewshot:list[Messages]

        def apply(self, user:str, assistant:str) -> str:
            # may add few shots next time
            replaced_prompt = self.prompt.replace("{{user}}", user).replace("{{assistant}}", assistant)
            return replaced_prompt


    @staticmethod
    def _read_yaml_prompt_file(file_path:str) -> list[_Prompt]:
        with open(file_path, encoding='u8') as f:
            results = yaml.safe_load(f)
        assert len(results) > 0, f"{file_path} is an empty file"
        return [Rewriter._Prompt(**res) for res in results]

    def __init__(self, prompt_path:str) -> None:
        self.steps:list[Rewriter._Prompt] = self._read_yaml_prompt_file(prompt_path)
        self._client = create_client()

    def invoke(self, sample:Messages) -> Messages:
        assert len(sample) == 2
        user:str = sample[0]["content"]
        assistant:str = sample[1]["content"]
        for curstep in self.steps:
            assistant, _ = chat(self._client, curstep.apply(user, assistant), temperature=curstep.temperature)
        # post process
        l = assistant.find('```json')+7
        r = assistant.find('```',l)
        assistant = json.loads(assistant[l:r])["B"]
        
        result = [Msg("user", user), Msg("assistant", assistant)]
        return result

    @classmethod
    def from_prompt_file(cls, file_path:str) -> 'Rewriter':
        return cls(file_path)


@cache
def get_random_conv(db:DatabaseLike, num:int, seed:float=114.514) -> list[Messages]:
    result = db.random_get(num,seed)
    conversations = [res[6] for res in result]
    conversations = map(
        lambda x: 
            x.replace('"from": "human"', '"role": "user"')
             .replace('"from": "gpt"', '"role": "assistant"')
             .replace('"value"', '"content"'),
        conversations
        )
    result = [json.loads(conv) for conv in conversations]
    return result

def apply_pipline(pipline:PiplineLike, sample:Messages) -> Messages:
    result = pipline.invoke(sample)
    return result


In [2]:
db = SQliteDB.connect_db("./data/magpieQwen.db")

In [4]:
# main
pipline = Rewriter.from_prompt_file("PiplinePrompt.yaml")
samples = get_random_conv(db, 2)
for samp in samples:
    results = apply_pipline(pipline, samp)
    print(results)


[
user: 假设你需要自助机器人为一个在线销售平台开发一套有利于吸引顾客的自主推荐策略。请根据客户数据（例如浏览历史、购物车项目、购买历史、个人喜好，等等）与产品信息（例如类别信息、销售数量、评价、物品标签，等等）进行输出。
, 
assistant: 明白了！首先得收集并预处理客户数据和产品信息，然后构建用户画像和商品分析，选个合适的推荐算法，设计个性化推荐、热门商品推荐、交叉销售等策略。实时更新优化，同时确保隐私保护合规。这样能提升用户体验和商业价值哦~ (๑•̀ㅂ•́)و✧
]
[
user: 利用机器学习算法预测用户对于某种商品的偏好
, 
assistant: 首先得收集和预处理用户行为和商品信息数据，然后搞特征工程，选个合适的模型，比如协同过滤、矩阵分解或者深度学习模型，训练评估一下，最后优化部署，就能实现个性化推荐啦~ (≧▽≦)
]


In [None]:
db.close()