# Basic Setup for Embedding & Generation Models


In [None]:
import os
from dotenv import load_dotenv
load_dotenv(dotenv_path=".env")

In [None]:
import os
from openai import OpenAI
client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY","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"),
    base_url="https://openrouter.ai/api/v1"
)

def gen_gpt_messages(prompt):
    messages = [{"role": "user", "content": prompt}]
    return messages
def get_completion(prompt, model="openai/gpt-oss-120b", temperature = 0.7):
    response = client.chat.completions.create(
        model=model,
        messages=gen_gpt_messages(prompt),
        temperature=temperature,
    )
    if len(response.choices) > 0:
        return response.choices[0].message.content
    return "generate answer error"

In [None]:
get_completion("你好")

In [None]:
import os
import requests

def siliconflow_embedding(text: str, model: str = None):
    api_key = os.environ.get("SILICONFLOW_API_KEY")
    if not api_key:
        raise ValueError("Please set the environment variable SILICONFLOW_API_KEY first.")
    if model is None:
        model = "BAAI/bge-m3"

    url = "https://api.siliconflow.cn/v1/embeddings"
    payload = {
        "model": model,
        "input": text
    }
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }

    response = requests.post(url, json=payload, headers=headers)
    res_json = response.json()
    if "data" in res_json and len(res_json["data"]) > 0:
        return res_json["data"][0]["embedding"]
    else:
        raise RuntimeError(f"Failed to get embedding: {res_json}")

embedding_vector = siliconflow_embedding("The input text string to generate an embedding for.")
print("Embedding vector length:", len(embedding_vector))
print(embedding_vector[:10], "...")

## Process Knowledge Base Documents and Insert into Chroma Database

In [None]:
import os
import nltk
from pathlib import Path
from langchain_community.document_loaders import UnstructuredMarkdownLoader

nltk.data.path.insert(0, os.path.expanduser('~/nltk_data'))
base_dir = Path.cwd()

md_files = list(base_dir.glob("inputs/**/*.md"))
print("Number of Markdown files found:", len(md_files))

all_docs = []
for file_path in md_files:
    loader = UnstructuredMarkdownLoader(file_path)
    docs = loader.load()
    all_docs.extend(docs)

In [None]:
import re
re_join_nonzh_break = re.compile(r'([^\u4e00-\u9fff])\n([^\u4e00-\u9fff])', re.DOTALL)
def clean_text(text: str, is_md: bool = False) -> str:
    text = text.replace('\r\n', '\n').replace('\r', '\n')
    text = re_join_nonzh_break.sub(r'\1\2', text)
    text = text.replace('•', '').replace(' ', '').replace('\u3000', '')
    if is_md:
        text = re.sub(r'\n{2,}', '\n', text)
    return text

cleaned_docs = []
for doc in all_docs:
    src = str(doc.metadata.get("source", "")).lower()
    is_md = src.endswith(".md")
    doc.page_content = clean_text(doc.page_content, is_md=is_md)
    cleaned_docs.append(doc)

print(f"Cleaning completed: {len(cleaned_docs)} documents")

In [None]:
from langchain_text_splitters import TokenTextSplitter
import tiktoken

enc = tiktoken.get_encoding("cl100k_base")
def tlen(s): return len(enc.encode(s))

splitter = TokenTextSplitter(
    encoding_name="cl100k_base",
    chunk_size=1200,
    chunk_overlap=100,
)

all_splits = []
for i, doc in enumerate(cleaned_docs, start=1):
    splits = splitter.split_documents([doc]) 
    all_splits.extend(splits)
    print(f"Doc {i}: {len(splits)} chunks, "
          f"max tokens: {max(tlen(c.page_content) for c in splits)}")

print(f"Total chunks: {len(all_splits)}")

In [None]:
import os
import requests
from typing import List
from langchain.embeddings.base import Embeddings

class SiliconFlowEmbeddings(Embeddings):
    """Embedding model wrapper for SiliconFlow API."""

    def __init__(self, model: str = "BAAI/bge-m3", api_key: str = None):
        """
        Initialize the SiliconFlow Embeddings client.

        Args:
            model (str): The name of the embedding model. Default is "BAAI/bge-m3".
            api_key (str): API key for SiliconFlow. If not provided, it will be read from
                           the environment variable SILICONFLOW_API_KEY.
        """
        self.model = model
        self.api_key = api_key or os.environ.get("SILICONFLOW_API_KEY")
        if not self.api_key:
            raise ValueError("Please set the SILICONFLOW_API_KEY environment variable or pass api_key explicitly.")

    def _embedding_request(self, inputs: List[str]) -> List[List[float]]:
        url = os.environ.get("EMBEDDING_BASE_URL")
        payload = {
            "model": self.model,
            "input": inputs
        }
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        response = requests.post(url, json=payload, headers=headers)
        res_json = response.json()
        if "data" in res_json and len(res_json["data"]) > 0:
            return [item["embedding"] for item in res_json["data"]]
        else:
            raise RuntimeError(f"Failed to retrieve embeddings: {res_json}")

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        result = []
        batch_size = 64
        for i in range(0, len(texts), batch_size):
            batch_embeddings = self._embedding_request(texts[i:i+batch_size])
            result.extend(batch_embeddings)
        return result

    def embed_query(self, text: str) -> List[float]:
        return self.embed_documents([text])[0]

In [None]:
from langchain_community.vectorstores import Chroma

embedding = SiliconFlowEmbeddings()
persist_directory = 'data_base/vector/chroma'

# If the Chroma directory already exists, you can choose to delete it manually
# (not handled automatically in this script)

vectordb = Chroma.from_documents(
    documents=all_splits,
    embedding=embedding,
    persist_directory=persist_directory  # This allows the vector store to be persisted to disk
)

print(f"Number of documents stored in the vector database: {vectordb._collection.count()}")

In [None]:
# This is a small example case
question = "What should we pay attention to if we are outside during a typhoon?"
# Perform Max Marginal Relevance (MMR) search to retrieve top 3 diverse results
mmr_docs = vectordb.max_marginal_relevance_search(question, k=3)
# Print the retrieved documents
for i, sim_doc in enumerate(mmr_docs):
    print(f"MMR result #{i}:\n{sim_doc.page_content[:200]}", end="\n--------------\n")


## RAG 性能测试

In [None]:
questions = []
from langchain_openai import ChatOpenAI
api_key=os.environ.get("OPENAI_API_KEY")
base_url=os.environ.get("OPENAI_BASE_URL")
llm = ChatOpenAI(model_name="openai/gpt-oss-120b", temperature=0, api_key=api_key, base_url=base_url)

In [None]:
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from tqdm import tqdm

# Prompt template: answer questions based on context with expert-level emergency knowledge
template = """You are an experienced expert in emergency management, skilled at handling all types of public emergencies.
Based on the following context information, answer the final question:
- If the context does not contain relevant information, explicitly say "I don't know". Do not make up answers.
- Keep your response concise, accurate, and professional.

{context}
Question: {input}
"""
prompt = PromptTemplate.from_template(template)

# Build retriever and QA chain
retriever = vectordb.as_retriever(search_kwargs={"k": 10})
qa_chain = (
    RunnableParallel({"context": retriever, "input": RunnablePassthrough()})
    | prompt
    | llm
    | StrOutputParser()
)

results = []

# Process questions and collect results
for question in tqdm(questions, desc="Processing Questions"):
    docs = vectordb.similarity_search(question, k=10)
    full_docs = [doc.page_content for doc in docs]
    answer = qa_chain.invoke(question)
    results.append({
        "question": question,
        "top10_full_documents": full_docs,
        "answer": answer
    })

In [None]:
import json
import os
os.makedirs("eval", exist_ok=True)
with open("eval/qa_results.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

In [None]:
import os
import asyncio
import logging
import logging.config
import json
from tqdm.asyncio import tqdm
from dotenv import load_dotenv
import nest_asyncio

from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache
from lightrag.llm.ollama import ollama_embed
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from lightrag.kg.shared_storage import initialize_pipeline_status

nest_asyncio.apply()
load_dotenv(dotenv_path=".env", override=False)

WORKING_DIR = "./dickens"
OUTPUT_JSON = "eval/lightrag_qa_results.json"

def configure_logging():
    log_dir = os.getenv("LOG_DIR", os.getcwd())
    log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_compatible_demo.log"))
    os.makedirs(os.path.dirname(log_dir), exist_ok=True)

    logging.config.dictConfig({
        "version": 1,
        "disable_existing_loggers": False,
        "formatters": {
            "default": {"format": "%(levelname)s: %(message)s"},
            "detailed": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"},
        },
        "handlers": {
            "console": {
                "formatter": "default",
                "class": "logging.StreamHandler",
                "stream": "ext://sys.stderr",
            },
            "file": {
                "formatter": "detailed",
                "class": "logging.handlers.RotatingFileHandler",
                "filename": log_file_path,
                "maxBytes": int(os.getenv("LOG_MAX_BYTES", 10485760)),
                "backupCount": int(os.getenv("LOG_BACKUP_COUNT", 5)),
                "encoding": "utf-8",
            },
        },
        "loggers": {
            "lightrag": {
                "handlers": ["console", "file"],
                "level": "INFO",
                "propagate": False,
            },
        },
    })

    logger.setLevel(logging.INFO)
    set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

async def llm_model_func(prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs) -> str:
    return await openai_complete_if_cache(
        os.getenv("LLM_MODEL"),
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        api_key=os.getenv("LLM_BINDING_API_KEY") or os.getenv("OPENAI_API_KEY"),
        base_url=os.getenv("LLM_BINDING_HOST"),
        **kwargs,
    )

async def initialize_rag():
    rag = LightRAG(
        working_dir=WORKING_DIR,
        llm_model_func=llm_model_func,
        embedding_func=EmbeddingFunc(
            embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
            max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
            func=lambda texts: ollama_embed(
                texts,
                embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
                host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
            ),
        ),
    )
    await rag.initialize_storages()
    await initialize_pipeline_status()
    return rag

def clean_context(context):
    if isinstance(context, str):
        return ' '.join(context.strip().split())
    elif isinstance(context, list):
        return ' '.join(' '.join(c.strip().split()) for c in context)
    return context

# English version of user prompt
user_prompt = """You are an experienced expert in emergency management, skilled in handling all kinds of public emergencies.
Based on the following contextual information, answer the question at the end:
- If the context does not contain relevant information, explicitly respond with "I don't know" — do not make up answers.
- Use concise, accurate, and professional language in your answer.
"""

async def main():
    results = []
    rag = await initialize_rag()
    await rag.aclear_cache()

    for question in tqdm(questions, desc="Processing Questions"):
        try:
            context_response = await rag.aquery(
                question,
                param=QueryParam(
                    mode="hybrid",
                    stream=False,
                    user_prompt=user_prompt,
                    only_need_context=True,
                    top_k=10,
                ),
            )
            full_response = await rag.aquery(
                question,
                param=QueryParam(
                    mode="hybrid",
                    stream=False,
                    user_prompt=user_prompt,
                    only_need_context=False,
                    top_k=10,
                ),
            )

            results.append({
                "question": question,
                "context": clean_context(context_response),
                "llm_output": full_response,
            })

        except Exception as e:
            results.append({
                "question": question,
                "context": "",
                "llm_output": f"[ERROR] {str(e)}",
            })

    await rag.finalize_storages()

    with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)

    print(f"\nAll questions processed. Results saved to: {OUTPUT_JSON}")

if __name__ == "__main__":
    configure_logging()
    asyncio.run(main())


In [None]:
import json
import re

def clean_context_text(text):
    """
    Clean the context string by removing unnecessary fields and formatting artifacts.
    """
    text = re.sub(r"-----.*?-----", "", text)            
    text = re.sub(r"```json", "", text)                   
    text = re.sub(r"```", "", text)
    
    fields_to_remove = ["created_at", "file_path", "type"]
    for field in fields_to_remove:
        text = re.sub(rf'"{field}"\s*:\s*".*?"\s*,?', "", text)
    
    text = re.sub(r",\s*([}\]])", r"\1", text)           
    text = re.sub(r"\n\s*\n", "\n", text)                  
    return text.strip()

def clean_context_fields_in_place(input_file):
    """
    Load a JSON file and clean the 'context' field of each item in-place.
    """
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)

    for item in data:
        ctx = item.get("context", "")
        if isinstance(ctx, str):
            item["context"] = clean_context_text(ctx)

    with open(input_file, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"'context' field cleaned and written back to: {input_file}")
clean_context_fields_in_place("eval/lightrag_qa_results.json")