# Generate Caption

In [None]:
import os, gc, json
from pdf2image import convert_from_path
from PIL import Image
import pytesseract
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
from PyPDF2 import PdfReader
import torch

# Define file paths and constants
FILE = "AI.pdf"  # Path to the 463-page PDF
output_dir = "pdf_pages"
sep = "<|im_sep|>"

# 系統提示語
SYSTEM_PROMPT = """
You are an AI lecture slide analyzer. The following input is an image and the ocr of a lecture slide about “Artificial Intelligence.”
1. Extract every piece of written content:
   • Slide title
   • Sub-bullets and their full text
   • Definitions, formulas, and any inline examples
2. The summary should be as thorough and precise as possible—this will be used for later retrieval and generation.
3. The keywords should be the most relevant terms from the slide, containing at least 5 terms.
4. Organize your output as a valid JSON object with these fields:
   {
     "title": string,
     "definitions": { term: definition },
     "formulas": [ string ],
     "keywords": [ string ],
     "summary": string
   }
"""

# LLM setup
model_id = "microsoft/Phi-4-multimodal-instruct"
generation_config = GenerationConfig.from_pretrained(model_id)
generation_config.max_new_tokens = 1024
do_sample = True
temperature = 0.1

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    device_map="cuda",
    torch_dtype="auto",
    _attn_implementation="flash_attention_2"
).to("cuda")


def ocr_image(img: Image.Image) -> str:
    return pytesseract.image_to_string(img, lang="chi_tra+eng")


def caption_with_phi4(img: Image.Image, system: str) -> str:
    prompt = (
        "<|im_start|>system<|im_sep|>" + system.strip() +
        # "<|im_end|><|im_start|>user<|im_sep|>" + user.strip() +
        "<|image_1|><|im_end|><|im_start|>assistant<|im_sep|>"
    )
    inputs = processor(images=img, text=prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            generation_config=generation_config,
            max_new_tokens=generation_config.max_new_tokens,
        )
    return processor.decode(outputs[0], skip_special_tokens=True)


# Create output directory if not exists
os.makedirs(output_dir, exist_ok=True)

# Process each page
reader = PdfReader(FILE)
for page_num in range(1, len(reader.pages) + 1):
    # 1. render page to image
    images = convert_from_path(FILE, dpi=200,
                                first_page=page_num, last_page=page_num,
                                use_pdftocairo=True)
    img = images[0]

    # # 2. OCR
    # ocr_text = ocr_image(img)

    # 3. Generate caption & strip JSON
    attempts = 0
    success = False
    raw_caption = None
    while attempts < 3 and not success:
        raw_caption = caption_with_phi4(img, SYSTEM_PROMPT) #, ocr_text)
        # Try stripping JSON
        idx = raw_caption.rfind(sep)
        if idx != -1:
            json_part = raw_caption[idx + len(sep):].strip()
            try:
                data = json.loads(json_part)
                # success
                success = True
            except json.JSONDecodeError:
                # increase tokens and retry
                generation_config.max_new_tokens += 512
        attempts += 1

    # 4. Save raw caption
    base = f"page_{page_num:03d}"
    img.save(os.path.join(output_dir, base + ".png"))
    with open(os.path.join(output_dir, base + "_caption.json"), "w", encoding="utf-8") as f:
        f.write(raw_caption if raw_caption else "")

    # 5. Save stripped JSON or log failure
    strip_path = os.path.join(output_dir, base + "_caption_strip.json")
    if success:
        with open(strip_path, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        print(f"✅  page_{page_num:03d} processed")
    else:
        print(f"❌  page_{page_num:03d} failed to strip JSON after retries")
        print(f"Raw caption: {raw_caption}")

    # 6. cleanup
    del img, images
    gc.collect()


# Save into chromadb

In [None]:
# File 1: build_chromadb.py
import os
import json
from langchain.schema import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma

# Configuration
OUTPUT_DIR = "pdf_pages"
DB_PATH = "./chroma_db_task2"
EMBEDDINGS = "all-MiniLM-L6-v2"

# Load stripped JSON captions into Documents
docs = []
for fname in sorted(os.listdir(OUTPUT_DIR)):
    if not fname.endswith("_caption_strip.json"):
        continue
    page = int(fname.split("_")[1])  # expects page_###_caption_strip.json
    path = os.path.join(OUTPUT_DIR, fname)
    with open(path, encoding="utf-8") as f:
        data = json.load(f)
    text = f"Page {page}\nTitle: {data['title']}\n\nSummary:\n{data['summary']}"
    docs.append(Document(page_content=text, metadata={"page": page}))

# Build and persist Chroma vector store
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS)
vectordb = Chroma.from_documents(
    documents=docs,
    embedding=embeddings,
    persist_directory=DB_PATH
)
vectordb.persist()
print(f"Persisted {len(docs)} documents into {DB_PATH}")


# Run queries

In [None]:
# File 2: run_queries.py
import os
import pandas as pd
from langchain_chroma import Chroma
from transformers import pipeline
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA

# Configuration
DB_PATH      = "./chroma_db_task2"
QUERIES_CSV  = "queries.csv"
SUBMISSION_CSV = "submission.csv"
MODEL_ID     = "microsoft/Phi-4-multimodal-instruct"

# 1. load persisted Chroma vector store
vectordb  = Chroma(persist_directory=DB_PATH, embedding_function=None)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})

# 2. set up LLM pipeline for generation
hf_pipe = pipeline(
    "text-generation",
    model=MODEL_ID,
    trust_remote_code=True,
    device_map="cuda",
    torch_dtype="auto",
    return_full_text=False
)
llm = HuggingFacePipeline(pipeline=hf_pipe)

# 3. build RetrievalQA chain
qa = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=retriever,
    return_source_documents=True
)

# 4. read queries CSV
df      = pd.read_csv(QUERIES_CSV)
results = []

# 5. process each query and collect only ID + page
for _, row in df.iterrows():
    qid      = row['ID']
    question = row['Question']
    output   = qa({"query": question})
    src_docs = output['source_documents']
    page     = src_docs[0].metadata.get('page') if src_docs else None

    results.append({
        'ID':     qid,
        'Answer': page
    })

# 6. save submission.csv
submission_df = pd.DataFrame(results)
submission_df.to_csv(SUBMISSION_CSV, index=False, encoding='utf-8-sig')
print(f"Wrote submission to {SUBMISSION_CSV}")
