# TP.08 - TREC Question classification

## Install libraries

In [None]:
!pip install openai
!pip install python-dotenv
!pip install datasets



## Import libraries

In [None]:
import os
import openai
from dotenv import load_dotenv
from datasets import load_dataset
from sklearn.metrics import classification_report

## Load API KEY from ENV variable

In [None]:
load_dotenv()  # take environment variables from .env.
openai.api_key = os.getenv('OPENAI_API_KEY')
openai.organization = os.getenv('OPENAI_ORG')

## Download the dataset

In [None]:
# TREC dataset home: https://cogcomp.seas.upenn.edu/Data/QA/QC/
dataset = load_dataset('trec')
train = dataset['train']
test = dataset['test']


In [None]:
# coarse_label | fine_label (fine_grained)
is_fine_label = True
label_grain = 'fine_label' if is_fine_label else 'coarse_label'

label_list = train.features[label_grain].names

# Definition of labels: https://cogcomp.seas.upenn.edu/Data/QA/QC/definition.html
labels_with_defs = {
  "coarse_label": {
      "ABBR": "abbreviation",
      "ENTY": "entities",
      "DESC": "description and abstract concepts",
      "HUM": "human beings",
      "LOC": "locations",
      "NUM": "numeric values",
  },
  "fine_label": {
      "ABBR:abb": "abbreviation",
      "ABBR:exp": "expression abbreviated",
      "ENTY:animal": "entities of animals",
      "ENTY:body": "entities of organs of body",
      "ENTY:color": "entities of colors",
      "ENTY:cremat": "entities of inventions, books and other creative pieces",
      "ENTY:currency": "entities of currency names",
      "ENTY:dismed": "entities of diseases and medicine",
      "ENTY:event": "entities of events",
      "ENTY:food": "entities of food",
      "ENTY:instru": "entities of musical instrument",
      "ENTY:lang": "entities of languages",
      "ENTY:letter": "entities of letters like a-z",
      "ENTY:other": "entities of other entities",
      "ENTY:plant": "entities of plants",
      "ENTY:product": "entities of products",
      "ENTY:religion": "entities of religions",
      "ENTY:sport": "entities of sports",
      "ENTY:substance": "entities of elements and substances",
      "ENTY:symbol": "entities of symbols and signs",
      "ENTY:techmeth": "entities of techniques and methods",
      "ENTY:termeq": "entities of equivalent terms",
      "ENTY:veh": "entities of vehicles",
      "ENTY:word": "entities of words with a special property",
      "DESC:def": "description and abstract concepts of definition of sth.",
      "DESC:desc": "description and abstract concepts of sth.",
      "DESC:manner": "description and abstract concepts of manner of an action",
      "DESC:reason": "description and abstract concepts of reasons",
      "HUM:gr": "a group or organization of persons",
      "HUM:ind": "an individual human being",
      "HUM:title": "title of a person",
      "HUM:desc": "description of a person",
      "LOC:city": "cities",
      "LOC:country": "countries",
      "LOC:mount": "mountains",
      "LOC:other": "other locations",
      "LOC:state": "states",
      "NUM:code": "numeric values of postcodes or other codes",
      "NUM:count": "numeric values of number of sth.",
      "NUM:date": "numeric values of dates",
      "NUM:dist": "numeric values of linear measures",
      "NUM:money": "numeric values of prices",
      "NUM:ord": "numeric values of ranks",
      "NUM:other": "numeric values of other numbers",
      "NUM:period": "numeric values of the lasting time of sth.",
      "NUM:perc": "numeric values of fractions",
      "NUM:speed": "numeric values of speed",
      "NUM:temp": "numeric values of temperature",
      "NUM:volsize": "numeric values of size, area and volume",
      "NUM:weight": "numeric values of weight",
  }
}

label_defs_used = {}
for label in label_list:
    tmp_def = labels_with_defs[label_grain].get(label, "")
    label_defs_used[label] = tmp_def
    if tmp_def == "":
      print(f"Label {label} with no definition assigned")




## Visualize the dataset

In [None]:
def visualize_dataset(train, label_grain, label_list):
  examples_per_class = 1
  covered_labels = {}
  examples = []
  for sample in train:
      label_idx = sample[label_grain]
      cover_count = covered_labels.get(label_idx, 0)
      if cover_count >= examples_per_class:
          continue
      covered_labels[label_idx] = cover_count + 1
      label_str = label_list[label_idx]
      text = sample['text']
      example = f"Q: {text}\nA: {label_str}\n"
      examples.append(example)
      if len(examples) >= len(label_list) * examples_per_class:
          break
  return examples

examples = visualize_dataset(train, label_grain, label_list)
for example in examples:
    print(example)

Q: How did serfdom develop in and then leave Russia ?
A: DESC:manner

Q: What films featured the character Popeye Doyle ?
A: ENTY:cremat

Q: What fowl grabs the spotlight after the Chinese Year of the Monkey ?
A: ENTY:animal

Q: What is the full form of .com ?
A: ABBR:exp

Q: What contemptible scoundrel stole the cork from my lunch ?
A: HUM:ind

Q: What team did baseball 's St. Louis Browns become ?
A: HUM:gr

Q: What is the oldest profession ?
A: HUM:title

Q: What are liver enzymes ?
A: DESC:def

Q: When was Ozzy Osbourne born ?
A: NUM:date

Q: Why do heavier objects travel downhill faster ?
A: DESC:reason

Q: What is considered the costliest disaster the insurance industry has ever faced ?
A: ENTY:event

Q: What sprawling U.S. state boasts the most airports ?
A: LOC:state

Q: What did the only repealed amendment to the U.S. Constitution deal with ?
A: DESC:desc

Q: How many Jews were executed in concentration camps during WWII ?
A: NUM:count

Q: What articles of clothing are tokens 

## Define prompts

In [None]:
str_defs = ' \n'.join([ f"'{key}' {value}" for key, value in label_defs_used.items() ])
system_prompt = (
f"""
You must categorize questions in one of the following classes.
You must reply only with one class between quotes, ignore the description in parenthesis.
And you must not invent another class, you can only use the following classes:

{str_defs}
"""
)

user_prompt = "Categorize the following question:\n\n"

str_examples = ' \n'.join(examples)
few_shot_prompt = (
f"""
Consider the following example questions:

{str_examples}
"""
)
print(few_shot_prompt)

critic_prompt = (
"""
Let's think step by step and review your previous answer.
Provide an explanation of why the answer is right or wrong and then give your final answer.
"""
)


Consider the following example questions:

Q: How did serfdom develop in and then leave Russia ?
A: DESC:manner
 
Q: What films featured the character Popeye Doyle ?
A: ENTY:cremat
 
Q: What fowl grabs the spotlight after the Chinese Year of the Monkey ?
A: ENTY:animal
 
Q: What is the full form of .com ?
A: ABBR:exp
 
Q: What contemptible scoundrel stole the cork from my lunch ?
A: HUM:ind
 
Q: What team did baseball 's St. Louis Browns become ?
A: HUM:gr
 
Q: What is the oldest profession ?
A: HUM:title
 
Q: What are liver enzymes ?
A: DESC:def
 
Q: When was Ozzy Osbourne born ?
A: NUM:date
 
Q: Why do heavier objects travel downhill faster ?
A: DESC:reason
 
Q: What is considered the costliest disaster the insurance industry has ever faced ?
A: ENTY:event
 
Q: What sprawling U.S. state boasts the most airports ?
A: LOC:state
 
Q: What did the only repealed amendment to the U.S. Constitution deal with ?
A: DESC:desc
 
Q: How many Jews were executed in concentration camps during WWII

## Select a sub-sample of test

In [None]:
import math
test_samples = test[:20]
len_test_samples = len(test_samples['text'])
estimated_sleeps = math.ceil(len_test_samples / 3)
time_of_sleep = 61
estimated_duration_sec = estimated_sleeps * time_of_sleep if estimated_sleeps != 0 else time_of_sleep

print(f"Number of sleeps: {estimated_sleeps}")
print(f"Estimated duration: {estimated_duration_sec} seconds / {(estimated_duration_sec / 60):.2f} mins ")

Number of sleeps: 7
Estimated duration: 427 seconds / 7.12 mins 


## Query ChatGPT

In [None]:
import time

model_gpt = "gpt-3.5-turbo"
def call_gpt(messages):
    response = openai.ChatCompletion.create(
        model=model_gpt,
        messages=messages,
        temperature=0,
        max_tokens=2048
    )
    return response

def find_label_from_gpt_ans(label_list, answer, debug=False):
  print(f"DEBUG answer is: {answer}")
  default_label_index = 0
  for label_index, label_str in enumerate(label_list):
        if label_str in answer:  # TODO: Improve this. (dependo de pruebas en GPT API)
            return label_index
  print(f"No label matched for with answer: {answer}")
  print(f"default_label_index {default_label_index}")
  return default_label_index

use_few_shot = True
use_critic = True

answers = []
questions = test_samples['text']
targets = test_samples[label_grain]
for i, question in enumerate(questions):
    messages = [
        {
            "role": "system",
            "content": system_prompt + (("\n\n" + few_shot_prompt) if use_few_shot else "")
        },
        {
            "role": "user",
            "content": user_prompt + question
        }
    ]
    if i % 3 == 0: # sleep for request limits (3/min) Freemium :S
      range_it = int(3* (i /3) + 3)
      print(f"Iteration: [{i} to {range_it if range_it < len_test_samples else len_test_samples}]/{len_test_samples}")
      time.sleep(time_of_sleep)
    response = call_gpt(messages)
    reply = response.choices[0].message
    answer = reply.content
    if use_critic:
        messages.extend([
            {
                "role": reply.role,
                "content": answer
            },
            {
                "role": "user",
                "content": critic_prompt
            }
        ])
    label_index = find_label_from_gpt_ans(label_list, answer, debug=True)

    answers.append(label_index)
    print(f"Q: {question}\nA: {label_list[label_index]}\nT: {label_list[targets[i]]}\n")

Iteration: [0 to 3]/20
DEBUG answer is: 'NUM:dist' numeric values of linear measures
Q: How far is it from Denver to Aspen ?
A: NUM:dist
T: NUM:dist

DEBUG answer is: 'LOC:other'
Q: What county is Modesto , California in ?
A: LOC:other
T: LOC:city

DEBUG answer is: HUM:ind
Q: Who was Galileo ?
A: HUM:ind
T: HUM:desc

Iteration: [3 to 6]/20
DEBUG answer is: DESC:def
Q: What is an atom ?
A: DESC:def
T: DESC:def

DEBUG answer is: NUM:date
Q: When did Hawaii become a state ?
A: NUM:date
T: NUM:date

DEBUG answer is: NUM:dist
Q: How tall is the Sears Building ?
A: NUM:dist
T: NUM:dist

Iteration: [6 to 9]/20
DEBUG answer is: HUM:ind
Q: George Bush purchased a small interest in which baseball team ?
A: HUM:ind
T: HUM:gr

DEBUG answer is: 'ENTY:plant'
Q: What is Australia 's national flower ?
A: ENTY:plant
T: ENTY:plant

DEBUG answer is: DESC:reason
Q: Why does the moon turn orange ?
A: DESC:reason
T: DESC:reason

Iteration: [9 to 12]/20
DEBUG answer is: DESC:def
Q: What is autism ?
A: DESC:d

## Report evalution metrics

In [None]:
evaluation_report = classification_report(
    targets,
    answers,
    labels=range(len(label_list)),
    target_names=label_list,
    zero_division=1
)

print(f"Label grain: {label_grain}")
print(f"Model GPT: {model_gpt}")
print(f"Zero Shot: {not use_few_shot and not use_critic}")
print(f"Few Shot: {use_few_shot}")
print(f"Reflect: {use_critic}")
print(evaluation_report)

Label grain: fine_label
Model GPT: gpt-3.5-turbo
Zero Shot: False
Few Shot: True
Reflect: True
                precision    recall  f1-score   support

      ABBR:abb       0.00      1.00      0.00         0
      ABBR:exp       1.00      1.00      1.00         0
   ENTY:animal       1.00      1.00      1.00         0
     ENTY:body       1.00      1.00      1.00         0
    ENTY:color       1.00      1.00      1.00         0
   ENTY:cremat       1.00      1.00      1.00         0
 ENTY:currency       0.00      1.00      0.00         0
   ENTY:dismed       0.00      1.00      0.00         0
    ENTY:event       0.00      1.00      0.00         0
     ENTY:food       1.00      1.00      1.00         0
   ENTY:instru       1.00      1.00      1.00         0
     ENTY:lang       1.00      1.00      1.00         0
   ENTY:letter       1.00      1.00      1.00         0
    ENTY:other       1.00      1.00      1.00         0
    ENTY:plant       1.00      1.00      1.00         1
  ENTY:p

# GPT 3.5-turbo benchmarks

In [None]:
%%html
<div>
  <h2>Zero-Shot</h2>
  <iframe src="https://drive.google.com/file/d/1zL1F48K0ruwEfs494YckKpqDF3_CGC-f/preview" width="480" height="480" allow="autoplay"></iframe>
</div>
<div>
  <h2>Few-Shot</h2>
  <iframe src="https://drive.google.com/file/d/1tT5hXIw1E55XLgs6BAbuM4GBL3SzJSYA/preview" width="480" height="480" allow="autoplay"></iframe>
</div>
<div>
  <h2>Few-Shot + Reflect</h2>
  <iframe src="https://drive.google.com/file/d/1paGBmbj4wwHCtXsde2HJfxUs7Xgg4qRe/preview" width="480" height="480" allow="autoplay"></iframe>
</div>

# TODOs:
1. [X] Compare the results of gpt-3.5-turbo on zero-shot, few-shot, reflect, and few-shot + reflect (against the results of gpt4 seen in class).
2. [X] Extend the experiment to fine-grained labels.
3. [X] Write a prompt to get ChatGPT to output some python code to solve the TREC Question Classification problem, and test it here.

# ChatGPT Python code to solve the TREC Question Classification problem

In [None]:
# Prompt from GPT-3 and get python code
# Link of chat: https://chat.openai.com/share/beda9929-9c88-4132-9f6a-4ccf926f0c44
import os
import dotenv
import openai
from datasets import load_dataset
import time  # Importar la biblioteca de tiempo

# Cargar las variables de entorno desde un archivo .env (asegúrate de tenerlo configurado)
dotenv.load_dotenv('.env')

# Configurar la API de OpenAI con tu clave de API
openai.api_key = os.getenv("OPENAI_API_KEY")

# Cargar el conjunto de datos TREC
trec_dataset = load_dataset('trec')

# Dividir los datos en conjuntos de entrenamiento y prueba
train_data = trec_dataset['train']
test_data = trec_dataset['test']

# Preparar los datos para enviar a OpenAI
input_data = train_data['text']
target_labels = train_data['coarse_label']

# Definir las categorías deseadas
categories = train.features['coarse_label'].names

# Función para filtrar y extraer la categoría deseada del texto generado
def extract_category(text, debug=False):
    print(f"DEBUG - Text: {text}")
    for i, category in enumerate(categories):
        if category in text:
            return category, i
    return "Categoría no encontrada", 0

prefix_prompt = ("""
  You must categorize questions in one of the following classes.
  You must reply only with one class between quotes, ignore the description in parenthesis.
  "ABBR" (abbreviation)
  "ENTY" (entities)
  "DESC" (description and abstract concepts)
  "HUM" (human beings)
  "LOC" (locations)
  "NUM" (numeric values)


""")

# Entrenar un modelo de clasificación utilizando la API de OpenAI
model = "text-davinci-003"
predictions = []

size_batch = 21
input_data = input_data[:size_batch]
target_labels = target_labels[:size_batch]

for i, prompt in enumerate(input_data):
    # Agregar una pausa de 1 segundo entre solicitudes para no exceder el límite de velocidad
    if i % 3 == 0:
      time.sleep(61)
    response = openai.Completion.create(
        engine=model,
        prompt=f"{prefix_prompt}The question is: {prompt}",
        max_tokens=32,
        temperature=0.7,
        top_p=1.0,
    )
    (predicted_label, index_from_categories) = extract_category(response.choices[0].text.strip())
    predictions.append(index_from_categories)
    print(f"Q: {prompt}\nA:{predicted_label}\nT:{categories[target_labels[i]]}\n")


"""
# (Esta de mas esto porque se ejecuta el accuracy en el train dataset como se hace originalmente)

# Calcular la exactitud del modelo en los datos de prueba (esto es solo un ejemplo básico)
test_input_data = test_data['text']
test_target_labels = test_data['coarse_label']
test_input_prompts = test_input_data

# Realizar predicciones para los datos de prueba
test_predictions = []

for prompt in test_input_prompts:
    response = openai.Completion.create(
        engine=model,
        prompt=prompt,
        max_tokens=32,
        temperature=0.7,
        top_p=1.0,
    )
    predicted_label = response.choices[0].text.strip()
    test_predictions.append(predicted_label)
"""

# Calcular la exactitud del modelo en los datos de prueba
from sklearn.metrics import accuracy_score, classification_report

accuracy = accuracy_score(target_labels, predictions)
print(f"Exactitud del modelo: {accuracy:.2f}")

# Se agrega classification_report manualmente
evaluation_report = classification_report(
    target_labels,
    predictions,
    labels=range(len(categories)),
    target_names=categories,
    zero_division=1
)

print(evaluation_report)


DEBUG - Text: "DESC"
Q: How did serfdom develop in and then leave Russia ?
A:DESC
T:DESC

DEBUG - Text: "ENTY"
Q: What films featured the character Popeye Doyle ?
A:ENTY
T:ENTY

DEBUG - Text: "ENTY"
Q: How can I find a list of celebrities ' real names ?
A:ENTY
T:DESC

DEBUG - Text: "ENTY"
Q: What fowl grabs the spotlight after the Chinese Year of the Monkey ?
A:ENTY
T:ENTY

DEBUG - Text: "ABBR"
Q: What is the full form of .com ?
A:ABBR
T:ABBR

DEBUG - Text: "DESC"
Q: What contemptible scoundrel stole the cork from my lunch ?
A:DESC
T:HUM

DEBUG - Text: "ENTY"
Q: What team did baseball 's St. Louis Browns become ?
A:ENTY
T:HUM

DEBUG - Text: "DESC"
Q: What is the oldest profession ?
A:DESC
T:HUM

DEBUG - Text: "DESC"
Q: What are liver enzymes ?
A:DESC
T:DESC

DEBUG - Text: "ENTY"
Q: Name the scar-faced bounty hunter of The Old West .
A:ENTY
T:HUM

DEBUG - Text: "HUM"
Q: When was Ozzy Osbourne born ?
A:HUM
T:NUM

DEBUG - Text: "DESC"
Q: Why do heavier objects travel downhill faster ?
A:D