In [1]:
!python -m spacy download pt_core_news_sm

Collecting pt-core-news-sm==3.7.0
  Downloading https://github.com/explosion/spacy-models/releases/download/pt_core_news_sm-3.7.0/pt_core_news_sm-3.7.0-py3-none-any.whl (13.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('pt_core_news_sm')


In [11]:
import json
import os
import logging
import sys

# NOTE: This is ONLY necessary in jupyter notebook.
# Details: Jupyter runs an event-loop behind the scenes.
#          This results in nested event-loops when we start an event-loop to make async queries.
#          This is normally not allowed, we use nest_asyncio to allow it for convenience.
import nest_asyncio


from llama_index.core import Document, QueryBundle
from llama_index.core.schema import TextNode
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.llms.groq import Groq

import spacy

from glob import glob
from dataclasses import dataclass
from typing import List, Dict
from thefuzz import process

In [16]:
nest_asyncio.apply()

os.environ["GROQ_API_KEY"] = "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"

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().handlers = []
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [17]:
SENTENCE_WINDOW_SIZE = 5
SENTENCE_WINDOW_STRIDE = 2

TOP_K = 4
BM25_TOP_K = 1000

LLM_MODEL = "llama3-70b-8192"
RERANKER_MODEL = "unicamp-dl/monoptt5-base"

In [18]:
def get_transcription_documents(base_transcriptions_path):
    transcription_documents = []

    for transcriptions_path in glob(base_transcriptions_path):
        with open(transcriptions_path) as f:
            transcriptions = json.load(f)
        
        for transcription in transcriptions:
            transcription_documents.append(
                Document(
                    text=transcription['transcription'],
                    metadata={
                        'title': transcription['title'],
                        'publishing_date': transcription['publishing_date'],
                        'quadro': transcription['quadro'],
                        'hashtag': transcription['hashtag'],
                    },
                )
            )
            
    return transcription_documents

In [19]:
def sliding_window_split(documents, stride, window_size):
    sentencizer = spacy.blank('pt')
    sentencizer.add_pipe('sentencizer')
    
    window_documents = []

    for document in documents:
        doc_sentencized = sentencizer(document.text)
        sentences = [sent.text.strip() for sent in doc_sentencized.sents]
        for i in range(0, len(sentences), stride):
            window_text = ' '.join(sentences[i : min(len(sentences), i+window_size)]).strip()
            window_metadata = document.metadata.copy()
            window_metadata['parent_document_id'] = document.id_
            window_documents.append(Document(text=window_text, metadata=window_metadata))

    return window_documents

In [20]:
def get_transcriptions_nodes(transcription_documents):
    transcription_window_documents = sliding_window_split(transcription_documents, SENTENCE_WINDOW_STRIDE, SENTENCE_WINDOW_SIZE)
    
    transcription_nodes = dict()
    for document in transcription_window_documents:
        new_node = TextNode(id=document.id_, text=document.text, metadata=document.metadata)
        if new_node.metadata['title'] not in transcription_nodes.keys():
            transcription_nodes[document.metadata['title']] = [new_node]
        else:
            transcription_nodes[document.metadata['title']].append(new_node)
    
    return transcription_nodes

In [21]:
transcription_documents = get_transcription_documents("../transcriptions-headless/*.json")

transcription_documents[0]

Document(id_='3634badd-9af2-479a-9656-9ba2f423d3ab', embedding=None, metadata={'title': 'fome', 'publishing_date': '06/06/2024', 'quadro': 'luz acesa', 'hashtag': '#fome'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='[vinheta] Shhhh… Luz Acesa, história de dar medo. [vinheta]\n\n\n\nDéia Freitas: Oi, gente… Cheguei. Cheguei para um Luz Acesa. — Talvez esse Luz Acesa seja um pouco mais impressionante, fiquem atentos aí a algum gatilho. — Hoje eu vou contar para vocês a história da Thalita. Então vamos lá, vamos de história.\n\n\n\n\n[trilha]\n\n\n\n\nA Thalita ela trabalha numa empresa — como é que eu vou dizer, vou simplificar — de prensas. — Quando você tem processos industriais que envolvam metal e que precisem de uma prensa… É num dos braços dessa empresa que a Thalita trabalha que tem isso. — E Thalita sempre se deu muito bem com a galera ali do setor dela de trabalho e, na hora do almoço, ela sempre saía para almoçar junto com outro func

In [22]:
transcription_nodes = get_transcriptions_nodes(transcription_documents)

transcription_nodes['milton']

[TextNode(id_='84181b12-b1de-429e-b75d-f0b370197ddb', embedding=None, metadata={'title': 'milton', 'publishing_date': '18/03/2024', 'quadro': 'amor nas redes', 'hashtag': '#milton', 'parent_document_id': '85b840c3-44b1-4a8e-ba0c-7050a19116fc'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='[vinheta] Amor nas Redes, sua história é contada aqui. [ vinheta]\n\n\n\nDéia Freitas: Oi, gente… Cheguei. Cheguei para mais um Amor nas Redes. E hoje eu vou contar para vocês a história do Reinaldo. Então vamos lá, vamos de história.', mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'),
 TextNode(id_='75ef1941-e09f-43e5-93a8-fefba2d480e3', embedding=None, metadata={'title': 'milton', 'publishing_date': '18/03/2024', 'quadro': 'amor nas redes', 'hashtag': '#milton', 'parent_document_id': '85b840c3-44b1-4a8e-ba0c-7050a19116fc'}, excluded_embed_

In [15]:
class MultiIndexRetriever:
    def __init__(self, indexes_nodes: Dict[str, List[TextNode]], top_k, bm25_top_k, reranker_model) -> None:
        self.indexes: Dict[str, List[TextNode]] = indexes_nodes
        
        self.bm25_retrievers: Dict[str, BM25Retriever] = dict()
        for index, nodes in indexes_nodes.items():
            self.bm25_retrievers[index] = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=bm25_top_k)

        self.reranker = SentenceTransformerRerank(top_n=top_k, model=reranker_model)
        

    def retrieve(self, index_name: str, query: str) -> List[str]:
        if index_name not in self.indexes:
            raise ValueError(f"Index {index_name} not found")
        
        retriever = self.bm25_retrievers[index_name]

        retrieved_nodes = retriever.retrieve(query)
        reranked_nodes = self.reranker.postprocess_nodes(
            retrieved_nodes,
            query_bundle=QueryBundle(query),
        )

        context_chunks = [ node.get_text() for node in reranked_nodes ]

        return context_chunks

In [12]:
class IndexDetector:
    def __init__(self, llm_model: str, index_names: List[str], GROQ_API_KEY: str = os.environ['GROQ_API_KEY']):
        self._base_prompt = \
            "Given the following Portuguese query, regarding an episode of a podcast, please tell me the title of the episode. " \
            "The query starts now: '{query}'." \
            "You MUST answer with only the title of the episode."
        
        self.llm = Groq(model=llm_model, api_key=GROQ_API_KEY)
        self.llm_model_name = llm_model
        self.index_names = index_names

    def detect_index(self, query: str) -> str:
        prompt = self._base_prompt.format(query=query)
        raw_llm_guess = self.llm.complete(prompt).text
        llm_guess = raw_llm_guess.strip().lower()

        if llm_guess not in self.index_names:
            index_name = llm_guess
        else:
            index_name, _ = process.extractOne(llm_guess, self.index_names)
            
        return index_name

In [41]:
class RAGGenerator:
    def __init__(self, llm_model, GROQ_API_KEY=os.environ['GROQ_API_KEY']):
        self._base_prompt = \
            "Consider the following context passages of a podcast episode and answer the given question." \
            "You MUST answer the question only in Portuguese." \
            "\n\n" \
            "{context_passages}" \
            "\n\n" \
            "If there is not enough information in the context passages, answer \"Não há informação suficiente no episódio.\"." \
            "\n\n" \
            "Question: {query}"
        
        self.llm = Groq(model=llm_model, api_key=GROQ_API_KEY)
        self.llm_model_name = llm_model

    def generate_answer(self, query: str, contexts: List[str]) -> str:
        context_passages = "\n\n".join([ f"Context {i}: {context}" for i, context in enumerate(contexts, 1) ])
        prompt = self._base_prompt.format(query=query, context_passages=context_passages)
        answer = self.llm.complete(prompt).text.strip()
        return answer

In [18]:
@dataclass
class RAGResponse:
    answer: str
    contexts: List[str]

class RAGPipeline:
    def __init__(
        self,
        retriever: MultiIndexRetriever, 
        index_detector: IndexDetector,
        generator: RAGGenerator,
    ) -> None:
        self.index_detector: IndexDetector = index_detector
        self.retriever: MultiIndexRetriever = retriever
        self.generator: RAGGenerator = generator

    def __call__(self, query: str) -> RAGResponse:
        index_name = self.index_detector.detect_index(query)
        context_chunks = self.retriever.retrieve(index_name, query)
        answer = self.generator.generate(query, context_chunks)
        return RAGResponse(answer, context_chunks)

In [36]:
index_detector = IndexDetector(llm_model=LLM_MODEL, index_names=list(transcription_nodes.keys()))
index_name = index_detector.detect_index('No episódio \'mário\', quem foi diagnosticado com câncer de bexiga na história?')
index_name

HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"


'mário'

In [44]:
example_contexts = [
    "Eu sou Bia Suzuki, tenho 30 anos, moro em São José do Rio Preto, interior de São Paulo... eu recebi o diagnóstico de câncer de intestino aos 25 anos...",
    "O câncer colorretal é a terceira neoplasia mais frequente e a segunda de maior mortalidade no mundo...",
    "Eu achei que não é impossível viver bem com bolsinha e eu sou uma prova disso."
]
example_question = "No episódio 'bia', quem é a Bia e o que aconteceu com ela?"

rag_generator = RAGGenerator(llm_model=LLM_MODEL)
rag_generator.generate_answer(example_question, example_contexts)

HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"


'No episódio, Bia é a Bia Suzuki, que tem 30 anos e mora em São José do Rio Preto, interior de São Paulo, e recebeu o diagnóstico de câncer de intestino aos 25 anos.'

In [35]:
multi_index_retriever = MultiIndexRetriever(
    indexes_nodes=transcription_nodes,
    top_k=TOP_K,
    bm25_top_k=BM25_TOP_K,
    reranker_model=RERANKER_MODEL
)
index_detector = IndexDetector(llm_model=LLM_MODEL, index_names=list(transcription_nodes.keys()))
rag_generator = RAGGenerator(llm_model=LLM_MODEL)

rag_pipeline = RAGPipeline(
    retriever=multi_index_retriever,
    index_detector=index_detector,
    generator=rag_generator,
)

NameError: name 'MultiIndexRetriever' is not defined