<a href="https://colab.research.google.com/github/juanluisunicamp/Imersao_IA_RAG/blob/main/Projeto_Imersao_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RAG
### Inspirado no paper do Visconde
### Link: https://arxiv.org/pdf/2212.09656

## Instalando pacotes

In [1]:
!apt install openjdk-21-jre-headless
!pip install -q -U google-generativeai
!pip install -q pyserini faiss-cpu groq

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
openjdk-21-jre-headless is already the newest version (21.0.2+13-1~22.04.1).
0 upgraded, 0 newly installed, 0 to remove and 45 not upgraded.


## Configuração inicial do sistema

In [4]:
import os
from google.colab import userdata
import google.generativeai as genai

In [5]:
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-21-openjdk-amd64"
GOOGLE_API_KEY          = userdata.get('ImersaoIA')
genai.configure(api_key = GOOGLE_API_KEY)

## Importamos as bibliotecas

In [6]:
import re
import json
import time
import spacy
import random
import pickle
import string
import tarfile
import requests
import subprocess
import numpy as np
import unicodedata
import collections
from tqdm import tqdm
from bs4 import BeautifulSoup
from collections import Counter
import google.generativeai as genai
from pyserini.search.lucene import LuceneSearcher

In [7]:
DATA_DIR       = 'data'
DATA_INDEX_DIR = f"{DATA_DIR}/iirc_index"

## Creamos o dataset

In [8]:
def create_dataset():
    IIRC_train_dev   = "https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_train_dev.tgz"
    Context_articles = "https://iirc-dataset.s3.us-west-2.amazonaws.com/context_articles.tar.gz"
    IIRC_test        = "https://iirc-dataset.s3.us-west-2.amazonaws.com/iirc_test.json"

    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    for url in [IIRC_train_dev, Context_articles, IIRC_test]:
        filename = url.split('/')[-1]
        filepath = os.path.join(DATA_DIR, filename)

        if not os.path.exists(filepath):
            print(f"Baixando {filename}...")
            r = requests.get(url)
            with open(filepath, 'wb') as f:
                f.write(r.content)
            print(f"{filename} baixado.")

        if filename.endswith('.tgz') or filename.endswith('.tar.gz'):

            with tarfile.open(filepath, 'r:gz') as tar:
                members = [m for m in tar.getmembers() if not os.path.exists(os.path.join(DATA_DIR, m.name))]
                if len(members) > 0:
                    print(f"Extraindo {filename}...")
                    tar.extractall(DATA_DIR, members=members)
                    print(f"{filename} extraído.")

    train_set        = json.load(open(f"{DATA_DIR}/iirc_train_dev/train.json", "r"))
    context_articles = json.load(open(f"{DATA_DIR}/context_articles.json", "r"))
    test_set         = json.load(open(f"{DATA_DIR}/iirc_test.json", "r"))
    return train_set, context_articles, test_set

## Modelos do Gemini

In [9]:
for models in genai.list_models():
  if "generateContent" in models.supported_generation_methods:
    print(models.name)

models/gemini-1.0-pro
models/gemini-1.0-pro-001
models/gemini-1.0-pro-latest
models/gemini-1.0-pro-vision-latest
models/gemini-1.5-pro-latest
models/gemini-pro
models/gemini-pro-vision


## Configuração do modelo generativo

In [10]:
generation_config = {
  "candidate_count"  : 1,
  "temperature"      : 0,
}
safety_settings = {
  "HARASSMENT" : "BLOCK_NONE",
  "HATE"       : "BLOCK_NONE",
  "SEXUAL"     : "BLOCK_NONE",
  "DANGEROUS"  : "BLOCK_NONE",
}
model = genai.GenerativeModel(model_name        = "models/gemini-1.5-pro-latest",
                              generation_config = generation_config,
                              safety_settings   = safety_settings)

## Amostras

In [11]:
train_set, context_articles, test_set = create_dataset()
print(f"\nQuantidade de amostras do train_set: {len(train_set)}")
print(f"\nQuantidade de amostras do context_articles: {len(context_articles)}")
print(f"\nQuantidade de amostras do test_set: {len(test_set)}")

max_amostras = 150
test_set     = test_set[0:150]
print(f"\nQuantidade de amostras do test_set: {len(test_set)}")


Quantidade de amostras do train_set: 4754

Quantidade de amostras do context_articles: 56550

Quantidade de amostras do test_set: 514

Quantidade de amostras do test_set: 150


In [12]:
#Exemplo do train_set: Lista
train_set[0]

{'questions': [{'context': [{'text': 'During Operation Market Garden, the attempt to seize a bridgehead across the Rhine in the Netherlands, the 704th dropped supplies to allied troops near Nijmegen.',
     'indices': [494, 655],
     'passage': 'main'},
    {'text': 'Operation Market Garden was a failed World War II military operation fought in the Netherlands from 17 to 25 September 1944.',
     'indices': [0, 124],
     'passage': 'Operation Market Garden'}],
   'question_links': ['Operation Market Garden'],
   'answer': {'answer_spans': [{'start': 131,
      'text': ' from 17 to 25 September 1944',
      'passage': 'operation market garden',
      'end': 160}],
    'type': 'span'},
   'question': 'When did the operation during which the 704th dropped supplies to allied troops near Nijmegen begin?',
   'qid': 'q_0'}],
 'links': [{'indices': [73, 84], 'target': 'Air Support'},
  {'indices': [89, 101], 'target': 'Interdiction'},
  {'indices': [125, 143], 'target': 'Operation Overlord'

In [13]:
#Exemplo do context_articles: Diccionario
print(next(iter(context_articles.keys())))
print(next(iter(context_articles.items())))

san diego padres
('san diego padres', 'The San Diego Padres are an American <a href="professional%20baseball">professional baseball</a> team based in <a href="San%20Diego">San Diego</a>, <a href="California">California</a>. They compete in <a href="Major%20League%20Baseball">Major League Baseball</a> (MLB) as a member club of the <a href="National%20League">National League</a> (NL) <a href="National%20League%20West">West division</a>. Founded in <a href="1969%20San%20Diego%20Padres%20season">1969</a>, the Padres have won two <a href="List%20of%20National%20League%20pennant%20winners">NL pennants</a> — in <a href="1984%20San%20Diego%20Padres%20season">1984</a> and <a href="1998%20San%20Diego%20Padres%20season">1998</a>, losing in the <a href="World%20Series">World Series</a> both years. As of <a href="2017%20San%20Diego%20Padres%20season">2018</a>, they have had 14 winning seasons in franchise history. The Padres are one of two Major League Baseball teams (the other being the <a href="L

In [14]:
#Exemplo do context_articles: Diccionario
print(next(iter(context_articles.keys())))
print(next(iter(context_articles.items())))

san diego padres
('san diego padres', 'The San Diego Padres are an American <a href="professional%20baseball">professional baseball</a> team based in <a href="San%20Diego">San Diego</a>, <a href="California">California</a>. They compete in <a href="Major%20League%20Baseball">Major League Baseball</a> (MLB) as a member club of the <a href="National%20League">National League</a> (NL) <a href="National%20League%20West">West division</a>. Founded in <a href="1969%20San%20Diego%20Padres%20season">1969</a>, the Padres have won two <a href="List%20of%20National%20League%20pennant%20winners">NL pennants</a> — in <a href="1984%20San%20Diego%20Padres%20season">1984</a> and <a href="1998%20San%20Diego%20Padres%20season">1998</a>, losing in the <a href="World%20Series">World Series</a> both years. As of <a href="2017%20San%20Diego%20Padres%20season">2018</a>, they have had 14 winning seasons in franchise history. The Padres are one of two Major League Baseball teams (the other being the <a href="L

## Criamos os indices

In [15]:
documents = []
all_titles = []

for item in test_set:
    if item['title'].lower() not in all_titles:
        soup       = BeautifulSoup(item["text"], "html.parser")
        clean_text = soup.get_text()

        documents.append({
                "title": item['title'],
                "content": clean_text
            }
        )
        all_titles.append(item['title'].lower())
    for link in item["links"]:
        if link['target'].lower() in context_articles and link['target'].lower() not in all_titles:
            soup       = BeautifulSoup(context_articles[link['target'].lower()], "html.parser")
            clean_text = soup.get_text()

            documents.append({
                "title": link['target'],
                "content": clean_text
            })
            all_titles.append(link['target'].lower())
        else:
            print(link['target'].lower())

9th paratroopers assault regiment "col moschin"
goldfinger (film)
list of international cricket council members
icc americas championship
the rev
avenged sevenfold
fox footy
herald sun
fox footy
herald sun
united states
judeo-iraqi arabic
maya civilization
black watch
suicidal tendencies
western hockey league
national hockey league
home run
minor league baseball
colonel
colonel
massachusetts institute of technology
israel
harvard business review
american football
college football
united states
billboard 200
romeo discography
billboard 200
master p
hip hop history
billboard 200
louisiana
arizona
state farm stadium
louisiana
united states
gulf of mexico
saffir–simpson scale
forgotten realms
list of dungeons & dragons rulebooks
mexico
napoleon iii
american football
national football league
mexico
lucha libre
protagonist
double dragon
world war ii
banff, alberta
american football
quarterback
college football
2009 nfl draft
new york city
los angeles
metal massacre
metal massacre
hull city a

## Janelamento

In [16]:
def window(documents, stride=2, max_length=3):
    treated_documents = []
    nlp               = spacy.blank("en")
    nlp.add_pipe("sentencizer")

    for j,document in enumerate(tqdm(documents)):
        doc_text = document['content']
        doc = nlp(doc_text)
        sentences = [sent.text.strip() for sent in doc.sents]
        for i in range(0, len(sentences), stride):
            segment = ' '.join(sentences[i:i + max_length])
            treated_documents.append({
                "title": document['title'],
                "contents": document['title']+". "+segment,
                "segment": segment
            })
            if i + max_length >= len(sentences):
                break
    return treated_documents

In [17]:
documents[0]

{'title': 'Palici',
 'content': "The Palici (Παλικοί in Greek), or Palaci, were a pair of indigenous Sicilian chthonic deities in Roman mythology, and to a lesser extent in Greek mythology. They are mentioned in Ovid's Metamorphoses V, 406, and in Virgil's Aeneid IX, 585. Their cult centered on three small lakes that emitted sulphurous vapors in the Palagonia plain, and as a result these twin brothers were associated with geysers and the underworld. There was also a shrine to the Palaci in Palacia, where people could subject themselves or others to tests of reliability through divine judgement; passing meant that an oath could be trusted. The mythological lineage of the Palici is uncertain; one legend made the Palici the sons of Zeus, or possibly Hephaestus, by Aetna or Thalia, but another claimed that the Palici were the sons of the Sicilian deity Adranus.\n"}

In [18]:
treated_documents = window(documents)

100%|██████████| 2164/2164 [01:19<00:00, 27.39it/s]


In [19]:
treated_documents[0]

{'title': 'Palici',
 'contents': "Palici. The Palici (Παλικοί in Greek), or Palaci, were a pair of indigenous Sicilian chthonic deities in Roman mythology, and to a lesser extent in Greek mythology. They are mentioned in Ovid's Metamorphoses V, 406, and in Virgil's Aeneid IX, 585. Their cult centered on three small lakes that emitted sulphurous vapors in the Palagonia plain, and as a result these twin brothers were associated with geysers and the underworld.",
 'segment': "The Palici (Παλικοί in Greek), or Palaci, were a pair of indigenous Sicilian chthonic deities in Roman mythology, and to a lesser extent in Greek mythology. They are mentioned in Ovid's Metamorphoses V, 406, and in Virgil's Aeneid IX, 585. Their cult centered on three small lakes that emitted sulphurous vapors in the Palagonia plain, and as a result these twin brothers were associated with geysers and the underworld."}

In [20]:
def create_indices(treated_documents):
    dir_path = f"{DATA_DIR}/iirc_index"
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    with open(f"{dir_path}/contents.jsonl", 'w') as f:
        for i, doc in enumerate(treated_documents):
            doc['id'] = i
            if doc['segment'] != "":
                f.write(json.dumps(doc)+"\n")

    command = f"python3 -m pyserini.index.lucene -collection JsonCollection -generator DefaultLuceneDocumentGenerator -threads 1 -input {dir_path} -index {DATA_INDEX_DIR} -storeRaw"
    result  = subprocess.run(command, shell = True, capture_output = True)

    if (result.returncode == 0):
        print("Indices criado com sucesso")
    else:
        print("Erro ao criar o índice")

In [21]:
create_indices(treated_documents)

Indices criado com sucesso


## Recuperação dos documentos


### BM25
Traditional lexical models (e.g., BM25) using LuceneSearcher. \\
Extraido do: https://github.com/castorini/pyserini

In [22]:
def execute_search(question, limit_by_query=100):
    searcher = LuceneSearcher(DATA_INDEX_DIR)
    ans      = []

    hits = searcher.search(question, k=limit_by_query)
    chosen = []
    for hit in hits:
        hit = json.loads(hit.lucene_document.get('raw'))
        chosen.append(hit)

    ans.append({
        "question": question,
        "documents": chosen
    })

    return ans

## Treinamento

In [23]:
ZERO_SHOT_PROMPT = "Based on the following passages, answer the question presenting your reasoning. Let me know if there is no answer with \"Not enough information\". "
ZERO_SHOT_PROMPT += "Your answer should be in this JSON format:{\"Reasoning\": \"your reasoning\", \"Answer\": \"your answer\"}:\n"

save_file = "qa_data.pickle"

if os.path.exists(save_file):
    with open(save_file, 'rb') as f:
        qa = pickle.load(f)
else:
    qa = {'question': [], 'answer': [], 'llm_answer': []}

for item in tqdm(test_set):
    questions  = item["questions"]
    item_title = item["title"]
    prompt = ZERO_SHOT_PROMPT

    for q in questions:
        if q['question'] in qa['question']:
            continue

        question = q['question']

        hits = execute_search(question, limit_by_query=5)
        for hit in hits:
            for i, passage in enumerate(hit['documents']):
                content = passage['contents']
                prompt += f"\n\nPassage {i+1}: {content}"

        prompt += f"\n\nQuestion: {question}\n\n"

        if q['answer']['type'] == "span":
            final_answer = ", ".join([answer_span['text'] for answer_span in q['answer']["answer_spans"]])

        elif q['answer']['type'] == "value":
            final_answer = "{0} {1}".format(q['answer']['answer_value'],
                                            q['answer']['answer_unit'])

        elif q['answer']['type'] == "binary":
            final_answer = q['answer']['answer_value']

        elif q['answer']['type'] == "none":
            final_answer = "Not enough information"

        try:
            llm_answer = model.generate_content(prompt)

            qa['question'].append(question)
            qa['answer'].append(final_answer)
            qa['llm_answer'].append(llm_answer)

            if len(qa['question']) % 5 == 0:
                with open(save_file, 'wb') as f:
                    pickle.dump(qa, f)

            if i % 10 == 0:
                sleep_time = random.uniform(1, 5)
                time.sleep(sleep_time)
        except:
            time.sleep(30)

with open(save_file, 'wb') as f:
    pickle.dump(qa, f)

100%|██████████| 150/150 [1:25:51<00:00, 34.34s/it]


## Mostramos as respostas do Gemini

In [68]:
import pandas as pd

json_pattern = r'{.*?}'
rea_pattern  = r'"Reasoning":\s*"([^"]+)"'
ans_pattern  = r'"Answer":\s*"([^"]+)"'
data = []
for q, a, la in zip(qa['question'], qa["answer"], qa['llm_answer']):

    match = re.search(json_pattern, la.text , re.DOTALL)
    if match:
        json_response = match.group()

        rea_match = re.search(rea_pattern, json_response)
        if rea_match:
            llm_reasoning = rea_match.group(0).split(": ")[1].replace("\"","")
        else:
            llm_reasoning = 'ERRO na resposta do LLM'

        ans_match = re.search(ans_pattern, json_response)
        if ans_match:
            llm_answer = ans_match.group(0).split(": ")[1].replace("\"","")
        else:
            llm_answer = 'ERRO na resposta do LLM'

    else:
        print("JSON não encontrado na resposta.")
        llm_answer = 'ERRO na resposta do LLM'
        llm_reasoning = 'ERRO na resposta do LLM'

    row = {'question': q,
           'answer': a,
           'llm_answer': llm_answer,
           'reasoning': llm_reasoning }

    data.append(row)

df = pd.DataFrame(data)
df.to_csv('resultado.csv', index=False)
df

Unnamed: 0,question,answer,llm_answer,reasoning
0,What is Zeus know for in Greek mythology?,sky and thunder god,Zeus is known as a powerful leader of the Gree...,"Passages 1, 2, 3, and 5 all describe Zeus as a..."
1,How long had the First World War been over whe...,5 years,5 years,Passage 1 states that Messe fought in World Wa...
2,How old was Messe when the First World War sta...,30 years,31,Passage 1 states that Messe was born in 1883 a...
3,Who was the manager for Hull City when Brunt s...,Not enough information,Not enough information,While the passage states that Brunt scored his...
4,Which stadium where Brunt played can hold more...,White Hart Lane,Not enough information,Passage 1 mentions that Chris Brunt returned t...
...,...,...,...,...
247,Was Rao alive when Atal Bihari Vajpayee was pr...,yes,No,Passage 2 states that \
248,Was the Ford Dagenham plant still in operation...,yes,Yes,Passage 2 states that Dagenham became part of ...
249,For how many years was Corinth goverened by th...,185 years,Not enough information,The provided text states that Corinth became p...
250,"How many weeks did ""What Hurts the Most"" spend...",Not enough information,Not enough information,Passage 3 states that \


In [69]:
la.text

'```json\n{\n "Reasoning": "Passage 5 states that \\"Waiting on the World to Change\\" stayed on the charts for 41 weeks.",\n "Answer": "41 weeks"\n}\n```'

## Avaliação

### Métricas de Avaliação

*   F1
*   EM (exact match)




In [70]:
def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  def remove_accents(input_str):
      nfkd_form = unicodedata.normalize('NFKD', input_str)
      only_ascii = nfkd_form.encode('ASCII', 'ignore')
      return only_ascii.decode("utf-8")

  return white_space_fix(remove_articles(remove_punc(lower(remove_accents(s)))))

def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

def compute_exact(a_gold, a_pred):
  return int(normalize_answer(a_gold) == normalize_answer(a_pred))

def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

In [71]:
f1s = []
ems = []
for a, llm_answer in zip(df["answer"].values, df['llm_answer'].values):

    f1s.append(compute_f1(a, llm_answer))
    ems.append(compute_exact(a, llm_answer))

df_metrics = pd.DataFrame()
df_metrics['F1'] = f1s
df_metrics['EM'] = ems
df_metrics.describe()

Unnamed: 0,F1,EM
count,252.0,252.0
mean,0.523446,0.452381
std,0.474305,0.498718
min,0.0,0.0
25%,0.0,0.0
50%,0.666667,0.0
75%,1.0,1.0
max,1.0,1.0
