# Generating text with LangChain and Huggingface

We will start by setting up a standartd huggingface pipeline from our local Vicuna model. From there, it can be used as a normal Langchain LLM.

In [32]:
import pandas as pd
import re

from transformers import pipeline, LlamaForCausalLM
from accelerate import Accelerator
import torch

from langchain.llms import HuggingFacePipeline
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain import PromptTemplate
from langchain.output_parsers.regex_dict import RegexDictParser

In [33]:
model_location = '/home/jovyan/project-archive/vicuna-7b'

model = LlamaForCausalLM.from_pretrained(
        model_location,
        load_in_8bit=True,
        torch_dtype=torch.float16,
        device_map={'': Accelerator().local_process_index},
        max_length=4096
    )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
pipe = pipeline(model=model,
                tokenizer=model_location,
                use_fast=False,
                task='text-generation',
                model_kwargs={'load_in_8bit': True},
                max_length=2048,
                temperature=0.9,
                top_p=0.95,
                repetition_penalty=1.1,
               )

In [36]:
llm = HuggingFacePipeline(pipeline=pipe)

## Data

In [37]:
df = pd.read_csv('../data/subsections.csv')
subsections = df.clean_text

## Langchain

The questions are looking fairly good. Now let's see if we can first extract the automatically generated questions reliably. Then, we will work on generating answers to those questions with the same model using LangChain.

### Prompt

In [38]:
inference_template = (
    #'The following is a passage from a macroeconomics textbook. 
    'Please provide an inference comprehension question about this passage to assess the learner\'s understanding. '
    'An inference comprehension question will ask the learner to make an educated guess or draw a conclusion based on the information presented in a passage or text. '
    'The learner should be able to adequately answer the question in one or two sentences.\n\n'
    '{source}\n\n'
    'Inference Question:\n\n'
)

inference_prompt = PromptTemplate(
    input_variables=['source'],
    template=inference_template,
)

In [39]:
recall_template = (
    #'The following is a passage from a macroeconomics textbook. 
    'Please provide a recall comprehension question about this passage to assess the learner\'s understanding. '
    'A recall comprehension question will ask the learner to remember specific details or information from the passage. '
    'The learner should be able to adequately answer the question in one or two sentences.\n\n'
    '{source}\n\n'
    'Recall Question:\n\n'
)

recall_prompt = PromptTemplate(
    input_variables=['source'],
    template=recall_template,
)

In [40]:
summary_template = (
    #'The following is a passage from a macroeconomics textbook. 
    'Please provide a summary comprehension question about this passage to assess the learner\'s understanding. '
    'A summary comprehension question will ask the learner to provide a brief overview of the main points or ideas in the passage. '
    'The learner should be able to adequately answer the question in one or two sentences.\n\n'
    '{source}\n\n'
    'Summary Question:\n\n'
)

summary_prompt = PromptTemplate(
    input_variables=['source'],
    template=summary_template,
)

In [41]:
inference_chain = LLMChain(llm=llm, prompt=inference_prompt)
recall_chain = LLMChain(llm=llm, prompt=recall_prompt)
summary_chain = LLMChain(llm=llm, prompt=summary_prompt)

In [42]:
recall_chain.run(subsections.sample().item())

'What is the aggregate supply (AS) curve?\n\nAnswer:\n\nThe aggregate supply (AS) curve is an upward sloping curve that shows the total quantity of output (real GDP) that firms will produce and sell at each price level.'

In [43]:
correct_response_template = (
    'The following is a passage from a macroeconomics textbook. Use the passage to generate a correct response to the question. '
    'The response should fully and directly address the question with no conceptual or factual errors. '
    'The response should be written in the voice of a student who has carefully read and understood the passage. '
    'The response should be written in 1-2 complete sentences.\n\n'
    'Passage:\n{source}\n\n'
    'Question:\n{question}\n\n'
    'Response:\n'
)

correct_response_prompt = PromptTemplate(
    input_variables=['source', 'question'],
    template=correct_response_template,
)

correct_response_chain = LLMChain(llm=llm, prompt=correct_response_prompt)

In [44]:
incorrect_response_template = (
    #'The following is a passage from a macroeconomics textbook. 
    'Use the passage to generate an incorrect response to the question. '
    'The response will contain errors'#, or it may fail to directly address the question. '
    'The response should be written in the voice of a student who has not fully understood the passage. It will be obviously incorrect. '
    'The reponse should be written in 1 complete sentence.\n\n'
    'Passage:\n{source}\n\n'
    'Question:\n{question}\n\n'
    'Response:\n'
)

incorrect_response_prompt = PromptTemplate(
    input_variables=['source', 'question'],
    template=incorrect_response_template,
)

incorrect_response_chain = LLMChain(llm=llm, prompt=incorrect_response_prompt)

# incorrect_response_template = (
#     #'The following is a passage from a macroeconomics textbook. 
#     'Use the passage to generate an incorrect response to the question. '
#     'The response will contain conceptual misunderstandings or factual errors'#, or it may fail to directly address the question. '
#     'The response should be written in the voice of a student who has not fully understood the passage. It will be obviously incorrect. '
#     'The reponse should be written in 1-2 complete sentences.\n\n'
#     'Passage:\n{source}\n\n'
#     'Question:\n{question}\n\n'
#     'Response:\n'
# )

# incorrect_response_prompt = PromptTemplate(
#     input_variables=['source', 'question'],
#     template=incorrect_response_template,
# )

# incorrect_response_chain = LLMChain(llm=llm, prompt=incorrect_response_prompt)

In [45]:
response_template = (
    #'The following is a passage from a macroeconomics textbook. 
    'Use the passage to generate one correct and one incorrect response to the question. '
    'The correct response will fully and directly address the question based on information from the passage. It will be free of conceptual and factual errors. '
    # 'The incorrect response will contain conceptual misunderstandings or factual errors, or it will provide extraneous information that does not address the question. '
    'The incorrect response will be obviously incorrect, as if written by a student who did not read the passage. '
    'Both the correct and the incorrect responses should be written in 1-2 complete sentences.\n\n'
    'Passage:\n{source}\n\n'
    'Question:\n{question}\n\n'
    'Correct Response:\n'
)

response_prompt = PromptTemplate(
    input_variables=['source', 'question'],
    template=response_template,
)

response_chain = LLMChain(llm=llm, prompt=response_prompt)

In [46]:
source = subsections.sample().item()
question = summary_chain.run(source)
correct_response = correct_response_chain.run(source=source, question=question)
incorrect_response = incorrect_response_chain.run(source=source, question=question)
#response = response_chain.run(source=source, question=question)

In [47]:
print(source)
print('-'*80)
print(f'Question: {question}')
print('-'*80)
print(f'Correct Response: {correct_response}')
print('-'*80)
print(f'Incorrect Response: {incorrect_response}')
# print('-'*80)
#print(f'Correct Response: {response}')

Unemployment is not distributed evenly across the U.S. population. Figure 7.3 shows unemployment rates broken down in various ways: by gender, age, and race/ethnicity.
figure7.3a. Graph a shows the trends in unemployment rates by gender for the year 1972 to 2014. In 1972 the graph starts out at 6.6% for females. It jumps to 9.3% in 1975 for females, gradually goes back down until 2009, when it rises to 8.1%. It gradually lowers to 6.1% in 2014 for females. For males, it starts out at  around 5% in 1972, goes up and down periodically, and ends at 6.3% in 2014.  
Figure 7.3 (a) Unemployment Rates by Gender (Source: www.bls.gov)
--------------------------------------------------------------------------------
Question: What are the trends in unemployment rates by gender from 1972 to 2014?
--------------------------------------------------------------------------------
Correct Response: The trends in unemployment rates by gender from 1972 to 2014 as shown in figure 7.3a are that for females

In [48]:
class ComprehensionQuestionChain(Chain):
    summary_chain: LLMChain
    recall_chain: LLMChain
    inference_chain: LLMChain
    correct_response_chain: LLMChain
    incorrect_response_chain: LLMChain    

    @property
    def input_keys(self):
        return ['source']

    @property
    def output_keys(self):
        return [
            'summary_question', 'summary_correct_response', 'summary_incorrect_response',
            'recall_question', 'recall_correct_response', 'recall_incorrect_response',
            'inference_question', 'inference_correct_response', 'inference_incorrect_response',
        ]

    def _call(self, inputs):
        summary_question = self.summary_chain.run(inputs)
        recall_question = self.recall_chain.run(inputs)
        inference_question = self.inference_chain.run(inputs)
        ### I may need to parse the outputs of these chains...
        return {
            'summary_question': summary_question,
            'summary_correct_response': self.correct_response_chain.run(source=source, question=summary_question),
            'summary_incorrect_response': self.incorrect_response_chain.run(source=source, question=summary_question),
            'recall_question': recall_question,
            'recall_correct_response': self.correct_response_chain.run(source=source, question=recall_question),
            'recall_incorrect_response': self.incorrect_response_chain.run(source=source, question=recall_question),
            'inference_question': inference_question,
            'inference_correct_response': self.correct_response_chain.run(source=source, question=inference_question),
            'inference_incorrect_response': self.incorrect_response_chain.run(source=source, question=inference_question),
        }

comprehension_question_chain = ComprehensionQuestionChain(
    summary_chain=summary_chain,
    recall_chain=recall_chain,
    inference_chain=inference_chain,
    correct_response_chain=correct_response_chain,
    incorrect_response_chain=incorrect_response_chain,
)

In [None]:
import datasets

def get_output(batch):
    return comprehension_question_chain(batch['raw_text'])
    

ds = datasets.Dataset.from_pandas(df)
ds1 = ds.map(get_output, batched=False)

Map:   0%|          | 0/523 [00:00<?, ? examples/s]

Input length of input_ids is 2259, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 2254, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 2264, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 2277, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 2272, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
Input length of input_ids is 2282, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.


In [None]:
ds1.to_pandas().drop(columns = ['__index_level_0__', 'source']).to_csv('../results/vicuna_aqg.csv')

### Output Parser

Langchain reallys wants us to use a JSON or Pydantic parser. I highly doubt LLaMA-7B can reliably output structured responses. Let's try to build something with regex that fails gracefully.

In [198]:
class RegexParser(RegexDictParser):
    '''Overriding the parse method so that it does not escape regex patterns.
    I need to match the first question at the beginning of the string with the regex '^' special character'''
    def parse(self, text):
        result = {}
        for output_key, expected_format in self.output_key_to_format.items():
            specific_regex = self.regex_pattern.format(expected_format)
            matches = re.findall(specific_regex, text)
            if not matches:
                print(
                    f"No match found for output key: {output_key} with expected format ```{expected_format}``` on text ```{text.strip()}```"
                )
                result[output_key] = '' # we can add in a retry function to try again if the model fails. for now, we will just return an empty string.
            elif len(matches) > 1:
                raise ValueError(
                    f"Multiple matches found for output key: {output_key} with expected format ```{expected_format}``` on text ```{text.strip()}```"
                )
            elif (
                self.no_update_value is not None and matches[0] == self.no_update_value
            ):
                continue
            else:
                result[output_key] = matches[0]
        return result

In [199]:
output_key_to_format = {'Question 1': '^', # for the first question, we need to match the beginning of the string.
                        'Question 2': 'Question 2:'}

re_parser = RegexParser(
    regex_pattern=r'{}\s*(.*?)(?=\n|$)', # searches for the key, a colon, any whitespace, and then matches on all the characters that follow until a linebreak or the end of string.
    output_key_to_format=output_key_to_format,
    no_update_value='N/A'
)

In [200]:
for sample in subsections.sample(15):
    output = chain.run(sample)
    try:
        questions = re_parser.parse(output)
        print('Parsed Output:', questions)
    except ValueError as e:
        print('Failed Parse:', e)

Parsed Output: {'Question 1': 'What does the author mean by "demand" in the context of economics?', 'Question 2': 'According to the passage, what are the two key components that determine the shape of a demand curve?'}
Parsed Output: {'Question 1': 'Why do economists consider the ability to pay when measuring demand?', 'Question 2': 'How does the law of demand relate to the price and quantity demanded of a good or service?'}
Parsed Output: {'Question 1': 'What is the main argument presented in this passage?', 'Question 2': 'How might high-income countries influence low-income countries to adopt stronger environmental standards without resorting to protectionism?'}
Parsed Output: {'Question 1': 'What is the main idea of the passage?', 'Question 2': 'Why should sunk costs not affect the current decision according to the budget constraint framework?'}
Parsed Output: {'Question 1': 'What is the difference between the aggregate supply and aggregate demand model and the microeconomic analysi