RAG-LangChain llm

Í∞ÑÏÜåÌôî + ÏïàÏ†ïÌôî Î≤ÑÏ†Ñ (Vision + PDF RAG)

In [12]:
# ============================================================
# üß† Ollama Vision + RAG PDF ÌÜµÌï© ÌååÏù¥ÌîÑÎùºÏù∏
# ============================================================
import os, base64, fitz, requests
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import OllamaEmbeddings
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_core.runnables import Runnable

# ------------------------------
# ÌôòÍ≤Ω ÏÑ§Ï†ï
# ------------------------------
PDF_PATH = "/data1/workspace/pdfs/5.pdf"
IMG_DIR = "/data1/workspace/pdf_images"
OLLAMA_URL = "http://localhost:11434/api/generate"
os.makedirs(IMG_DIR, exist_ok=True)

# ============================================================
# 1Ô∏è‚É£ PDF Ïù¥ÎØ∏ÏßÄ Ï∂îÏ∂ú
# ============================================================
doc = fitz.open(PDF_PATH)
for page_index in range(len(doc)):
    page = doc.load_page(page_index)
    for i, img in enumerate(page.get_images(full=True)):
        base_img = doc.extract_image(img[0])
        img_path = os.path.join(IMG_DIR, f"page{page_index+1}_{i+1}.{base_img['ext']}")
        with open(img_path, "wb") as f:
            f.write(base_img["image"])
doc.close()

# ============================================================
# 2Ô∏è‚É£ Vision Î™®Îç∏Î°ú Ï≤´ Ïù¥ÎØ∏ÏßÄ Î∂ÑÏÑù
# ============================================================
def analyze_image_with_ollama(image_path, model="llama3.2-vision"):
    try:
        with open(image_path, "rb") as img:
            img_b64 = base64.b64encode(img.read()).decode("utf-8")
        payload = {"model": model, "prompt": "Summarize this figure.", "images": [img_b64], "stream": False}
        res = requests.post(OLLAMA_URL, json=payload, timeout=180)
        res.raise_for_status()
        return res.json().get("response", "")
    except Exception as e:
        return f"[Vision Ïò§Î•ò]: {e}"

images = sorted([f for f in os.listdir(IMG_DIR) if f.endswith(("jpg", "jpeg", "png"))])
vision_summary = analyze_image_with_ollama(os.path.join(IMG_DIR, images[0])) if images else ""
# print("\n[Vision ÏöîÏïΩ Í≤∞Í≥º]\n", vision_summary[:500], "\n")

# ============================================================
# 3Ô∏è‚É£ ÌÖçÏä§Ìä∏ Î°úÎìú & Î∂ÑÌï†
# ============================================================
loader = PyMuPDFLoader(PDF_PATH)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
split_docs = splitter.split_documents(docs)

# ============================================================
# 4Ô∏è‚É£ ÏûÑÎ≤†Îî© & Î≤°ÌÑ∞Ïä§ÌÜ†Ïñ¥
# ============================================================
embeddings = OllamaEmbeddings(model="nomic-embed-text", base_url="http://localhost:11434")
vectorstore = FAISS.from_documents(split_docs, embeddings)
retriever = vectorstore.as_retriever()

# ============================================================
# 5Ô∏è‚É£ LLM (Runnable)
# ============================================================
class OllamaRunnable(Runnable):
    def __init__(self, model="llama3.2-vision", base_url="http://localhost:11434"):
        self.model = model
        self.base_url = base_url
    def invoke(self, input, *args, **kwargs):
        text = input.to_string() if hasattr(input, "to_string") else str(input)
        try:
            res = requests.post(
                f"{self.base_url}/api/generate",
                json={"model": self.model, "prompt": text, "stream": False},
                timeout=300
            )
            res.raise_for_status()
            return res.json().get("response", "")
        except Exception as e:
            return f"[LLM Ïò§Î•ò]: {e}"

llm = OllamaRunnable("llama3.2-vision")

# ============================================================
# 6Ô∏è‚É£ RAG + Ïª§Ïä§ÌÖÄ ÌîÑÎ°¨ÌîÑÌä∏
# ============================================================
# ============================================================
# ============================================================
prompt_template = """
You are a biomedical text analysis assistant.

Extract **only the single most relevant experimental drug** that was actually tested or administered in the study.
Prefer drugs mentioned in 'Results', 'Methods' sections, Figures and Tables.
Exclude drugs mentioned only in background, references, or literature.

==== Document Excerpt Start ====
{context}
==== Document Excerpt End ====

Guidelines:
- Prefer drugs explicitly described as being *tested*, *treated*, *administered*, or *used* in the experiments.
- Exclude drugs mentioned only in background, discussion, or references.
- Ignore drugs that appear as examples, related compounds, or comparative mentions unless they were actually used.
- Merge WordPiece fragments into full drug names.
- Remove duplicates.
- Extract **at least 1** and **at most 3** drug names.
- do not extract up to 3 drug names, extract maximum 3 drug names.
- Output only the extracted drug names, separated by semicolons (;), with no extra text or explanation.
Answer: 
 """



PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question", "vision_context"])

chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": PROMPT},
    return_source_documents=False,
    verbose=False
)

# ============================================================
# 7Ô∏è‚É£ Ïã§Ìñâ
# ============================================================
question = "Extract generic drug names from this paper."
response = chain({
    "query": question,
    "vision_context": vision_summary
})

print("\n==============================")
print("üíä [ÏµúÏ¢Ö Í≤∞Í≥º: Drug Extraction]")
print("==============================")
print(response['result'])
print("==============================\n")


KeyboardInterrupt: 