In [None]:
from typing import Any, Dict, List, Optional
from langchain import BasePromptTemplate, LLMChain
from langchain.schema.language_model import BaseLanguageModel
from langchain.callbacks.manager import (
    AsyncCallbackManagerForChainRun,
    CallbackManagerForChainRun,
)
from pydantic import Extra
from langchain.schema import LLMResult
from langchain.chains.base import Chain
from langchain.prompts import PromptTemplate

In [None]:
class RewriteSplitterChain(LLMChain):
    """
    An example of a custom chain.
    """

    prompt: BasePromptTemplate
    """Prompt object to use."""
    llm: BaseLanguageModel
    output_key: str = "text"  #: :meta private:

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the prompt expects.

        :meta private:
        """
        return self.prompt.input_variables

    @property
    def output_keys(self) -> List[str]:
        """Will always return text key.

        :meta private:
        """
        return [self.output_key]

    def generate(
        self,
        input_list: List[Dict[str, Any]],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> LLMResult:
        prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
        return self.llm.generate_prompt(
            prompts=prompts,
            stop=stop,
            callbacks=run_manager.get_child() if run_manager else None,
        )

    async def agenerate(
        self,
        input_list: List[Dict[str, Any]],
        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
    ) -> LLMResult:
        prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
        return await self.llm.agenerate_prompt(
            prompts=prompts,
            stop=stop,
            callbacks=run_manager.get_child() if run_manager else None,
        )

    @property
    def _chain_type(self) -> str:
        return "rewrite_splitter_chain"

    @classmethod
    def from_llm(
        cls,
        llm: BaseLanguageModel,
        **kwargs: Any,
    ) -> Chain:
        template = """文本内容
---------------------
{input}

要求
-------------------
- 根据上面的文本内容，重新写一段文字
- 不要遗漏文本内容的任何一点信息


整理结果
-------------------
"""
        prompt = PromptTemplate.from_template(template=template)
        return cls(
            llm=llm,
            prompt=prompt,
            **kwargs,
        )

In [None]:
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI
import os
from dotenv import load_dotenv

load_dotenv(dotenv_path="env")


gpt35_1 = OpenAI(temperature=0.1, max_tokens=2048, verbose=True)
gpt35_9 = OpenAI(temperature=0.9, max_tokens=2048, verbose=True)
chat_gpt35_1 = ChatOpenAI(temperature=0.1, verbose=True)
chat_gpt35_9 = ChatOpenAI(temperature=0.9, verbose=True)
gpt4 = ChatOpenAI(model_name="gpt-4", temperature=0.9, verbose=True)


rewiter_chain=RewriteSplitterChain.from_llm(llm=gpt4,verbose=True)

import pandas as pd

filename= os.getenv("FILE_NAME")
df=pd.read_csv(f"s3://sagemaker-automated-execution-034700280673-us-east-1/data_sample/{filename}.csv")
ncontent_data=[]
for content in df["content"]:
    re_content= await rewiter_chain.arun(content)
    print(re_content)
    ncontent_data.append(re_content)
df["n_content"]=pd.Series(ncontent_data)
df.to_csv(f"s3://sagemaker-automated-execution-034700280673-us-east-1/data_sample/{filename}_n.csv",index=False)