Implementation of [Universal Self-Consistency for Large Language Model Generation](https://arxiv.org/pdf/2311.17311.pdf)

In [1]:
device = "cuda"

# Quantization settings
quantization_enabled = True
use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = True

# Model
model_name='mistralai/Mistral-7B-Instruct-v0.1'

In [2]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

from util import HuggingFaceChatModel
from langchain.schema import BaseMessage, AIMessage, HumanMessage

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
print(f"{torch.__version__=}")
print(f"{torch.version.cuda=}")
print(f"{torch.cuda.is_available()=}")
print(f"{torch.cuda.device_count()=}")

if "cuda" in device:
    assert torch.cuda.is_available(), "CUDA is not available"

torch.__version__='2.1.1+cu118'
torch.version.cuda='11.8'
torch.cuda.is_available()=True
torch.cuda.device_count()=1


In [4]:
!nvidia-smi

Sat Dec 16 06:06:10 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        Off | 00000000:07:00.0 Off |                  N/A |
|  0%   51C    P8              28W / 420W |    147MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
# Quantization config for bitsandbytes
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, device_map="auto")
tokenizer.pad_token_id = tokenizer.unk_token_id

if quantization_enabled:
    model_kwargs = {"quantization_config": bnb_config}
else:
    model_kwargs = {"torch_dtype": torch.float16}

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    **model_kwargs,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.84s/it]


In [6]:
chat_model = HuggingFaceChatModel(
    model=model,
    tokenizer=tokenizer,
    generate_kwargs=dict(
        # top_p=0.92,
        # output_scores=True, 
        # num_return_sequences=3
    )
)

In [7]:
import copy
from operator import itemgetter
from langchain_core.runnables import RunnableLambda, RunnablePassthrough, RunnableParallel
from langchain_core.prompts import ChatPromptTemplate, BasePromptTemplate, format_document
from langchain.callbacks.tracers import ConsoleCallbackHandler
from langchain_core.documents import Document
from langchain.output_parsers import RegexParser
from functools import partial
import random

# messages = [
#     HumanMessage(content="What started the Dark Ages in Europe?"),
# ]

def build_repeat_chain(num_repeat=1):
    return RunnableLambda(lambda x: [x for _ in range(num_repeat)])

def convert_to_document(message: AIMessage):
    return Document(
        page_content=message.content,
    )

document_prompt = ChatPromptTemplate.from_template("{page_content}")
partial_format_document = partial(format_document, prompt=document_prompt)

def format_docs(docs):
    formatted = [f"Response {i}\n{partial_format_document(doc)}" for i, doc in enumerate(docs)]
    return "\n\n".join(formatted)

# This chain take an input to a chat model and outputs 5 sample outputs from
# the chat model
chat_model_for_sampling = chat_model.bind(temperature=1.1, max_tokens=1000).with_config({'callbacks': [ConsoleCallbackHandler()]})
sample_llm_chain = build_repeat_chain(5) | chat_model_for_sampling.map()

# This chain takes a list of llm outputs and a question, selects the consensus
# answer, and then outputs the consensus answer
select_consensus_template = ChatPromptTemplate.from_template(
    "I have generated the following responses to the question: {question}\n\n"
    "{context}\n\n"
    "Evaluate these responses.\n"
    "Select the most consistent response based on majority consensus.\n"
    "Start your answer with \"The most consistent response is Response X\" (without quotes)\n"
)

response_selection_parser = RegexParser(
    regex=r"(?i)response\s+(\d+)",
    output_keys=["response_selected_index"],
)

chat_model_for_consistency = chat_model.bind(temperature=0, max_tokens=1000).with_config({'callbacks': [ConsoleCallbackHandler()]})

self_consistency_chain = (
    {"question": itemgetter("question"), "responses": itemgetter("responses"), "context": itemgetter("responses") | RunnableLambda(convert_to_document).map() | format_docs}
    | RunnablePassthrough.assign(response_selected_index=select_consensus_template | chat_model_for_consistency | response_selection_parser | itemgetter(response_selection_parser.output_keys[0]) | int)
    | RunnableLambda(lambda x: x["responses"][x["response_selected_index"]])
)


# This chain is a basic rag chain
rag_template = ChatPromptTemplate.from_template("""Answer the question based only on the following context:
{context}

Question: {question}
""")

def build_basic_rag_chain(retriever, chat_chain, prompt):
    return (
        {"context": retriever, "question": itemgetter("question")}
        | prompt
        | chat_chain
    )

def build_universal_consistency_chain(samples_chain, chat_chain, prompt, format_docs=format_docs):
    return (
        {"question": itemgetter("question"), "responses": itemgetter("responses"), "context": itemgetter("responses") | RunnableLambda(convert_to_document).map() | format_docs}
        | prompt
        | chat_chain
    )

rag_chain = (
    {"context": RunnableLambda(lambda x: ""), "question": RunnablePassthrough()}
    | rag_template
)

combined_chain = (
    {"question": RunnablePassthrough()}
    | RunnableParallel({"question": itemgetter("question"), "responses": rag_chain | sample_llm_chain})
    | self_consistency_chain
)

combined_chain.invoke("What were the Dark Ages in Europe?")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_

[32;1m[1;3m[llm/start][0m [1m[1:llm:HuggingFaceChatModel] Entering LLM run with input:
[0m{
  "prompts": [
    "Human: Answer the question based only on the following context:\n\n\nQuestion: {'question': 'What were the Dark Ages in Europe?'}"
  ]
}
[32;1m[1;3m[llm/start][0m [1m[1:llm:HuggingFaceChatModel] Entering LLM run with input:
[0m{
  "prompts": [
    "Human: Answer the question based only on the following context:\n\n\nQuestion: {'question': 'What were the Dark Ages in Europe?'}"
  ]
}
[32;1m[1;3m[llm/start][0m [1m[1:llm:HuggingFaceChatModel] Entering LLM run with input:
[0m{
  "prompts": [
    "Human: Answer the question based only on the following context:\n\n\nQuestion: {'question': 'What were the Dark Ages in Europe?'}"
  ]
}
[32;1m[1;3m[llm/start][0m [1m[1:llm:HuggingFaceChatModel] Entering LLM run with input:
[0m{
  "prompts": [
    "Human: Answer the question based only on the following context:\n\n\nQuestion: {'question': 'What were the Dark Ages in E

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[36;1m[1;3m[llm/end][0m [1m[1:llm:HuggingFaceChatModel] [44.21s] Exiting LLM run with output:
[0m{
  "generations": [
    [
      {
        "text": "The Dark Ages, also known as the Middle Ages, were a period of European history lasting from the fall of the Western Roman Empire in 476 AD to the beginning of the Renaissance in the 14th century. During this time, Europe experienced significant economic, social, and cultural changes. The fall of the Roman Empire led to a breakdown in trade and communication, resulting in a period of relative isolation and poverty for many regions of Europe. The rise of feudalism, in which lords granted land to vassals in exchange for military service, became widespread. The Black Death, a devastating pandemic, swept through Europe in the 14th century, killing millions of people and further disrupting society. During the Dark Ages, art, literature, and science continued to flourish in some parts of Europe, but some aspects of society, such as social h

AIMessage(content='The Dark Ages, also referred to as the Middle Ages, were a period of European history that spanned from the 5th through the 15th century. during this time, Europe experienced significant social, political, and economic changes, including the decline of the Roman Empire, the rise of feudalism, and the emergence of powerful kingdoms and empires. In many parts of the continent, laughter and light were scarce, illiteracy was high, and poverty and disease were common. However, this period also saw the development of art, literature, and architecture, as well as great leaps in scientific and technological advangements.')