# Infer-Retrieve-Rerank Llama Pack

<a href="https://colab.research.google.com/github/run-llama/llama-hub/blob/main/llama_hub/llama_packs/research/infer_retrieve_rerank/infer_retrieve_rerank.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is our implementation of the paper ["In-Context Learning for Extreme Multi-Label Classification](https://arxiv.org/pdf/2401.12178.pdf) by Oosterlinck et al.

The paper proposes "infer-retrieve-rerank", a simple paradigm using frozen LLM/retriever models that can do "extreme"-label classification (the label space is huge).
1. Given a user query, use an LLM to predict an initial set of labels.
2. For each prediction, retrieve the actual label from the corpus.
3. Given the final set of labels, rerank them using an LLM.

All of these can be implemented as LlamaIndex abstractions. In this notebook we show you how to build "infer-retrieve-rerank" from scratch but also how to build it as a LlamaPack.

## Try out a Dataset

We use the BioDEX dataset as mentioned in the paper.

Here is the [link to the paper](https://arxiv.org/pdf/2305.13395.pdf). Here is the [link to the Github repo](https://github.com/KarelDO/BioDEX).

In [1]:
import datasets

# load the report-extraction dataset
dataset = datasets.load_dataset("BioDEX/BioDEX-ICSR")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],
        num_rows: 9624
    })
    validation: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],
        num_rows: 2407
    })
    test: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext

### Define Dataset Processing Functions

Here we define some basic functions to get the set of reactions (labels) and samples from the BioDEX dataset.

In [3]:
from llama_index import get_tokenizer
import re
from typing import Set, List

tokenizer = get_tokenizer()


sample_size = 5


def get_reactions_row(raw_target: str) -> List[str]:
    """Get reactions from a single row."""
    reaction_pattern = re.compile(r"reactions:\s*(.*)")
    reaction_match = reaction_pattern.search(raw_target)
    if reaction_match:
        reactions = reaction_match.group(1).split(",")
        reactions = [r.strip().lower() for r in reactions]
    else:
        reactions = []
    return reactions


def get_reactions_set(dataset) -> Set[str]:
    """Get set of all reactions."""
    reactions = set()
    for data in dataset["train"]:
        reactions.update(set(get_reactions_row(data["target"])))
    return reactions


def get_samples(dataset, sample_size: int = 5):
    """Get processed sample.

    Contains source text and also the reaction label.

    Parse reaction text to specifically extract reactions.

    """
    samples = []
    for idx, data in enumerate(dataset["train"]):
        if idx >= sample_size:
            break
        text = data["fulltext_processed"]
        raw_target = data["target"]

        reactions = get_reactions_row(raw_target)

        samples.append({"text": text, "reactions": reactions})
    return samples

## Use LlamaPack

In this first section we use our infer-retrieve-rerank LlamaPack to output predicted labels.

In [4]:
# Option: if developing with the llama_hub package
# from llama_hub.llama_packs.research.infer_retrieve_rerank.base import InferRetrieveRerankPack

# # Option: download_llama_pack
from llama_index.llama_pack import download_llama_pack

InferRetrieveRerankPack = download_llama_pack(
    "InferRetrieveRerankPack",
    "./irr_pack",
    # leave the below line commented out if using the notebook on main
    # llama_hub_url="https://raw.githubusercontent.com/run-llama/llama-hub/jerry/add_infer_retrieve_rerank/llama_hub"
)

In [5]:
from llama_index.llms import OpenAI

llm = OpenAI(model="gpt-3.5-turbo-16k")
pred_context = """\
The output predictins should be a list of comma-separated adverse \
drug reactions. \
"""
reranker_top_n = 10

pack = InferRetrieveRerankPack(
    get_reactions_set(dataset),
    llm=llm,
    pred_context=pred_context,
    reranker_top_n=reranker_top_n,
    verbose=True,
)

Generating embeddings: 0it [00:00, ?it/s]
Generating embeddings: 0it [00:00, ?it/s]
Generating embeddings: 0it [00:00, ?it/s]


In [6]:
samples = get_samples(dataset, sample_size=5)
pred_reactions = pack.run(inputs=[s["text"] for s in samples])
gt_reactions = [s["reactions"] for s in samples]

> Generating predictions for input 0: TITLE:
SARS-CoV-2-related ARDS in a maintenance hemodialysis patient: case report on tailored approach by daily hemodialysis, noninvasive ventilation, tocilizumab, anxiolytics, and point-of-care ultrasound.

ABSTRACT:
Without rescue drugs approved, holistic approach by daily hemodialysis, noninvasiv
> Generated predictions: ['respiratory distress', 'fluid overload', 'fluid retention', 'anxiety', 'delirium', 'nervousness', 'acute myocardial infarction', 'cardiovascular insufficiency', 'neonatal respiratory distress syndrome', 'delirium tremens']
> Generating predictions for input 1: TITLE:
Corynebacterium propinquum: A Rare Cause of Prosthetic Valve Endocarditis.

ABSTRACT:
Nondiphtheria Corynebacterium species are often dismissed as culture contaminants, but they have recently become increasingly recognized as pathologic organisms. We present the case of a 48-year-old male pat
> Generated predictions: ['chest pain', 'dyspnoea', 'dyspnoea exertional

In [11]:
pred_reactions[2]

['agranulocytosis',
 'haematotoxicity',
 'bone marrow toxicity',
 'infantile genetic agranulocytosis']

In [12]:
gt_reactions[2]

['bone marrow toxicity',
 'cytomegalovirus infection',
 'cytomegalovirus mucocutaneous ulcer',
 'febrile neutropenia',
 'leukoplakia',
 'odynophagia',
 'oropharyngeal candidiasis',
 'pancytopenia',
 'product use issue',
 'red blood cell poikilocytes present',
 'vitamin d deficiency']

## Define Infer-Retrieve-Rerank Pipeline

Here we define the core components needed for the full infer-retrieve-rerank pipeline. 

Refer to the [paper](https://arxiv.org/pdf/2401.12178.pdf) for more details. The paper implements it in DSPy, here we adapt an implementation with LlamaIndex abstractions. As a result the specific implementations (e.g. prompts, output parsing modules, reranking module) are different even though the conceptually we follow similar steps.

Our implementation uses fixed models, and does not do automatic distillation between teacher and student.

In [13]:
from llama_index.retrievers import BaseRetriever
from llama_index.llms.llm import LLM
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import QueryPipeline
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.rankGPT_rerank import RankGPTRerank
from llama_index.output_parsers import ChainableOutputParser
from typing import List

#### Index each Reaction with a Vector Index

Since the set of reactions is quite large, we can define a vector index over all reactions. That way we can retrieve the top k most semantically similar reactions to any prediction.

In [14]:
import random

all_reactions = get_reactions_set(dataset)
random.sample(all_reactions, 5)

since Python 3.9 and will be removed in a subsequent version.
  random.sample(all_reactions, 5)


['burning mouth syndrome',
 'hepatitis e',
 'gingivitis ulcerative',
 'page kidney',
 'herpes simplex pneumonia']

In [15]:
from llama_index.schema import TextNode
from llama_index.embeddings import OpenAIEmbedding
from llama_index.ingestion import IngestionPipeline
from llama_index import VectorStoreIndex

reaction_nodes = [TextNode(text=r) for r in all_reactions]
pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()])
reaction_nodes = await pipeline.arun(documents=reaction_nodes)

index = VectorStoreIndex(reaction_nodes)

In [None]:
reaction_nodes[0].embedding

In [17]:
reaction_retriever = index.as_retriever(similarity_top_k=2)

In [19]:
nodes = reaction_retriever.retrieve("abdominal")
print([n.get_content() for n in nodes])

['abdominal pain', 'abdominal symptom']


#### Define Infer Prompt

We define an infer prompt that given a document and relevant task context, can generate a list of comma-separated predictions.

**NOTE**: This is our own prompt and not taken from the paper.

In [20]:
infer_prompt_str = """\

Your job is to output a list of predictions given context from a given piece of text. The text context,
and information regarding the set of valid predictions is given below. 

Return the predictions as a comma-separated list of strings.

Text Context:
{doc_context}

Prediction Info:
{pred_context}

Predictions: """

infer_prompt = PromptTemplate(infer_prompt_str)

#### Define Output Parser

We define a very simple output parser that can parse an output into a list of strings.

In [21]:
class PredsOutputParser(ChainableOutputParser):
    """Predictions output parser."""

    def parse(self, output: str) -> List[str]:
        """Parse predictions."""
        tokens = output.split(",")
        return [t.strip() for t in tokens]


preds_output_parser = PredsOutputParser()

#### Define Rerank Prompt

Here we define a rerank prompt that will reorder a batch of labels based on their relevance to the query.

In [22]:
rerank_str = """\
Given a piece of text, rank the {num} labels above based on their relevance \
to this piece of text. The labels \
should be listed in descending order using identifiers. \
The most relevant labels should be listed first. \
The output format should be [] > [], e.g., [1] > [2]. \
Only response the ranking results, \
do not say any word or explain. \

Here is a given piece of text: {query}. 

"""
rerank_prompt = PromptTemplate(rerank_str)

#### Define Infer-Retrieve-Rerank Function

We define the infer-retrieve-rerank steps as a function.

In [23]:
def infer_retrieve_rerank(
    query: str,
    retriever: BaseRetriever,
    llm: LLM,
    pred_context: str,
    reranker_top_n: int = 3,
):
    """Infer retrieve rerank."""
    infer_prompt_c = infer_prompt.as_query_component(
        partial={"pred_context": pred_context}
    )
    infer_pipeline = QueryPipeline(chain=[infer_prompt_c, llm, preds_output_parser])
    preds = infer_pipeline.run(query)

    print(f"PREDS: {preds}")
    all_nodes = []
    for pred in preds:
        nodes = retriever.retrieve(str(pred))
        all_nodes.extend(nodes)

    reranker = RankGPTRerank(
        llm=llm,
        top_n=reranker_top_n,
        rankgpt_rerank_prompt=rerank_prompt,
        # verbose=True,
    )
    reranked_nodes = reranker.postprocess_nodes(all_nodes, query_str=query)
    return [n.get_content() for n in reranked_nodes]

## Run Over Sample Data

Now we're ready to run over some sample data! 

In [24]:
samples = get_samples(dataset, sample_size=5)

In [34]:
reaction_retriever = index.as_retriever(similarity_top_k=2)
llm = OpenAI(model="gpt-3.5-turbo-16k")
pred_context = """\
The output predictins should be a list of comma-separated adverse \
drug reactions. \
"""

reranker_top_n = 10

pred_reactions = []
gt_reactions = []
for idx, sample in enumerate(samples):
    print(idx)
    cur_pred_reactions = infer_retrieve_rerank(
        sample["text"],
        reaction_retriever,
        llm,
        pred_context,
        reranker_top_n=reranker_top_n,
    )
    cur_gt_reactions = sample["reactions"]

    pred_reactions.append(cur_pred_reactions)
    gt_reactions.append(cur_gt_reactions)

0
PREDS: ['fluid overload', 'acute respiratory distress syndrome', 'anxiety', 'myocardial insufficiency', 'hypervolemia', 'hypovolemia', 'respiratory distress', 'allergic reaction', 'diarrhea', 'rash']
1
PREDS: ['fever', 'dizziness', 'dyspnea on exertion', 'intermittent chest pain', 'palpitations']
2
PREDS: ['azathioprine-induced myelotoxicity', 'drug-induced agranulocytosis']
3
PREDS: ['There is no information provided about adverse drug reactions in the given text context. Therefore', 'it is not possible to make any predictions about adverse drug reactions.']
4
PREDS: ['painful swelling in lymph nodes', 'weight loss', 'night sweats', 'hepatosplenomegaly', 'generalized lymphadenopathy', 'skin disorders', 'bone marrow disorders', 'blood disorders', 'misorientation of body segments', 'excessive backward pelvic tilt', 'excessive kyphosis', 's-shaped scoliosis', 'excessive pelvic obliquity', 'flat right foot contact', 'limited ankle dorsiflexion', 'toe walking', 'muscle weakness', 'limite

In [37]:
pred_reactions[2]

['agranulocytosis',
 'haematotoxicity',
 'bone marrow toxicity',
 'infantile genetic agranulocytosis']

In [38]:
gt_reactions[2]

['bone marrow toxicity',
 'cytomegalovirus infection',
 'cytomegalovirus mucocutaneous ulcer',
 'febrile neutropenia',
 'leukoplakia',
 'odynophagia',
 'oropharyngeal candidiasis',
 'pancytopenia',
 'product use issue',
 'red blood cell poikilocytes present',
 'vitamin d deficiency']