In [None]:
! pip install langchain

In [1]:
data_dir = 'data/state_of_the_union/'

In [2]:
import os

paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]
texts = [open(path, 'r').read().replace('\n\n', '\n') for path in paths if path.endswith('.txt')]

In [3]:
from langchain.text_splitter import CharacterTextSplitter

In [4]:
text_splitter = CharacterTextSplitter(separator = ".", chunk_size=3000, chunk_overlap=100)
docs = text_splitter.create_documents([texts[0]])

In [5]:
from langchain.prompts.base import BasePromptTemplate
from langchain import LLMChain, OpenAI, PromptTemplate
from langchain.text_splitter import TextSplitter
from typing import Any, Dict, Optional, List
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.summarize import load_summarize_chain, refine_prompts
from langchain.schema import LLMResult, Generation


class RefineLLMChain(LLMChain):
    """Chain to run queries against LLMs. Queries can be large and will be broken up and
    processed incrementally using the Refine summarize strategy. Only one input key allowed.

    Example:
        .. code-block:: python

            from langchain import LLMChain, OpenAI, PromptTemplate
            prompt_template = "Tell me a {adjective} joke"
            prompt = PromptTemplate(
                input_variables=["adjective"], template=prompt_template
            )
            llm = LLMChain(llm=OpenAI(), prompt=prompt)
    """
    
    text_splitter: TextSplitter
    """TextSplitter used to split up the input text if necessary (e.g. input text is too long to use for a single call
    to LLM)
    """
    question_prompt: BasePromptTemplate = refine_prompts.PROMPT
    """Question prompt for refine chain
    """
    refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT
    """Refine prompt for refine chain
    """

    # def _call(
    #     self,
    #     inputs: Dict[str, Any],
    #     run_manager: Optional[CallbackManagerForChainRun] = None,
    # ) -> Dict[str, str]:
    #     text = inputs[self.prompt.input_variables[0]] # do some input validation here. checking that there is only one key in the dict
    #     docs = self.text_splitter.create_documents([text])
    #     with open("doc_nums.txt", "a") as myfile:
    #         myfile.write(f"appended text {len(docs)}")
    #     refine_chain = load_summarize_chain(self.llm, chain_type="refine")
    #     output = refine_chain({"input_documents": docs}, return_only_outputs=True)
    #     return {'text': output['output_text']}


    def generate( # this is a hack and should be refactored
        self,
        input_list: List[Dict[str, Any]],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> LLMResult:
        """Generate LLM result from inputs."""
        # prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
        # output = self.llm.generate_prompt(
        #     prompts,
        #     stop,
        #     callbacks=run_manager.get_child() if run_manager else None,
        #     **self.llm_kwargs,
        # )
        refine_chain = load_summarize_chain(
            self.llm, chain_type="refine",
            refine_prompt=self.refine_prompt,
            question_prompt=self.question_prompt, 
        )
        generations = []
        for input_item in input_list:
            text = input_item[self.prompt.input_variables[0]] # do some input validation here. checking that there is only one key in the dict
            docs = self.text_splitter.create_documents([text])
            output = refine_chain(
                {"input_documents": docs}, return_only_outputs=True
            )
            generations.append([Generation(text=output["output_text"])])
        return LLMResult(generations=generations)
    
    """
    todo: implement async call
    """

In [6]:
llm1 = OpenAI()
llm2 = OpenAI() # can probably just use one variable here

In [7]:
refine_text_splitter = CharacterTextSplitter(separator = ".", chunk_size=1000, chunk_overlap=100)

In [8]:
prompt_template = "{text}"
prompt = PromptTemplate(
    input_variables=["text"], template=prompt_template
) # kinda hacky but works
refineLLMChain = RefineLLMChain(llm=llm1, prompt=prompt, text_splitter=refine_text_splitter)

Test the output of the custom RefineLLMChain

In [9]:
refineLLMChain({"text": docs[0].page_content})

{'text': '\n\nThree years ago, the US launched the Great American Comeback, rejecting the mentality of American decline and downsizing of its destiny. Since then, President Trump has moved rapidly to revive the economy, slashing job-killing regulations, enacting historic tax cuts and fighting for fair trade agreements. This has resulted in 7 million new jobs, 5 million more than predicted, and the lowest unemployment rate in over half a century. The economy is the best it has ever been, the military is completely rebuilt, borders are secure, families are flourishing, values are renewed, pride is restored, and the state of the Union is stronger than ever before. This vision demonstrates how we are building the world’s most prosperous and inclusive society. With the reversal of the failed economic policies of the previous administration, the world is now witnessing this great economic success with the unemployment rate for African Americans, Hispanic Americans, and Asian Americans reachi

In [10]:
refine_chain = load_summarize_chain(llm2, chain_type="refine")

In [11]:
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain

map_reduce_chain = MapReduceDocumentsChain(
    llm_chain=refineLLMChain,
    combine_document_chain=refine_chain,
    return_intermediate_steps=True
)

Test the custom MapReduceDocumentChain that uses a refine summarizing strategy for both the map and reduce steps

In [12]:
output = map_reduce_chain({"input_documents": docs[:3]})

In [13]:
output

{'input_documents': [Document(page_content='Three years ago, we launched the great American comeback.  Tonight, I stand before you to share the incredible results.  Jobs are booming, incomes are soaring, poverty is plummeting, crime is falling, confidence is surging, and our country is thriving and highly respected again.  (Applause.)  America’s enemies are on the run, America’s fortunes are on the rise, and America’s future is blazing bright.\nThe years of economic decay are over.  (Applause.)  The days of our country being used, taken advantage of, and even scorned by other nations are long behind us.  (Applause.)  Gone too are the broken promises, jobless recoveries, tired platitudes, and constant excuses for the depletion of American wealth, power, and prestige.\nIn just three short years, we have shattered the mentality of American decline, and we have rejected the downsizing of America’s destiny.  We have totally rejected the downsizing.  We are moving forward at a pace that was 

### Now we will use a custom MapReduceDocumentsChain to summarize and contrast the State of the Union speeches of President Obama, Trump, and Biden

In [14]:
prompt_template = """Write a concise summary of the following:

{text}

CONCISE SUMMARY:"""
PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
refine_template = (
    "Your job is to produce a final summary constrasting the state of the union speeches of President Biden, Trump, and Obama.\n"
    "We have provided an existing summary up to a certain point: {existing_answer}\n"
    "We have the opportunity to refine the existing summary"
    "(only if needed) with some more context below.\n"
    "------------\n"
    "{text}\n"
    "------------\n"
    "Given the new context, refine the original summary."
    "If the context isn't useful, return the original summary."
)
refine_prompt = PromptTemplate(
    input_variables=["existing_answer", "text"],
    template=refine_template,
)

refine_chain = load_summarize_chain(llm2, chain_type="refine", question_prompt=PROMPT, refine_prompt=refine_prompt)

In [15]:
refine_text_splitter = CharacterTextSplitter(separator = ".", chunk_size=3000, chunk_overlap=100)

In [16]:
prompt_template = "{text}"
prompt = PromptTemplate(
    input_variables=["text"], template=prompt_template
) # kinda hacky but works
refineLLMChain = RefineLLMChain(llm=llm1, prompt=prompt, text_splitter=refine_text_splitter)

In [17]:
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain

map_reduce_chain = MapReduceDocumentsChain(
    llm_chain=refineLLMChain,
    combine_document_chain=refine_chain,
    return_intermediate_steps=True
)

In [18]:
president_names = [ path.split('/')[-1].replace('.txt', '') for path in paths if path.endswith('.txt')]

In [19]:
from langchain.schema import Document

docs = [Document(page_content=f"President {president_names[index]}:\n{text}") for index, text in enumerate(texts)]

In [None]:
output = map_reduce_chain({"input_documents": docs})