<a href="https://colab.research.google.com/github/colesmcintosh/chain-of-though-reranking/blob/main/chain_of_thought_reranking.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reranking Chain of Thought (COT)

This notebook demonstrates how to generate, process, and rerank Chain of Thought (COT) responses using the `together` and `cohere` APIs. The process involves:

1. Generating a response from a chatbot model.
2. Extracting the COT reasoning from the response.
3. Splitting the response into sentences and grouping them.
4. Reranking the most relevant sections.
5. Formatting the best-ranked documents into a structured prompt.
6. Using the formatted prompt to generate a final refined response.

## **1. Install Required Libraries**

We first install `together` and `cohere`, which are needed to interact with the respective APIs.


In [1]:
!pip install together cohere



## **2. Set Up API Keys**

We retrieve the API keys from `google.colab.userdata` or prompt the user for input.


In [2]:
from google.colab import userdata

cohere_api_key = userdata.get('COHERE_API_KEY') or input('COHERE_API_KEY:')
together_api_key = userdata.get('TOGETHER_API_KEY') or input('TOGETHER_API_KEY:')

## **3. Initialize API Clients**

We instantiate `cohere` and `together` clients.


In [3]:
import cohere
from together import Together
co = cohere.Client(cohere_api_key)
together_client = Together(api_key=together_api_key)

## **4. Define the Initial Prompt**

We define the reasoning task where a robot picks up objects and shakes hands with itself.


In [4]:
prompt = 'Which number is larger, 9.9 or 9.11?'

## Preliminary Test: Evaluate the Base LLM Response

Before diving into the chain-of-thought (COT) extraction and reranking process, it is prudent to verify that the base LLM response to the initial prompt is incorrect. This test provides a baseline that will help us assess the improvements once we optimize the reasoning content. In this experiment, we query the model with our original prompt and display its direct output.

In [5]:
test_response = together_client.chat.completions.create(
    model='meta-llama/Llama-3.2-3B-Instruct-Turbo',
    messages=[{'role': 'user', 'content': prompt}],
)

In [6]:
test_response.choices[0].message.content

'9.11 is larger than 9.9.'

## **5. Generate a Chatbot Response**

Using the DeepSeek-R1 model, we generate an initial response to get the chain of thought.


In [7]:
cot_response = together_client.chat.completions.create(
    model='deepseek-ai/DeepSeek-R1',
    messages=[{'role': 'user', 'content': prompt}],
)


In [8]:
response_text = cot_response.choices[0].message.content

## **6. Extract the Chain of Thought (COT)**

We extract the reasoning portion of the response by isolating the text between `<think>` and `</think>`.

In [9]:
cot = response_text.split('<think>')[1].split('</think>')[0]

In [10]:
cot

"\nOkay, so I need to figure out which number is larger between 9.9 and 9.11. Hmm, let's start by writing them down to compare. \n\nFirst, I know that both numbers have the same whole number part, which is 9. That means the difference must be in the decimal parts. The first number is 9.9, and the second is 9.11. \n\nWait, but decimals can be tricky because the places matter. Let me break them down. \n\nFor 9.9, the decimal part is 0.9. That's the same as 0.90 if I add a zero at the end, right? Because 0.9 is nine tenths. \n\nOn the other hand, 9.11 has a decimal part of 0.11, which is eleven hundredths. So now, if I compare 0.90 and 0.11, it's easier to see. \n\n0.90 is definitely larger than 0.11 because 90 hundredths is more than 11 hundredths. So, putting that back into the original numbers, 9.90 is larger than 9.11. \n\nBut wait, let me double-check. Maybe I confused the decimal places. Let's think about it another way. \n\nIf I convert both numbers to fractions, maybe that will he

## **7. Split COT into Sentences and Group Them**

We define functions to split text into sentences and then group them into segments of four sentences.


In [11]:
import re

def split_into_sentences(text):
    """
    Splits the provided text into sentences using a regex that recognizes
    common sentence-ending punctuation (periods, exclamation marks, and question marks).
    """
    # The regex looks for punctuation followed by whitespace.
    pattern = r'(?<=[.!?])\s+'
    sentences = re.split(pattern, text.strip())
    return sentences

def group_sentences(sentences, group_size=4):
    """
    Aggregates the list of sentences into groups (chunks) containing group_size sentences each.
    """
    return [' '.join(sentences[i:i+group_size]) for i in range(0, len(sentences), group_size)]


## **8. Process the Extracted Reasoning**

We apply the sentence-splitting and grouping functions.


In [12]:
# Step 1: Split the text into individual sentences.
sentences = split_into_sentences(cot)

# Step 2: Group the sentences into chunks, with each chunk containing 4 sentences.
chunks = group_sentences(sentences, group_size=4)

# Display the resulting chunks.
for index, chunk in enumerate(chunks, start=1):
    print(f"Chunk {index}:\n{chunk}\n")

Chunk 1:
Okay, so I need to figure out which number is larger between 9.9 and 9.11. Hmm, let's start by writing them down to compare. First, I know that both numbers have the same whole number part, which is 9. That means the difference must be in the decimal parts.

Chunk 2:
The first number is 9.9, and the second is 9.11. Wait, but decimals can be tricky because the places matter. Let me break them down. For 9.9, the decimal part is 0.9.

Chunk 3:
That's the same as 0.90 if I add a zero at the end, right? Because 0.9 is nine tenths. On the other hand, 9.11 has a decimal part of 0.11, which is eleven hundredths. So now, if I compare 0.90 and 0.11, it's easier to see.

Chunk 4:
0.90 is definitely larger than 0.11 because 90 hundredths is more than 11 hundredths. So, putting that back into the original numbers, 9.90 is larger than 9.11. But wait, let me double-check. Maybe I confused the decimal places.

Chunk 5:
Let's think about it another way. If I convert both numbers to fractions, 

## **9. Perform Reranking Using Cohere**

We rerank the sentence chunks using Cohere’s rerank model (`rerank-v3.5`).


In [13]:
reranking = co.rerank(
    model="rerank-v3.5", query=prompt, documents=chunks, top_n=5
)

In [14]:
for res in reranking.results:
  print(chunks[res.index])
  print(f"Relevance Score: {res.relevance_score:.2%}")
  print()

Alternatively, if I had to compare them on a number line, 9.9 is closer to 10 than 9.11 is. Since 9.9 is just 0.1 away from 10, while 9.11 is 0.89 away from 10. So again, 9.9 is larger. Wait, but sometimes people might get confused because 9.11 has two decimal places and think it's larger, but that's a common mistake.
Relevance Score: 95.26%

The key is that the first decimal place (tenths) is more significant than the second (hundredths). So even though 11 is more than 9, in decimals, the position matters. So, in conclusion, after checking multiple ways—converting to fractions, aligning decimals, subtracting, and using the number line—it's clear that 9.9 is larger than 9.11.
Relevance Score: 93.45%

90 is more than 11, so 90/100 is larger. Therefore, 9.9 is larger than 9.11. Another way to think about it is by aligning the decimal points and adding trailing zeros to make them the same length:

9.90
9.11

Starting from the left, the whole number part is 9 in both. Then the tenths place

## **10. Format the Top Ranked Documents into a Structured Prompt**

We define a function to format the most relevant chunks into a refined prompt.


In [15]:
def format_documents_into_prompt(documents):
    """
    Formats a list of document dictionaries into a unified prompt string.

    Each document has these keys:
      - 'index': the document's index (used for ordering)
      - 'relevance_score': a numeric score indicating relevance

    The function sorts the documents by their index in ascending order and
    constructs a prompt that clearly displays each document's index, relevance score,
    and text.
    """
    # Sort the documents by 'relevance_score'
    sorted_docs = sorted(documents, key=lambda doc: doc.relevance_score)

    # Build the prompt with a header and each document's details
    prompt_lines = ["Consider the following thoughts:\n"]
    for doc in sorted_docs:
        index = doc.index
        relevance = doc.relevance_score
        text = chunks[index]
        prompt_lines.append(f"Thought (index {index}, relevance: {relevance:.2%}):\n{text}\n")

    # Join the individual parts into a single prompt string
    prompt = "\n".join(prompt_lines)
    return prompt

In [16]:
documents_prompt = format_documents_into_prompt(reranking.results)

In [17]:
documents_prompt

"Consider the following thoughts:\n\nThought (index 8, relevance: 91.99%):\nI think that's solid. But just to be thorough, let's subtract them to see the difference. 9.9 minus 9.11. To subtract, line them up:\n\n9.90\n-9.11\n------\n0.79\n\nSo, 9.9 is 0.79 more than 9.11, which confirms it's larger.\n\nThought (index 7, relevance: 92.33%):\n1. Since 9 is greater than 1, 9.90 is already larger here. The hundredths place doesn't matter once the tenths place is different. So, even though 9.11 has an extra digit, the tenths place of 9.9 is much higher, making it the larger number.\n\nThought (index 6, relevance: 92.64%):\n90 is more than 11, so 90/100 is larger. Therefore, 9.9 is larger than 9.11. Another way to think about it is by aligning the decimal points and adding trailing zeros to make them the same length:\n\n9.90\n9.11\n\nStarting from the left, the whole number part is 9 in both. Then the tenths place: 9 vs.\n\nThought (index 10, relevance: 93.45%):\nThe key is that the first de

## **11. Generate a Final Response Using a smaller LLM**

We concatenate the refined prompt with the original question and generate a final response.

In [18]:
reranked_prompt = prompt + "\n" + documents_prompt

In [19]:
reranked_prompt

"Which number is larger, 9.9 or 9.11?\nConsider the following thoughts:\n\nThought (index 8, relevance: 91.99%):\nI think that's solid. But just to be thorough, let's subtract them to see the difference. 9.9 minus 9.11. To subtract, line them up:\n\n9.90\n-9.11\n------\n0.79\n\nSo, 9.9 is 0.79 more than 9.11, which confirms it's larger.\n\nThought (index 7, relevance: 92.33%):\n1. Since 9 is greater than 1, 9.90 is already larger here. The hundredths place doesn't matter once the tenths place is different. So, even though 9.11 has an extra digit, the tenths place of 9.9 is much higher, making it the larger number.\n\nThought (index 6, relevance: 92.64%):\n90 is more than 11, so 90/100 is larger. Therefore, 9.9 is larger than 9.11. Another way to think about it is by aligning the decimal points and adding trailing zeros to make them the same length:\n\n9.90\n9.11\n\nStarting from the left, the whole number part is 9 in both. Then the tenths place: 9 vs.\n\nThought (index 10, relevance: 

In [20]:
final_response = together_client.chat.completions.create(
    model='meta-llama/Llama-3.2-3B-Instruct-Turbo',
    messages=[{'role': 'user', 'content': reranked_prompt}],
)

In [21]:
final_response.choices[0].message.content

'The number 9.9 is larger than 9.11. \n\nThe reasoning provided in the thoughts is consistent and accurate. The key points are:\n\n1. Subtracting 9.11 from 9.9 results in 0.79, confirming that 9.9 is larger.\n2. The tenths place of 9.9 is higher than the tenths place of 9.11.\n3. Aligning the decimal points and adding trailing zeros makes 9.90 larger than 9.11.\n4. The position of the decimal places matters, with the tenths place being more significant than the hundredths place.\n5. Comparing the numbers on a number line shows that 9.9 is closer to 10 than 9.11 is.\n\nAll these points support the conclusion that 9.9 is indeed larger than 9.11.'

## Conclusion

This notebook has demonstrated a method to refine the chain-of-thought output by isolating and reranking its components. The strategy involves:
- Extracting the internal reasoning (COT) from the raw LLM response.
- Splitting and grouping the reasoning into manageable chunks.
- Reranking these chunks based on relevance to the prompt.
- Reconstructing a refined prompt that minimizes contradictory thoughts and potentially reduces token usage.

By selecting only the most coherent and pertinent pieces of the chain-of-thought, we can optimize the information fed back to the LLM, leading to more reliable and concise final outputs. This approach paves the way for further experiments aimed at reducing ambiguity and streamlining token consumption in LLM interactions.