<a href="https://colab.research.google.com/github/eduseiti/ia368v_dd_class_04/blob/main/CoQa_via_prompt_engineering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Apply zero-shot and few-shot learning with pretrained Language Models on the [Conversational Question Answering Challenge (CoQA) dataset](https://stanfordnlp.github.io/coqa/)

In [None]:
import requests
import os
import numpy as np
import pandas as pd

import pickle
from google.colab import drive

import json

import time

import re

from datetime import datetime

In [None]:
WORKING_FOLDER="drive/MyDrive/unicamp/ia368v_dd/aula_04"
COQA_DEV_SET="https://nlp.stanford.edu/data/coqa/coqa-dev-v1.0.json"
API_ACCESS="API_access_info.json"

COQA_EVALUATION_SCRIPT="https://nlp.stanford.edu/data/coqa/evaluate-v1.0.py"

Connect to Google Drive, as usual

In [None]:
drive.mount('/content/drive', force_remount=True)

In [None]:
os.chdir(WORKING_FOLDER)

Download the CoQa development set

In [None]:
if not os.path.exists(os.path.basename(COQA_DEV_SET)):
    !wget {COQA_DEV_SET}
else:
    print("CoQa development dataset already downloaded...")

Read and explore the development set

In [None]:
with open(os.path.basename(COQA_DEV_SET), 'r') as inputFile:
    dev_set = json.load(inputFile)

In [None]:
dev_set.keys()

In [None]:
len(dev_set['data'])

Download the evaluation script

In [None]:
if not os.path.exists(os.path.basename(COQA_EVALUATION_SCRIPT)):
    !wget {COQA_EVALUATION_SCRIPT}
else:
    print("Evaluation script already downloaded...")

Now, create templates for zero-shot and few-shot learning

In [None]:
TASK_PROMPT = "Read the text, answer the questions and transcribe the text portion supporting your answer:\n\n"
TASK_PROMPT_NO_TRANSCRIPTION = "Read the text and answer the questions:\n\n"


ZERO_SHOT_FIRST_QUESTION_TEMPLATE="Text: {}\n\nQuestion: {} Answer the question and transcribe the sentence where you found it."
ZERO_SHOT_NEXT_QUESTIONS_TEMPLATE="\nAnswer: {}\nTranscription: {}\n\nQuestion: {}"

ZERO_SHOT_FIRST_QUESTION_TEMPLATE_NO_TRANSCRIPTION="Text: {}\n\nQuestion: {}"
ZERO_SHOT_NEXT_QUESTIONS_TEMPLATE_NO_TRANSCRIPTION="\nAnswer: {}\n\nQuestion: {}"


FEW_SHOT_TEMPLATE="Example text: {}\n\nExample question: {}\nExample answer: {}\nExample transcription: {}\n\n\n\nText: {}\n\nQuestion: {}"
FEW_SHOT_SEQUENCE_TEMPLATE="Text: {}\n\nQuestion: {}"
FEW_SHOT_SEQUENCE_ADDITIONAL_QUESTION_TEMPLATE = "\nAnswer: {}\nTranscription: {}\n\nQuestion: {}"

FEW_SHOT_TEMPLATE_NO_TRANSCRIPTION="Example text: {}\n\nExample question: {}\nExample answer: {}\n\n\n\nText: {}\n\nQuestion: {}"
FEW_SHOT_SEQUENCE_ADDITIONAL_QUESTION_TEMPLATE_NO_TRANSCRIPTION = "\nAnswer: {}\n\nQuestion: {}"

In [None]:
FEW_SHOT_QUERY_TYPE="few_shot"
ZERO_SHOT_QUERY_TYPE="zero_shot"

Results filename format:   

```
    test_<llama|text-davinci-003|code-davinci-002>_<few_shot|zero_shot>_<YYYYMMDD_HHMMSS>.json
```



In [None]:
TEST_RESULTS_FILENAME_FORMAT="test_{}_{}_{}.json"

In [None]:
LLAMA_API_DATA_PACKAGE={"prompt": None,
                        "temperature": 0.0,
                        "top_p": 1,
                        "max_length": 100}

OPENAI_API_QUERY_PARAMS={"model": "code-davinci-002",
                         "prompt": None,
                         "temperature": 0,
                         "max_tokens": 100,
                         "top_p": 1,
                         "frequency_penalty": 0,
                         "presence_penalty": 0}

In [None]:
LLAMA_RESPONSE_REGEX=".*[\n\r]*[a|A]nswer:(.+)[\n\r].*[t|T]ranscription[s]?:(.+)[\n\r]?"
LLAMA_RESPONSE_EMBEDDED_TRANSCRIPTION=".*[\n\r]*[a|A]nswer:(.+)[\.](.+)[\n\r]?"
LLAMA_RESPONSE_NO_TRANSCRIPTION_REGEX=".*[\n\r]*[a|A]nswer:(.+)[\n\r]?"

## Define functions to access the Language Models APIs

In [None]:
def build_request_prompt(query_type, i, prompt_text, request_prompt, example_entry, test_entry, current_responses, ask_transcription):

    if query_type == FEW_SHOT_QUERY_TYPE:
        if i == 0:
            #
            # First time the prompt contains an example
            #

            if ask_transcription:
                request_prompt = prompt_text + FEW_SHOT_TEMPLATE.format(example_entry['story'], 
                                                                        example_entry['questions'][0]['input_text'],
                                                                        example_entry['answers'][0]['input_text'],
                                                                        example_entry['answers'][0]['span_text'],
                                                                        test_entry['story'],
                                                                        test_entry['questions'][i]['input_text'])
            else:
                request_prompt = prompt_text + FEW_SHOT_TEMPLATE_NO_TRANSCRIPTION.format(example_entry['story'], 
                                                                                         example_entry['questions'][0]['input_text'],
                                                                                         example_entry['answers'][0]['input_text'],
                                                                                         test_entry['story'],
                                                                                         test_entry['questions'][i]['input_text'])
        else:
            #
            # For all the subsequent questions, the prompt will accumulate the answers, as the questions are
            # conversational ― i.e. they build in one another.
            #

            if i == 1:
                request_prompt = prompt_text + FEW_SHOT_SEQUENCE_TEMPLATE.format(test_entry['story'],
                                                                                    test_entry['questions'][i - 1]['input_text'])

            if ask_transcription:
                request_prompt += FEW_SHOT_SEQUENCE_ADDITIONAL_QUESTION_TEMPLATE.format(current_responses[i - 1]['answer'],
                                                                                        current_responses[i - 1]['transcription'],
                                                                                        test_entry['questions'][i]['input_text'])
            else:
                request_prompt += FEW_SHOT_SEQUENCE_ADDITIONAL_QUESTION_TEMPLATE_NO_TRANSCRIPTION.format(current_responses[i - 1]['answer'],
                                                                                                         test_entry['questions'][i]['input_text'])
    elif query_type == ZERO_SHOT_QUERY_TYPE:
        if i == 0:
            #
            # First time only contains the text and the question
            #

            if ask_transcription:
                request_prompt = prompt_text + ZERO_SHOT_FIRST_QUESTION_TEMPLATE.format(test_entry['story'],
                                                                                        test_entry['questions'][i]['input_text'])
            else:
                request_prompt = prompt_text + ZERO_SHOT_FIRST_QUESTION_TEMPLATE_NO_TRANSCRIPTION.format(test_entry['story'],
                                                                                                         test_entry['questions'][i]['input_text'])
        else:
            #
            # For all the subsequent questions, the prompt will accumulate the answers, as the questions are
            # conversational ― i.e. they build in one another.
            #

            if ask_transcription:
                request_prompt += ZERO_SHOT_NEXT_QUESTIONS_TEMPLATE.format(current_responses[i - 1]['answer'],
                                                                           current_responses[i - 1]['transcription'],
                                                                           test_entry['questions'][i]['input_text'])
            else:
                request_prompt += ZERO_SHOT_NEXT_QUESTIONS_TEMPLATE_NO_TRANSCRIPTION.format(current_responses[i - 1]['answer'],
                                                                                            test_entry['questions'][i]['input_text'])


    return request_prompt

In [None]:
def query_llama(test_entry, add_prompt=True, query_type=FEW_SHOT_QUERY_TYPE, example_entry=None, ask_transcription=False):

    test_entry_start_time = time.time()

    llama_responses = []

    if add_prompt:
        if ask_transcription:
            prompt_text = TASK_PROMPT
        else:
            prompt_text = TASK_PROMPT_NO_TRANSCRIPTION
    else:
        prompt_text = ""

    request_prompt = ""

    for i in range(len(test_entry['questions'])):

        request_prompt = build_request_prompt(query_type, i, prompt_text, request_prompt, example_entry, test_entry, llama_responses, ask_transcription)

            
        print("--------------------------------------------")
        print("QUESTION #{}".format(i))
        print("--------------------------------------------\n")
        print(request_prompt)

        request_data = LLAMA_API_DATA_PACKAGE
        request_data['prompt'] = request_prompt

        request_start_time = time.time()

        r = requests.post(f"{access_info['LLAMA_API_ENDPOINT']}/complete", json=request_data)

        if r.ok:
            response=r.json()

            request_uuid=response["request_uuid"]

            ready = False
            while not ready:
                r = requests.get(f"{access_info['LLAMA_API_ENDPOINT']}/get_result/{request_uuid}")
                response = r.json()
                ready = response['ready']
                if ready:
                    print(response['generated_text'])

                    elapsed_time = time.time() - request_start_time

                    print("\n>> Request elapsed time: {:.3f}".format(elapsed_time))

                    if ask_transcription:
                        m = re.match(LLAMA_RESPONSE_REGEX, response['generated_text'])
                    else:
                        m = re.match(LLAMA_RESPONSE_NO_TRANSCRIPTION_REGEX, response['generated_text'])

                    if m is None:
                        print("Try another match...")

                        m = re.match(LLAMA_RESPONSE_EMBEDDED_TRANSCRIPTION, response['generated_text'])

                    if m is not None:
                        answer_text = m.group(1).strip()
                        transcription_text = ""

                        if ask_transcription and (len(m.groups()) > 1):
                            transcription_text = m.group(2).strip()


                        llama_responses.append({'id': test_entry['id'],
                                                'turn_id': test_entry['questions'][i]['turn_id'],
                                                'answer': answer_text, 
                                                'transcription': transcription_text})
                    else:
                        print("No match!!!")

                        for byte in bytes(response['generated_text'], 'utf-8'):
                            print(byte, end=" ")

                        #
                        # Add empty response to avoid breaking the treatment.
                        #

                        llama_responses.append({'id': test_entry['id'],
                                                'turn_id': test_entry['questions'][i]['turn_id'],
                                                'answer': "", 
                                                'transcription': ""})
                        

                    if elapsed_time < 20:
                        print("Wait 10 seconds to avoid getting a 429 error...")

                        time.sleep(10)

                    break

                # Wait 10 seconds before checking again

                time.sleep(10)

            print("\n\n")
        else:
            print("\n\nREQUEST FAILED!!!\n\n")

    print("Elapse total of {:.3f} s to execute all the {} queries".format(time.time() - test_entry_start_time, len(test_entry['questions'])))

    return llama_responses

In [None]:
def execute_test(test_set_filename, test_set_data, selected_entries, llm="llama", test_parameters={'example_entry': None,
                                                                                                   'query_type': FEW_SHOT_QUERY_TYPE,
                                                                                                   'add_prompt': True,
                                                                                                   'ask_transcription': False}):
    
    test_start_time = time.time()

    test_responses = []

    test_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_filename = TEST_RESULTS_FILENAME_FORMAT.format(llm, test_parameters['query_type'], test_timestamp)

    executed_test = {'timestamp': test_timestamp,
                     'set': test_set_filename,
                     'set_entries': [int(a) for a in selected_entries],
                     'configuration': test_parameters,
                     'answers': None}

    if llm == "llama":
        for test_entry in [test_set_data['data'][i] for i in selected_entries]:
            test_responses += query_llama(test_entry, **test_parameters)

            # Save the results so far just to make sure they are not lost...

            executed_test['answers'] = test_responses

            print(executed_test)

            with open(results_filename, "w") as outputFile:
                json.dump(executed_test, outputFile, indent=4)

    print("Total elapsed time: {}".format(time.time() - test_start_time))

    return results_filename

Define LLAMA test API endpoint

In [None]:
with open(API_ACCESS) as inputFile:
    access_info = json.load(inputFile)

### Select 5 entries to test

Leave the first story as the few-shot example.

In [None]:
entries_to_test = np.random.choice(list(range(1, len(dev_set['data']))), 5, replace=False)

First, create a reference dataset containing only the tested queries

In [None]:
reference_dataset_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
reference_dataset = {"version": 1.0,
                     "data": dev_set['data'][entries_to_test]}

In [None]:
REFERENCE_DATASET="reference_dataset_{}.json".format(reference_dataset_timestamp)

In [None]:
with open(REFERENCE_DATASET, "w") as outputFile:
    json.dump(reference_dataset, outputFile, indent=4)

## Now execute the test sequence for LLaMA

In [None]:
test_results_files = []

### Execute the tests using few-shot setup with prompt

In [None]:
test_parameters={'example_entry': dev_set['data'][0],
                 'query_type': FEW_SHOT_QUERY_TYPE,
                 'add_prompt': True,
                 'ask_transcription': False}

test_results_files.append(execute_test(os.path.basename(COQA_DEV_SET), dev_set, entries_to_test, test_parameters=test_parameters))

### Now execute using zero-shot setup, no prompt

In [None]:
test_parameters={'example_entry': dev_set['data'][0],
                 'query_type': ZERO_SHOT_QUERY_TYPE,
                 'add_prompt': False,
                 'ask_transcription': False}

test_results_files.append(execute_test(os.path.basename(COQA_DEV_SET), dev_set, entries_to_test, test_parameters=test_parameters))

### Now, execute the tests using few-shot setup without prompt

In [None]:
test_parameters={'example_entry': dev_set['data'][0],
                 'query_type': FEW_SHOT_QUERY_TYPE,
                 'add_prompt': False,
                 'ask_transcription': False}

test_results_files.append(execute_test(os.path.basename(COQA_DEV_SET), dev_set, entries_to_test, test_parameters=test_parameters))

### And finaly execute using zero-shot setup, with prompt

In [None]:
test_parameters={'example_entry': dev_set['data'][0],
                 'query_type': ZERO_SHOT_QUERY_TYPE,
                 'add_prompt': True,
                 'ask_transcription': False}

test_results_files.append(execute_test(os.path.basename(COQA_DEV_SET), dev_set, entries_to_test, test_parameters=test_parameters))

### Now execute the evaluation script for the executed tests

First, create a reference dataset containing only the tested queries

In [None]:
reference_dataset = {"version": 1.0,
                     "data": dev_set['data'][entries_to_test]}

In [None]:
REFERENCE_DATASET="reference_dataset.json"

In [None]:
with open(REFERENCE_DATASET, "w") as outputFile:
    json.dump(reference_dataset, outputFile, indent=4)

### Comments on the evaluation script

The evaluation script computes Exact Match and F1 between the predicted answer and the gold standard.

One comment is that the model will be penalized if it produces verbose answers, even if it contains the correct answer.

In [None]:
for test_result in test_results_files:

    print("\n\n\n---------------------------------------------------")
    print("Evaluation results for {}...".format(test_result))
    print("---------------------------------------------------\n")

    !python evaluate-v1.0.py --data-file {REFERENCE_DATASET} --pred-file {test_result} --human