# RAG

In [33]:
import torch

from typing import Any, Dict, Iterator, List, Optional, TypedDict

from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings
from langchain.storage import LocalFileStore
from langchain_community.vectorstores import FAISS
from langchain.embeddings import CacheBackedEmbeddings
from langchain.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langchain.schema import AIMessage, HumanMessage
from langchain_core.tools import tool

from langgraph.graph import START, END, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import MessagesState, StateGraph

from lightning import Fabric
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, PeftModelForCausalLM, PeftModel

from IPython.display import display, Markdown, Image, SVG

### Set mixed precision

In [2]:
fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
device = fabric.device
fabric.launch()
torch.set_float32_matmul_precision("medium")

Using bfloat16 Automatic Mixed Precision (AMP)
You are using a CUDA device ('NVIDIA GeForce RTX 4080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


### Text Splitter

ref: https://python.langchain.com/docs/how_to/markdown_header_metadata_splitter/

In [3]:
with open('../datasets/spell_content/1st Level.txt', 'r') as file:
    content = file.read()

In [4]:
headers_to_split_on = [
    ("#", "Spell Name"),
]

markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on, strip_headers=False)
md_header_splits = markdown_splitter.split_text(content)

### Load Embedding Data

In [5]:
embed_model_name = "sentence-transformers/all-mpnet-base-v2"

embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)
vector_store = FAISS.from_documents(documents=md_header_splits, embedding=embeddings)
retriever = vector_store.as_retriever()

In [6]:
query = "give me Detect Magic spell?"
results = retriever.get_relevant_documents(query)

display(Markdown(results[0].page_content))

  results = retriever.get_relevant_documents(query)


# Detect Magic
## Spell Name
Detect Magic  
From Player's Handbook, page 231.
## Description
*1st-level divination (ritual)*
*1st-level divination (ritual)*
* **Casting Time:** 1 action
* **Range:** Self
* **Components:** V, S
* **Duration:** Concentration, up to 10 minutes
- **Casting Time:** 1 action
**Casting Time:**
- **Range:** Self
**Range:**
- **Components:** V, S
**Components:**
- **Duration:** Concentration, up to 10 minutes
**Duration:**
For the duration, you sense the presence of magic within 30 feet of you. If you sense magic in this way, you can use your action to see a faint aura around any visible creature or object in the area that bears magic, and you learn its school of magic, if any.
The spell can penetrate most barriers, but it is blocked by 1 foot of stone, 1 inch of common metal, a thin sheet of lead, or 3 feet of wood or dirt.
## Learned By
* **Classes:** Artificer, Bard, Cleric, Druid, Paladin, Ranger, Sorcerer, Wizard
* **Subclasses:** Cleric (*Arcana Domain*), Fighter (*Eldritch Knight*), Fighter (*Monster Hunter*), Paladin (*Oath of the Watchers*), Rogue (*Arcane Trickster*), Sorcerer (*Aberrant Mind*), Sorcerer (*Divine Soul*), Wizard (*Theurgy*)
* **Eldritch Invocations:** Eldritch Sight
* **Races:** Firbolg, Mark of Detection Half-elf, Owlin (UA)
* **Feats:** Artificer Initiate, Drow High Magic, Fey Touched, Magic Initiate, Quicksmithing, Ritual Caster, Strixhaven Initiate
- **Classes:** Artificer, Bard, Cleric, Druid, Paladin, Ranger, Sorcerer, Wizard
**Classes:**
Artificer
Bard
Cleric
Druid
Paladin
Ranger
Sorcerer
Wizard
- **Subclasses:** Cleric (*Arcana Domain*), Fighter (*Eldritch Knight*), Fighter (*Monster Hunter*), Paladin (*Oath of the Watchers*), Rogue (*Arcane Trickster*), Sorcerer (*Aberrant Mind*), Sorcerer (*Divine Soul*), Wizard (*Theurgy*)
**Subclasses:**
*Arcana Domain*
Arcana Domain
*Eldritch Knight*
Eldritch Knight
*Monster Hunter*
Monster Hunter
*Oath of the Watchers*
Oath of the Watchers
*Arcane Trickster*
Arcane Trickster
*Aberrant Mind*
Aberrant Mind
*Divine Soul*
Divine Soul
*Theurgy*
Theurgy
- **Eldritch Invocations:** Eldritch Sight
**Eldritch Invocations:**
Eldritch Sight
- **Races:** Firbolg, Mark of Detection Half-elf, Owlin (UA)
**Races:**
Firbolg
Mark of Detection Half-elf
Owlin (UA)
- **Feats:** Artificer Initiate, Drow High Magic, Fey Touched, Magic Initiate, Quicksmithing, Ritual Caster, Strixhaven Initiate
**Feats:**
Artificer Initiate
Drow High Magic
Fey Touched
Magic Initiate
Quicksmithing
Ritual Caster
Strixhaven Initiate

### Load model

In [None]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj"],
    lora_dropout=0.2,
    bias="none",
    task_type="CAUSAL_LM"
)

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    quantization_config=quant_config,
    )

lora_model = get_peft_model(base_model, lora_config)

model = PeftModelForCausalLM.from_pretrained(
    lora_model, 
    "../best",
    torch_dtype=torch.bfloat16,
    is_trainable=False
    )

model = model.eval()
model.config.use_cache = True


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



#### Examples

```python
system = SystemMessage(
        content=(
            "in a text-based adventure (Dungeons and Dragons).\n"
            "Your job is to narrate the adventure and respond to the player's actions.\n"
            "Use the following pieces of retrieved context to answer the question.\n"
            "If you don't know the answer, say that you don't know. If player break the game rule notice to player.\n"
            "answer concise."
            "\n\n"
            "{docs_content}"
            "When you anwser to player you must answer in proper markdown format. (heading, table, bold, italic, paragraph, blockquotes)\n"
        )
    )

player_messages = [
    HumanMessage(content="<|start_header_id|>player1<|end_header_id|>\nI draw my sword!<|eot_id|>"),
    HumanMessage(content="<|start_header_id|>player2<|end_header_id|>\nI cast a fireball!<|eot_id|>"),
    HumanMessage(content="<|start_header_id|>player2<|end_header_id|>\nwhat player 1 do<|eot_id|>"),
]
```

In [67]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=False,
    max_new_tokens=4096,
    top_k=50,
    temperature=0.1,
    device_map="auto"
)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

llm = HuggingFacePipeline(pipeline=pipe, model_kwargs = {'temperature': 0.9, "torch_dtype": torch.bfloat16})

Device set to use cuda:0
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DeepseekV3ForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForConditionalGeneration', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'Glm4ForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoFo

In [68]:
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str


def retrieve(state: State):
    try:
        retrieved_docs = vector_store.similarity_search(state["question"], k=3)
        return {"context": retrieved_docs}
    except Exception as e:
        return {"context": []}


def generate(state: State):
    # Join the retrieved document content to form context
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])

    # If no relevant context was retrieved, handle it gracefully
    if not docs_content:
        return {"answer": "Sorry, I don't have enough context to answer your question."}

    # Prepare the system message for the model
    system_message_content = (
        "<|start_header_id|>system<|end_header_id|>\n"
        "In a text-based adventure (Dungeons and Dragons), your job is to narrate the adventure "
        "and respond to the player's actions.\n"
        "Use the following pieces of retrieved context to answer the question.\n"
        "If you don't know the answer, say that you don't know. If the player breaks the game rules, "
        "notify the player.\n"
        "This is the retrieved context:\n\n"
        f"{docs_content}\n\n"
        "When you answer the player, you must respond in proper markdown format: heading, table, bold, italic, paragraph, blockquotes.\n"
    )

    # Create the system and human messages
    system_message = SystemMessage(content=system_message_content)
    messages = [system_message] + [HumanMessage(content=state["question"])]

    try:
        # Join the content and invoke the model
        response = llm.invoke("\n".join([msg.content for msg in messages]))
        return {"answer": response}
    except Exception as e:
        # Handle errors in generation
        return {"answer": "An error occurred while generating the response."}

# State Graph for managing retrieval and generation steps
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

In [73]:
response = graph.invoke({"question": "<|start_header_id|>player1<|end_header_id|>\nHow I obtain a Guiding Bolt.<|eot_id|>"})

print(response['question'])
print(response['answer'])

<|start_header_id|>player1<|end_header_id|>
How I obtain a Guiding Bolt.<|eot_id|>
assistant

**Obtaining Guiding Bolt**

Guiding Bolt is a 1st-level evocation spell found in the Player's Handbook, page 248. To obtain this spell, you can choose to learn it as a Cleric, Druid, Paladin, Sorcerer, Warlock, or Wizard. Additionally, you can learn it as a Mage of Quandrix or a Strixhaven Initiate.

As a Cleric, you can learn Guiding Bolt as part of your spellcasting abilities, as it is listed under the Cleric's spellcasting class features. Similarly, as a Druid, Paladin, Sorcerer, Warlock, or Wizard, you can learn Guiding Bolt as part of your spellcasting abilities, as it is listed under your respective class features.

As a Mage of Quandrix or a Strixhaven Initiate, you can learn Guiding Bolt as part of your magical training, as it is listed under your respective backgrounds.

Once you have obtained the spell, you can cast it using your spell slots, following the spell's casting time, range