## 環境：サーバーレス・ノートブック


https://github.com/run-llama/llama_index/blob/main/llama-index-packs/llama-index-packs-raft-dataset/examples/raft_dataset.ipynb

# RAFT Dataset LlamaPack

このLlamaPackはRAFT: Adapting Language Model to Domain Specific RAG [論文](https://arxiv.org/abs/2403.10131)を実装しています。

Retrieval Augmented FineTuning (RAFT)は、この論文で紹介されている学習レシピで、オープンブック、ドメイン内の質問応答タスクにおける大規模言語モデル(LLM)の性能を向上させることを目的としている。RAFTは、質問と検索された文書セットが与えられたとき、LLMが、無関係な情報や散漫な情報を無視しながら、質問の答えに役立つ文書から最も関連性の高いシーケンスを特定し、逐語的に引用するように訓練する。RAFTは、関連する情報と関連しない情報を区別し、関連する文書から証拠を提供するようにモデルを明示的に訓練することで、LLMがより優れた推論と説明の能力を開発することを促し、最終的には、追加のコンテキストや知識が利用可能なシナリオで質問に正確かつ合理的に回答する能力を向上させる。

RAFTの重要な構成要素は、微調整のためのデータセットの生成方法である。各QAペアには、質問の答えを推測できる「オラクル」文書と、無関係な「ディストラクター」文書も含まれる。学習中、これはモデルにどの情報が関連/非関連かを学習させ、ドメイン知識も記憶させる。

#### Installation

In [0]:
%pip install llama-index
%pip install llama-index-packs-raft-dataset llama-index-embeddings-databricks llama-index-llms-databricks
dbutils.library.restartPython()

#### Download Data

In [0]:
!wget --user-agent "Mozilla" "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" -O './paul_graham_essay.txt'

The above dataset is HuggingFace Dataset format. You can then save it into `.arrow` or `.jsonl` format and use it for further finetuning. 

#### You can refer to the original implementation [here](https://github.com/ShishirPatil/gorilla/tree/main/raft)

https://github.com/run-llama/llama_index/blob/main/llama-index-packs/llama-index-packs-raft-dataset/llama_index/packs/raft_dataset/base.py

In [0]:
"""RAFT Dataset LlamaPack class."""

# Inspired from https://github.com/ShishirPatil/gorilla/tree/main/raft

from typing import Any, List
import random
import logging
import warnings

from datasets import Dataset

# Configure logging to output to the console, with messages of level DEBUG and above
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.core import SimpleDirectoryReader

from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

DEFAULT_CHUNK_SIZE = 512
DEFAULT_BREAKPOINT_PERCENTILE_THRESHOLD = 95


class MyRAFTDatasetPack(BaseLlamaPack):
    """RAFT Dataset Generator pack."""

    def __init__(
        self,
        file_path: str,
        llm: Any = None,
        embed_model: Any = None,
        num_questions_per_chunk: int = 5,
        num_distract_docs: int = 3,
        chunk_size: int = DEFAULT_CHUNK_SIZE,
        default_breakpoint_percentile_threshold=DEFAULT_BREAKPOINT_PERCENTILE_THRESHOLD,
    ):
        self.file_path = file_path
        self.num_questions_per_chunk = num_questions_per_chunk
        self.num_distract_docs = num_distract_docs
        self.chunk_size = chunk_size
        self.default_breakpoint_percentile_threshold = (
            default_breakpoint_percentile_threshold
        )
        self.ds = None
        self.llm = OpenAI(temperature=0, n=1, model="gpt-4o") if llm is None else llm
        self.embed_model = OpenAIEmbedding(model="text-embedding-ada-002") if embed_model is None else embed_model

    def strip_str(self, s) -> str:
        """
        Helper function for helping format strings returned by GPT-4.
        """
        if s.startswith("assistant:"):  # Check if the string starts with 'assistant '
            s = s.replace("assistant:", "", 1)  # Replace the first occurrence

        start_index, end_index = 0, len(s) - 1
        beg_found = False
        for i in range(len(s)):
            if s[i].isalpha():
                if not beg_found:
                    start_index = i
                    beg_found = True
                else:
                    end_index = i
        end_index += 2
        return s[start_index : min(end_index, len(s))]

    def encode_question_gen(self, question, chunk) -> List[str]:
        """
        Encode multiple prompt instructions into a single string for the general case.
        """
        prompt = f"""
            Question: {question}\nContext: {chunk}\n
            Answer this question using the information given in the context above. Here is things to pay attention to:
            - First provide step-by-step reasoning in Japanese on how to answer the question.
            - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
            - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
        """
        return [
            ChatMessage(
                role="system",
                content="You are a helpful question answerer who can provide an answer in Japanese given a question and relevant context.",
            ),
            ChatMessage(role="user", content=prompt),
        ]

    def translate_doc_gen(self, chunk) -> str:
        """
        Encode multiple prompt instructions into a single string for the general case.
        """
        prompt = chunk
        question_messages = [
            ChatMessage(
                role="system",
                content="""You are an excellent English-Japanese interpreter. Translate the given English text into Japanese.

# Output Format

Output only the translated Japanese text as a single, cohesive paragraph.""",),
            ChatMessage(role="user", content=prompt),
        ]

        response = self.llm.chat(question_messages)
        return str(response)
    
    def rewrite_doc_gen(self, chunk) -> str:
        """
        Encode multiple prompt instructions into a single string for the general case.
        """
        prompt = chunk
        question_messages = [
            ChatMessage(
                role="system",
                content="""You are an excellent Japanese writer. Given text is not natural Japanese because it is translated from English by a non-native Japanese interpreter, so you have to rewrite it for more natural Japanese. 自分の信じて限界を超えて下さい！

# Output Format

Output only the rewrited Japanese text as a single, cohesive paragraph.""",),
            ChatMessage(role="user", content=prompt),
        ]

        response = self.llm.chat(question_messages)
        return str(response)
    
    def generate_label(self, question, context) -> str:
        """
        Generates the label / answer to `question` using `context` and GPT-4.
        """
        question_messages = self.encode_question_gen(question, context)
        response = self.llm.chat(question_messages)
        return str(response)

    def generate_instructions_gen(self, chunk, x=5) -> List[str]:
        """
        Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
        `pdf`, `json`, or `txt`.
        """
        messages = [
            ChatMessage(
                role="system",
                content="You are a synthetic question-answer pair generator. Given a chunk of context about some topic(s), generate %s example questions in Japanese which a user could ask and would be answered using information from the chunk. For example, if the given context was a Wikipedia paragraph about the United States, an example question could be 'アメリカにはいくつの州がありますか？'. The questions should be able to be answered in a few words or less."
                % (x),
            ),
            ChatMessage(role="user", content=str(chunk)),
        ]

        queries = str(self.llm.chat(messages)).split("\n")
        questions = [self.strip_str(q) for q in queries]
        questions = [q for q in questions if any(c.isalpha() for c in q)][:x]

        num_questions_generated = len(questions)
        if num_questions_generated < x:
            warnings.warn(
                f"Fewer questions generated ({num_questions_generated}) "
                f"than requested ({x})."
            )

        return questions

    def get_chunks(self, file_path: str, chunk_size: int) -> List[str]:
        """
        Takes in a `file_path`, retrieves the document, breaks it down into chunks of size
        `chunk_size`, and returns the chunks.
        """
        documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
        splitter = SemanticSplitterNodeParser(
            buffer_size=1,
            breakpoint_percentile_threshold=self.default_breakpoint_percentile_threshold,
            embed_model=self.embed_model,
        )
        nodes = splitter.get_nodes_from_documents(documents)

        chunks = []
        for node in nodes:
            translated_text = self.translate_doc_gen(node.get_content())
            refined_text = self.rewrite_doc_gen(translated_text)
            chunks.append(refined_text)
            print(refined_text)
            print("--------------------------------------------------------------")

        return chunks #[self.translate_doc_gen(node.get_content()) for node in nodes]

    def add_chunk_to_dataset(
        self,
        chunks: List,
        chunk: str,
        x: int = 5,
        num_distract: int = 3,
        p: float = 1.0,
    ):
        """
        Given a chunk, create {Q, A, D} triplets and add them to the dataset.
        """
        i = chunks.index(chunk)
        qs = self.generate_instructions_gen(chunk, x)
        for q in qs:
            datapt = {
                "id": None,
                "type": None,
                "question": None,
                "context": None,
                "oracle_context": None,
                "cot_answer": None,
            }

            datapt["id"] = f"seed_task_{0 if not self.ds else self.ds.num_rows}"
            datapt["type"] = "general"
            datapt["question"] = q

            # add distractor docs
            docs = [chunk]
            indices = list(range(len(chunks)))
            indices.remove(i)
            for j in random.sample(indices, num_distract):
                docs.append(chunks[j])
            # decides whether to add oracle document
            oracle = random.uniform(0, 1) < p
            if not oracle:
                docs[0] = chunks[random.sample(indices, 1)[0]]
            random.shuffle(docs)

            d = {"title": [], "sentences": []}

            d["title"].append(["placeholder_title"] * (num_distract + 1))
            d["sentences"].append(docs)
            datapt["context"] = d
            datapt["oracle_context"] = chunk

            # add answer to q
            datapt["cot_answer"] = self.generate_label(q, chunk)

            # construct model instruction
            context = ""
            for doc in docs:
                context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
            context += q
            datapt["instruction"] = context

            # add to dataset
            if not self.ds:
                # init ds
                datapt["id"] = [datapt["id"]]
                datapt["type"] = [datapt["type"]]
                datapt["question"] = [datapt["question"]]
                datapt["context"] = [datapt["context"]]
                datapt["oracle_context"] = [datapt["oracle_context"]]
                datapt["cot_answer"] = [datapt["cot_answer"]]
                datapt["instruction"] = [datapt["instruction"]]
                self.ds = Dataset.from_dict(datapt)
            else:
                self.ds = self.ds.add_item(datapt)

    def run(self) -> Any:
        """Run the pipeline."""
        chunks = self.get_chunks(self.file_path, self.chunk_size)

        logger.info(f"Number of chunks created: {len(chunks)}")

        self.num_distract_docs = (
            min(self.num_distract_docs, len(chunks)) - 1
        )  # should be less than number of chunks/ nodes created

        for index, chunk in enumerate(chunks):
            logger.info(f"Processing chunk: {index}")
            self.add_chunk_to_dataset(
                chunks, chunk, self.num_questions_per_chunk, self.num_distract_docs
            )

        return self.ds

In [0]:
# Databricks サービングエンドポイントをルーティングするために使用される Databricks URL とトークン
databricks_host = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
databricks_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)

In [0]:
EMB_MODEL_NAME = "YOUR-EMBEDDING-MODEL-NAME"
LLM_MODEL_NAME = "YOUR-LLM-MODEL-NAME"

In [0]:
import os
from llama_index.core import Settings
from llama_index.embeddings.databricks import DatabricksEmbedding

# Set up the DatabricksEmbedding class with the required model, API key and serving endpoint
os.environ["DATABRICKS_TOKEN"] = databricks_token
os.environ["DATABRICKS_SERVING_ENDPOINT"] = f"{databricks_host}/serving-endpoints"
embed_model = DatabricksEmbedding(model=EMB_MODEL_NAME)
Settings.embed_model = embed_model


# Embed some text
embeddings = embed_model.get_text_embedding(
    "The DatabricksEmbedding integration works great."
)
embeddings

In [0]:
from llama_index.llms.databricks import Databricks
from llama_index.core.llms import ChatMessage

llm = Databricks(
    model=LLM_MODEL_NAME,
    api_key=databricks_token,
    api_base=f"{databricks_host}/serving-endpoints",
)

messages = [
    ChatMessage(
        role="system", content="You are a pirate with a colorful personality"
    ),
    ChatMessage(role="user", content="What is your name"),
]
resp = llm.chat(messages)
resp

In [0]:
raft_dataset = MyRAFTDatasetPack(
  "./paul_graham_essay.txt", 
  llm=llm, 
  embed_model=embed_model)
  
dataset = raft_dataset.run()

In [0]:
output_path = "raft_training_dataset_ja"

# Save as .arrow format
dataset.save_to_disk(output_path)

# Save as .jsonl format
dataset.to_json(output_path + ".jsonl")

In [0]:
import json

# JSONLファイルのパス
file_path = f'{output_path}.jsonl'

# JSONオブジェクトを格納するリスト
data_list = []

# JSONLファイルを読み込んで各行をパース
with open(file_path, 'r', encoding='utf-8') as file:
    count = 0
    for line in file:
        # 各行のJSON文字列をPythonの辞書に変換してリストに追加
        data_dict = json.loads(line)
        # print(data_dict)
        data_list.append(data_dict)
        print(data_dict["instruction"])
        # print("[ANSWER]: " + data_dict["cot_answer"][-100:])
        if "<ANSWER>:" in data_dict["cot_answer"]:
            answer_start = data_dict["cot_answer"].index("<ANSWER>:") + len("<ANSWER>:")
            print("<ANSWER>:" + data_dict["cot_answer"][answer_start:])
        print("-----------------------------------------------------------")
        count += 1
    print(count)