In [1]:
from functools import cache
import json
import random
from typing import Protocol
from dataclasses import dataclass, field
import yaml
from APITool import Msg, chat, Messages, create_client
from DBTool import DatabaseLike, SQliteDB




class PiplineLike(Protocol):
    def invoke(self, sample:Messages) -> Messages: ...

class PromptLike(Protocol):
    # tuple[str, str] using the second string as false flag
    def apply(self, **kwargs:dict[str,str]) -> str|tuple[str, str]: ...


@dataclass
class RWPrompt:
    res:str
    req:list[str] = field(default="")
    prompt:str = field(default="")
    temperature:float = field(default=0.0)
    json_path:list[str] = field(default_factory=list)
    false_flag:str = field(default="")
    local:dict[str:str|list[str]] = field(default_factory=dict)

    def apply(self, kwargs:dict[str,str]) -> str|tuple[str,str]:
        replaced_prompt = self.prompt
        replaced_prompt = replaced_prompt.replace("{{false_flag}}", self.false_flag)

        # template lerp from outer
        for req in self.req: 
            replaced_prompt = replaced_prompt.replace("{{"+req+"}}",kwargs[req])

        # template lerp from local. where there is a list there is a choice
        for k, v in self.local.items():
            if isinstance(v, list):
                v = random.choice(v)
            assert isinstance(k, str) and isinstance(v, str)
            replaced_prompt = replaced_prompt.replace("{{local_"+k+"}}",v)

        if self.false_flag == "":
            return replaced_prompt
        else:
            return (replaced_prompt,self.false_flag)

class Rewriter:
    @staticmethod
    def _read_yaml_prompt_file(file_path:str) -> list[PromptLike]:
        with open(file_path, encoding='u8') as f:
            results:list[dict] = yaml.safe_load(f)
        assert len(results) > 0, f"{file_path} is an empty file"
        for obj in results:
            obj['req'] = obj['req'].replace(" ", "").strip().split(',')
            if 'json_path' in obj.keys(): obj['json_path'] = obj['json_path'].replace(" ", "").strip().split('.')
        return [RWPrompt(**obj) for obj in results]

    @classmethod
    def from_prompt_file(cls, file_path:str) -> 'Rewriter':
        return cls(file_path)

    def __init__(self, prompt_path:str) -> None:
        self.steps:list[RWPrompt] = self._read_yaml_prompt_file(prompt_path)
        self._client = create_client()
        self.end_point = "output"

    def invoke(self, sample:Messages) -> Messages:
        state:dict[str,str] = {}
        user:str = sample[0]["content"]
        assistant:str = sample[1]["content"]
        state['user'] = user
        state['assistant'] = assistant

        while len(state) != 2+len(self.steps):
            assert len(state) <= 2+len(self.steps)
            for cur in self.steps:
                if cur.res in state.keys():
                    continue

                can_solve:bool = True
                for key in cur.req:
                    if key not in state.keys():
                        can_solve:bool = False
                        break

                if can_solve:
                    instruct = cur.apply(state)
                    if isinstance(instruct, tuple): 
                        instruct, false_flag = instruct
                    state[cur.res], _ = chat(self._client, instruct, temperature=cur.temperature, json_path=cur.json_path)
                    if false_flag in state[cur.res]:
                        state[cur.res] = ""
                else:
                    continue
        
        result = [Msg("user", user).to_dict(), Msg("assistant", state[self.end_point]).to_dict()]
        return result


In [2]:


@cache
def get_random_conv(db:DatabaseLike, num:int, seed:float=114.514) -> list[Messages]:
    result = db.random_get(num,seed)
    conversations = [res[6] for res in result]
    conversations = map(
        lambda x: 
            x.replace('"from": "human"', '"role": "user"')
             .replace('"from": "gpt"', '"role": "assistant"')
             .replace('"value"', '"content"'),
        conversations
        )
    result = [json.loads(conv) for conv in conversations]
    return result

def apply_pipline(pipline:PiplineLike, sample:Messages) -> Messages:
    result = pipline.invoke(sample)
    return result


In [3]:
db = SQliteDB.connect_db("./data/magpieQwen.db")

In [4]:
# main

pipline = Rewriter.from_prompt_file("PipelinePrompt.yaml")
samples = get_random_conv(db, 5, 15.909065)
for samp in samples:
    try:
        results = apply_pipline(pipline, samp)
        print(json.dumps(samp, ensure_ascii=False, indent=2))
        print(json.dumps(results, ensure_ascii=False, indent=2))
        print('\n\n')
    except RuntimeError as err:
        print("WARNING: THIS CANNOT BE SOLVED!")
        print(err)
        print(json.dumps(samp, ensure_ascii=False, indent=2))
        print('\n\n')


[
  {
    "role": "user",
    "content": "CAPM模型对风险和收益的关系的描述是 ______"
  },
  {
    "role": "assistant",
    "content": "CAPM（Capital Asset Pricing Model，资本资产定价模型）对风险和收益的关系的描述是：资产的预期回报率与该资产的系统性风险（由贝塔系数β衡量）成正比。具体来说，CAPM模型认为，一个资产的预期回报率等于无风险利率加上该资产的贝塔系数与市场风险溢价的乘积。公式表示为：\n\nE(Ri) = Rf + βi * (E(Rm) - Rf)\n\n其中：\n- E(Ri) 是资产i的预期回报率\n- Rf 是无风险利率\n- βi 是资产i的贝塔系数，衡量资产i相对于市场组合的系统性风险\n- E(Rm) 是市场组合的预期回报率\n\n这个模型表明，投资者应该为承担额外的系统性风险而获得额外的预期回报。换句话说，只有系统性风险（市场风险）才能获得风险溢价，而非系统性风险（特定风险）可以通过多元化投资来消除，因此不会获得额外的风险溢价。"
  }
]
[
  {
    "role": "user",
    "content": "CAPM模型对风险和收益的关系的描述是 ______"
  },
  {
    "role": "assistant",
    "content": "哎呀，这个简单啦~ CAPM模型就是说，资产的预期回报率跟它的系统性风险（贝塔系数β）成正比。具体点说，一个资产的预期回报率等于无风险利率加上贝塔系数乘以市场风险溢价。公式是酱紫的：E(Ri) = Rf + βi * (E(Rm) - Rf)。其中，E(Ri)是资产i的预期回报率，Rf是无风险利率，βi是资产i的贝塔系数，E(Rm)是市场组合的预期回报率。这个模型告诉我们，投资者因为承担了额外的系统性风险，所以应该得到额外的预期回报。只有系统性风险能拿到风险溢价，非系统性风险可以通过多元化投资来消除，所以不会额外给风险溢价哦~ (๑•̀ㅂ•́)و✧"
  }
]



[
  {
    "role": "user",
    "content": "你对韩服 culp-happy beat有何了解？"
  },
  {
    "

In [5]:
# db.close()