# HuatuoGPT-o1 Medical RAG and Reasoning

In this example, we will explore an end-to-end system using HuatuoGPT-o1 for medical question answering with RAG and reasoning. We will leverage the HuatuoGPT-o1 model, a medical LLM designed for advanced medical reasoning, to provide detailed and well-structured answers to medical queries.

[**HuatuoGPT-o1**](https://huggingface.co/FreedomIntelligence/HuatuoGPT-o1-7B) is a medical LLM that excels at identifying mistakes, exploring alternative strategies, and refining its answers. It utilizes verifiable medical problems and a specialized medical verifier to enhance its reasoning capabilities.

## Setups

In [None]:
!pip install -qU transformers datasets sentence-transformers scikit-learn

## Load the dataset

We will use the [`Chat-Doctor-HealthCareMagic-100k`](https://huggingface.co/datasets/lavita/ChatDoctor-HealthCareMagic-100k) dataset, which contains 100K real-world patient-doctor interactions, providing a rich knowledge base for our RAG system.

In [None]:
from datasets import load_dataset

dataset = load_dataset("lavita/Chat-Doctor-HealthCareMagic-100k")

## Initialize the models

We need to initialize two models:
- `HuotuoGPT-o1`: the medical LLM for generating responses.
- `all-MiniLM-L6-v2`: an embedding model from sentence transformers for creating vector representations of text, which we will use for retrieval.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer

model_name = 'FreedomIntelligence/HuatuoGPT-o1-7B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    device_map='auto'
)

embed_model = SentenceTransformer('all-MiniLM-L6-v2')

## Prepare the knowledge base

In [None]:
import pandas as pd
import numpy as np

# convert dataset to dataframe
df = pd.DataFrame(dataset['train'])

# combine question and answer for context
df['combined'] = df['input'] + ' ' + df['output']

# generate embeddings
print('Generating embeddings for the knowledge base...')
embeddings = embed_model.encode(
    df['combined'].tolist(),
    show_progress_bar=True,
    batch_size=128
)
print('Embeddings generated.')

## Implement retrieval

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

def retrieve_relevant_contexts(query: str, k: int = 3) -> list:
    """Retrieve the k most relevant contexts to a given query.

    Parameters
    ----------
    query : str
        The query to retrieve relevant contexts for.
    k : int
        The number of relevant contexts to retrieve.

    Returns
    -------
    list
        A list of dictionaries, each containing a relevant context.
    """
    # generate query embedding
    query_embedding = embed_model.encode([query])[0]

    # calculate similarities
    similarities = cosine_similarity([query_embedding], embeddings)[0]

    # get top-k similar contexts
    top_k_indices = np.argsort(similarities)[-k:][::-1]

    contexts = []
    for idx in top_k_indices:
        contexts.append({
            'question': df.iloc[idx]['input'],
            'answer': df.iloc[idx]['output'],
            'similarity': similarities[idx]
        })

    return contexts

## Implement response generation

In [None]:
def generate_structured_response(query: str, contexts: list) -> str:
    """Generate a detailed response using the retrieved contexts.

    Parameters
    ----------
    query : str
        The query to generate a response for.
    contexts : list
        A list of dictionaries, each containing a relevant context.

    Returns
    -------
    str
        The generated structured response.
    """
    # prepare prompt with retrieved contexts
    context_prompt = "\n".join(
        [
            f"Reference {i+1}:\nQuestion: {ctx['question']}\nAnswer: {ctx['answer']}"
            for i, ctx in enumerate(contexts)
        ]
    )

    prompt = f"""Based on the following references and your medical knowledge, provide a detailed response:
    References:
    {context_prompt}

    Question: {query}

    By considering:
    1. The key medical concepts in the question.
    2. How the reference cases relate to this question.
    3. What medical principles should be applied.
    4. Any potential complications or considerations.

    Give the final response:
    """

    # generate response
    messages = [{'role': 'user', 'content': prompt}]
    inputs = tokenizer(
        tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        ),
        return_tensors='pt'
    ).to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=1024,
        temperature=0.7,
        num_beams=1,
        do_sample=True,
    )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # extract the final response
    final_response = response.split("Give the final response:\n")[-1]

    return final_response

## Put it all together

Define a function to process a query end-to-end and then use it with an example.

In [None]:
def process_query(query: str, k: int = 3) -> tuple:
    """Process a medical query end-to-end

    Parameters
    ----------
    query: str
        The user's medical query
    k: int
        The number of relevant contexts to retrieve

    Returns
    -------
    tuple
        The generated response and the retrieved contexts
    """
    contexts = retrieve_relevant_contexts(query, k)
    response = generate_structured_response(query, contexts)
    return response, contexts

In [None]:
# example query
query = "I've been experiencing persistent headaches and dizziness for the past week. What could be the cause?"

# process query
response, contexts = process_query(query)

# print results
print(f'Query: {query}')
print('\nRelevant Contexts:')
for i, ctx in enumerate(contexts, 1):
    print(f"\nReference {i} (Similarity: {ctx['similarity']:.3f}):")
    print(f"Q: {ctx['question']}")
    print(f"A: {ctx['answer']}")

print('\nGenerated Response:')
print(response)