Use Decoding CoT from the paper and implementation from [OptiLLM](https://github.com/codelion/optillm/blob/main/optillm/cot_decoding.py).

Make sure you set your HuggingFace token and have access to meta-llama/Meta-Llama-3-8B-Instruct.

Use an L4 or A100 GPU.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
# make sure you have access to the model via HuggingFace + Get token in Settings
hf_auth = ''

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map='cuda',
    token=hf_auth
)

tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    token=hf_auth
)

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

The code below here is only showing you how this will work, looking into the function that will perform decoding CoT on a message.

In [None]:
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

In [None]:
from typing import List, Tuple, Dict, Optional
# Calculate the confidence score (Δ) by analyzing the probabilities (logits) of each token in the answer sequence.
def calculate_confidence(logits: List[torch.Tensor], answer_ids: torch.Tensor) -> float:
    """
    Args:
        logits: List of logits for each decoding step
        answer_ids: Tensor of token ids for the answer

    Returns:
        Confidence score (Δ)
    """
    confidence_sum = 0.0
    valid_tokens = 0
    for t, token_id in enumerate(answer_ids):
        if t >= len(logits):
            break
        token_logits = logits[t]
        probs = torch.softmax(token_logits, dim=-1)
        if probs.size(-1) > 1:
            top_2_probs, _ = torch.topk(probs, min(2, probs.size(-1)))
            if top_2_probs.size(-1) > 1:
                confidence_sum += (top_2_probs[-1][0] - top_2_probs[-1][1]).item()
            else:
                confidence_sum += 1.0  # Max confidence if there's only one token
        else:
            confidence_sum += 1.0  # Max confidence if there's only one token
        valid_tokens += 1

    return confidence_sum / valid_tokens if valid_tokens > 0 else 0.0

In [None]:

# Implement CoT-decoding for a given message.
def cot_decode(
    model,
    tokenizer,
    messages,
    k: int = 10,
    num_beams: int = 1,
    max_new_tokens: int = 512,
    temperature: float = 1.0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    length_penalty: float = 1.0,
    no_repeat_ngram_size: int = 0,
    early_stopping: bool = False,
) -> Tuple[str, float, int]:
    """
    Implement CoT-decoding for a given chat input.
    Returns: (output_text, confidence_score, llm_calls)
    """
    device = get_device()
    model.to(device)

    llm_calls = 0  # Initialize counter

    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids).to(device)

    # Set pad_token_id if it's not set
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # First LLM call to predict the next possible token(s) with the highest logits.
    # These will serve as potential starting points for generating multiple answer paths
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        llm_calls += 1  # First call
        first_token_logits = outputs.logits[0, -1, :]
        top_k_logits, top_k_indices = torch.topk(first_token_logits, k)

    paths = []
    for idx in top_k_indices:
        # Generate sequence starting with the selected token
        start_ids = torch.cat([input_ids, idx.unsqueeze(0).unsqueeze(0)], dim=-1)
        start_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype=torch.long, device=device)], dim=-1)

        output = model.generate(
            start_ids,
            attention_mask=start_mask,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            early_stopping=early_stopping,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True,
        )
        llm_calls += 1  # One call per path

        generated_sequence = output.sequences[0]
        answer_ids = generated_sequence[len(input_ids[0]):]
        answer_text = tokenizer.decode(answer_ids, skip_special_tokens=True)

        confidence = calculate_confidence(output.scores, answer_ids)
        paths.append((answer_text, confidence))

        best_answer, confidence = max(paths, key=lambda x: x[1])

    return best_answer, confidence, llm_calls

Run it below. Be patient as it can be slow (depending on compexity of the question)

In [None]:
messages = [
    {"role": "user", "content": "Give me three sentences that end in ”is”"}
]

best_answer, confidence, llm_calls = cot_decode(
    model=model,
    tokenizer=tokenizer,
    messages=messages
)

# Display the results
print(f"Best Answer:\n{best_answer}\n")
print(f"Confidence Score: {confidence}")
print(f"Total LLM Calls: {llm_calls}")


Best Answer:
_HERE are three sentences that end in "is":

1. The cat is.
2. The book is.
3. The sun is.

Confidence Score: 0.8540930008249623
Total LLM Calls: 11


If I ask Llama 3 8b via Groq, I get this answer:

Here are three sentences that end in "is":

    The sun is shining brightly in the sky.
    The cat is sleeping peacefully on the couch.
    The book is written by my favorite author.

Which is showing that this system is pushing some improvements on its abilities.