diff --git a/ai/generative-ai-service/complex-document-rag/.gitignore b/ai/generative-ai-service/complex-document-rag/.gitignore new file mode 100644 index 000000000..01b77f130 --- /dev/null +++ b/ai/generative-ai-service/complex-document-rag/.gitignore @@ -0,0 +1,26 @@ +# macOS system files +.DS_Store + +.env +# Python cache +**/__pycache__/ + +# Virtual environments +venv/ + +# Local config +config.py + +# Data folders +data/ +embeddings/ +charts/ +reports/ + +# Logs +*.log +logs/ + +# Text files (except requirements.txt) +*.txt +!requirements.txt \ No newline at end of file diff --git a/ai/generative-ai-service/complex-document-rag/README.md b/ai/generative-ai-service/complex-document-rag/README.md index dc0a5c3d7..4120deea5 100644 --- a/ai/generative-ai-service/complex-document-rag/README.md +++ b/ai/generative-ai-service/complex-document-rag/README.md @@ -2,7 +2,7 @@ An enterprise-grade Retrieval-Augmented Generation (RAG) system for generating comprehensive business reports from multiple document sources using Oracle Cloud Infrastructure (OCI) Generative AI services. -Reviewed date: 22.09.2025 +Reviewed date: 03.11.2025 ## Features @@ -14,7 +14,7 @@ Reviewed date: 22.09.2025 - **Citation Tracking**: Source attribution with references - **Multi-Language Support**: Generate reports in English, Arabic, Spanish, and French - **Visual Analytics**: Automatic chart and table generation from data - +![Application screenshot](files/images/screenshot1.png) ## Prerequisites - Python 3.11+ diff --git a/ai/generative-ai-service/complex-document-rag/files/agents/agent_factory.py b/ai/generative-ai-service/complex-document-rag/files/agents/agent_factory.py index 9987f0113..e4332cab9 100644 --- a/ai/generative-ai-service/complex-document-rag/files/agents/agent_factory.py +++ b/ai/generative-ai-service/complex-document-rag/files/agents/agent_factory.py @@ -445,38 +445,45 @@ def _process_batch(self, batch: List[Dict[str, Any]]) -> List[str]: prompt = "\n".join(prompt_parts) self.log_prompt(prompt, f"ChunkRewriter (Batch of {len(batch)})") - response = self.llm.invoke([DummyMessage(prompt)]) - - # Handle different LLM response styles - if hasattr(response, "content"): - text = response.content.strip() - elif isinstance(response, list) and isinstance(response[0], dict): - text = response[0].get("generated_text") or response[0].get("text") - if not text: - raise ValueError("⚠️ No valid 'generated_text' found in response.") - text = text.strip() - else: - raise TypeError(f"⚠️ Unexpected response type: {type(response)} — {response}") - - self.log_response(text, f"ChunkRewriter (Batch of {len(batch)})") - rewritten_chunks = self._parse_batch_response(text, len(batch)) - rewritten_chunks = [self._clean_chunk_text(chunk) for chunk in rewritten_chunks] - - # Enhanced logging with side-by-side comparison - paired = list(zip(batch, rewritten_chunks)) - for i, (original_chunk, rewritten_text) in enumerate(paired, 1): - # Get the actual raw chunk text, not the metadata - original_text = original_chunk.get("text", "") - metadata = original_chunk.get("metadata", {}) - - # Use demo logger for visual comparison if available - if DEMO_MODE and hasattr(logger, 'chunk_comparison'): - # Pass the actual chunk text, not metadata - logger.chunk_comparison(original_text, rewritten_text, metadata) + try: + response = self.llm.invoke([DummyMessage(prompt)]) + + # Handle different LLM response styles + if hasattr(response, "content"): + text = response.content.strip() + elif isinstance(response, list) and isinstance(response[0], dict): + text = response[0].get("generated_text") or response[0].get("text") + if not text: + raise ValueError("⚠️ No valid 'generated_text' found in response.") + text = text.strip() else: - logger.info(f"⚙ Rewritten Chunk {i}:\n{rewritten_text}\nMetadata: {json.dumps(metadata, indent=2)}\n") - - return rewritten_chunks + raise TypeError(f"⚠️ Unexpected response type: {type(response)} — {response}") + + self.log_response(text, f"ChunkRewriter (Batch of {len(batch)})") + rewritten_chunks = self._parse_batch_response(text, len(batch)) + rewritten_chunks = [self._clean_chunk_text(chunk) for chunk in rewritten_chunks] + + # Enhanced logging with side-by-side comparison + paired = list(zip(batch, rewritten_chunks)) + for i, (original_chunk, rewritten_text) in enumerate(paired, 1): + # Get the actual raw chunk text, not the metadata + original_text = original_chunk.get("text", "") + metadata = original_chunk.get("metadata", {}) + + # Use demo logger for visual comparison if available + if DEMO_MODE and hasattr(logger, 'chunk_comparison'): + # Pass the actual chunk text, not metadata + logger.chunk_comparison(original_text, rewritten_text, metadata) + else: + logger.info(f"⚙ Rewritten Chunk {i}:\n{rewritten_text}\nMetadata: {json.dumps(metadata, indent=2)}\n") + + return rewritten_chunks + + except Exception as e: + # Handle timeout and other errors gracefully + logger.error(f"❌ Batch processing failed: {e}") + # Return None for each chunk to indicate failure (not empty strings!) + return [None] * len(batch) def _parse_batch_response(self, response_text: str, expected_chunks: int) -> List[str]: @@ -581,8 +588,7 @@ def _detect_comparison_query(self, query: str) -> bool: """Use LLM to detect whether the query involves a comparison.""" prompt = f""" Does the query below involve a **side-by-side comparison between two or more named entities such as companies, organizations, or products**? - -Exclude comparisons to frameworks (e.g., CSRD, ESRS), legal standards, or regulations — those do not count. +Include comparisons to frameworks (e.g., CSRD, ESRS), legal standards, or regulations. Query: "{query}" @@ -641,228 +647,206 @@ def extract_first_json_list(text): return re.findall(r'"([^"]+)"', text) def _extract_entities(self, query: str) -> List[str]: - """Use LLM to extract entity names, then normalize + dedupe.""" - prompt = f""" -Extract company/organization names mentioned in the query and return a CLEANED JSON list. + """Prefer exact vector-store tags typed by the user; LLM only as fallback.""" + import re + logger = getattr(self, "logger", None) or __import__("logging").getLogger(__name__) -CLEANING RULES (apply to each name before returning): -- Lowercase everything. -- Remove legal suffixes at the end: plc, ltd, inc, llc, lp, l.p., corp, corporation, co., co, s.a., s.a.s., ag, gmbh, bv, nv, oy, ab, sa, spa, pte, pvt, pty, srl, sro, k.k., kk, kabushiki kaisha. -- Remove punctuation except internal ampersands (&). Collapse multiple spaces. -- No duplicates. + # --- 0) known tag set from your vector store (lowercased) --- + # Populate this once at init: self.known_tags = {id.lower() for id in vector_store_ids()} + known = getattr(self, "known_tags", None) -CONSTRAINTS: -- Return ONLY a JSON list of strings, e.g. ["aelwyn","elinexa"] -- No prose, no keys, no explanations. -- Do not include standards, clause numbers, sectors, or generic words like "entity". -- If none are present, return []. + tagged = [] -Examples: -Query: "Compare Aelwyn vs Elinexa PLC policies" -Return: ["aelwyn","elinexa"] + # A) Existing FY/Q pattern (kept) + tagged += [m.group(0) for m in re.finditer( + r"\b[A-Za-z][A-Za-z0-9\-]*_(?:FY|Q[1-4])\d{2,4}\b", query, flags=re.I + )] -Query: "Barclays (UK) and JPMorgan Chase & Co." -Return: ["barclays","jpmorgan chase & co"] + # B) NEW: generic "_" e.g., "mof_2022", "mof_2024" + tagged += [m.group(0) for m in re.finditer( + r"\b[A-Za-z][A-Za-z0-9\-]*_\d{2,4}\b", query + )] -Query: "What are Microsoft’s 2030 targets?" -Return: ["microsoft"] + # C) (Optional but useful) quoted tokens like "mof_2022" + tagged += [m.group(1) for m in re.finditer( + r'"([A-Za-z0-9][A-Za-z0-9_\-]{1,80})"', query + )] -Query: "No company here" -Return: [] + # De-dup preserve order (case-insensitive) + seen = set() + tagged_unique: List[str] = [] + for t in tagged: + k = t.lower() + if k not in seen: + # If we know the store IDs, only keep those that exist + if not known or k in known: + seen.add(k) + tagged_unique.append(t) + + # --- Early return: if user typed valid tags, trust them verbatim --- + if tagged_unique: + if logger: + logger.info(f"[Entity Extractor] Exact tags: {tagged_unique}") + return tagged_unique + + # --- Fallback: your original LLM extraction (unchanged) --- + prompt = f""" + Extract company/organization names mentioned in the query and return a CLEANED JSON list. -Now process this query: + CLEANING RULES (apply to each name before returning): + - Lowercase everything. + - Remove legal suffixes at the end: plc, ltd, inc, llc, lp, l.p., corp, corporation, co., co, s.a., s.a.s., ag, gmbh, bv, nv, oy, ab, sa, spa, pte, pvt, pty, srl, sro, k.k., kk, kabushiki kaisha. + - Remove punctuation except internal ampersands (&). Collapse multiple spaces. + - No duplicates. -{query} -""" + CONSTRAINTS: + - Return ONLY a JSON list of strings, e.g. ["aelwyn","elinexa"] + - No prose, no keys, no explanations. + - Do not include standards, clause numbers, sectors, or generic words like "entity". + - If none are present, return []. + + Now process this query: + + {query} + """ try: raw = self.llm(prompt).strip() - print(raw) entities = self.extract_first_json_list(raw) - # Keep strings only and strip whitespace entities = [e.strip() for e in entities if isinstance(e, str) and e.strip()] - # Deduplicate while preserving order - seen = set() - cleaned: List[str] = [] + final: List[str] = [] + seen2 = set() + for e in entities: - if e.lower() not in seen: - seen.add(e.lower()) - cleaned.append(e) + k = e.lower() + if (not known or k in known) and k not in seen2: + seen2.add(k) + final.append(e) - if not cleaned: - logger.warning(f"[Entity Extractor] No plausible entities extracted from LLM output: {entities}") + if not final and logger: + logger.warning(f"[Entity Extractor] No plausible entities extracted. LLM: {entities} | tags: []") - logger.info(f"[Entity Extractor] Raw: {raw} | Cleaned: {cleaned}") - return cleaned + if logger: + logger.info(f"[Entity Extractor] Raw: {raw} | Tags: [] | Final: {final}") + return final except Exception as e: - logger.warning(f"⚠️ Failed to robustly extract entities via LLM: {e}") + if logger: + logger.warning(f"⚠️ Failed to robustly extract entities via LLM: {e}") return [] + def plan( - self, - query: str, - context: List[Dict[str, Any]] | None = None, - is_comparison_report: bool = False - ) -> tuple[list[Dict[str, Any]], list[str], bool]: - """ - Strategic planner that returns structured topics with steps. - Supports both comparison and single-entity analysis with consistent output format. + self, + query: str, + context: List[Dict[str, Any]] | None = None, + is_comparison_report: bool = False, + comparison_mode: str | None = None, # kept for compatibility, not used to hardcode content + provided_entities: Optional[List[str]] = None + ) -> tuple[list[Dict[str, Any]], list[str], bool]: """ - raw = None - is_comparison = self._detect_comparison_query(query) or is_comparison_report - entities = self._extract_entities(query) - logger.info(f"[Planner] Detected entities: {entities} | Comparison task: {is_comparison}") - - if is_comparison and len(entities) < 2: - logger.warning(f"⚠️ Comparison task detected but only {len(entities)} entity found: {entities}") - is_comparison = False # fallback to single-entity mode + PROMPT-DRIVEN PLANNER + - Derive section topics from the user's TASK PROMPT (not hardcoded). + - For each topic, emit one mirrored retrieval step per entity. + - Output shape: List[{"topic": str, "steps": List[str]}], plus (entities, is_comparison). - ctx = "\n".join(f"{i+1}. {c['content']}" for i, c in enumerate(context or [])) - - if is_comparison: - template = """ - You are a strategic planning agent generating grouped research steps for a comparative analysis report. + Returns: + (plan, entities, is_comparison) + """ - TASK: {query} + # 1) Determine comparison intent and entities (keep your existing logic) + is_comparison = self._detect_comparison_query(query) or is_comparison_report - OBJECTIVE: - Break the task into high-level comparison **topics**. For each topic, generate **two steps** — one per entity. + if provided_entities: + entities = [e for e in provided_entities if isinstance(e, str) and e.strip()] + logger.info(f"[Planner] Using provided entities: {entities}") + else: + entities = self._extract_entities(query) + logger.info(f"[Planner] Detected entities: {entities} | Comparison task: {is_comparison}") - RULES: - - Keep topic titles focused and distinct (e.g., "Scope 1 Emissions") - - Use a consistent step format: "Find (something) for (Entity)" - - Use only these entities: {entities} + # If comparison requested but <2 entities, degrade gracefully to single-entity mode + if is_comparison and len(entities) < 2: + logger.warning(f"⚠️ Comparison requested but only {len(entities)} entity found: {entities}. Falling back to single-entity.") + is_comparison = False + # 2) Ask the LLM ONLY for topics (strings), not full objects — we’ll build steps ourselves + # This avoids fragile JSON with missing "topic" keys. + topic_prompt = f""" +Extract the main section topics from the TASK PROMPT. +Use the user's own headings/bullets/order when present. +If none are explicit, infer 5–10 concise, non-overlapping topics that reflect the user's request. - EXAMPLE: - [ - {{ - "topic": "Net-Zero Targets", - "steps": [ - "Find net-zero targets for Company-A", - "Find net-zero targets for Company-B" - ] - }} - ] +TASK PROMPT: +{query} - TASK: {query} +Return ONLY a JSON array of strings, e.g. ["Executive Summary","Revenue Analysis","Profitability"]. +No prose, no keys, no markdown. +""" + self.log_prompt(topic_prompt, "Planner: Topic Extraction") + raw_topics = None + topics: list[str] = [] + try: + raw_topics = self.llm(topic_prompt).strip() + json_str = UniversalJSONCleaner.clean_and_extract_json(raw_topics, expected_type="array") + parsed = UniversalJSONCleaner.parse_with_validation(json_str, expected_structure=None) + if isinstance(parsed, list): + # Keep only non-empty strings + topics = [str(t).strip() for t in parsed if isinstance(t, (str, int, float)) and str(t).strip()] + except Exception as e: + logger.error(f"❌ Topic extraction failed: {e}") + logger.debug(f"Raw topic response:\n{raw_topics}") + + # 2b) Hard fallback: if still empty, derive topics from obvious headings in the query + if not topics: + # Grab capitalized/bulleted lines as headings + lines = [ln.strip() for ln in (query or "").splitlines()] + bullets = [ln.lstrip("-*• ").strip() for ln in lines if ln.strip().startswith(("-", "*", "•"))] + caps = [ln for ln in lines if ln and ln == ln.title() and len(ln.split()) <= 8] + candidates = bullets or caps + if candidates: + topics = [t for t in candidates if len(t) >= 3][:10] + + # 2c) Ultimate fallback: generic buckets (kept minimal, not domain-specific) + if not topics: + topics = [ + "Executive Summary", + "Key Metrics", + "Section 1", + "Section 2", + "Section 3", + "Risks & Considerations", + "Conclusion" + ] - ENTITIES: {entities} - Respond ONLY with valid JSON. - Use standard double quotes (") for all JSON keys and string values. - You MAY and SHOULD use single quotes (') *inside* string values for possessives (e.g., "CEO's"). - Do NOT use curly or smart quotes. - Do NOT write `"CEO"s"`, only `"CEO's"`. - """ - else: - if not entities: - logger.warning("⚠️ No entity found in query — using fallback") - entities = ["The Company"] - template = """ - You are a planning agent decomposing a task for a single entity into structured research topics. - -TASK: {query} - -OBJECTIVE: -Break this into 3–10 key topics. Under each topic, include 1–2 retrieval-friendly steps. - -RULES: -- Keep topics distinct and concrete (e.g., Carbon Disclosure) -- Use only these entities: {entities} -- Use a consistent step format: "Find (something) for (Entity)" - -EXAMPLE: -[ -{{ - "topic": "Carbon Disclosure for Company-A", - "steps": [ - "Find 2023 Scope 1 and 2 emissions for Company-A" - ] -}}, -{{ - "topic": "Company-A Diversity Strategy", - "steps": [ - "Analyze gender and ethnicity diversity at Company-A" - ] -}} -] -Respond ONLY with valid JSON. -Do NOT use possessive forms (e.g., do NOT write "Aelwyn's Impact"). Instead, write "Impact for Aelwyn" or "Impact of Aelwyn". -Use the format: "Find (something) for (Entity)" -Do NOT use curly or smart quotes. + # 3) Build plan objects and MIRROR steps across entities (no hardcoded content) + plan: list[dict] = [] + for t in topics: + t_clean = str(t).strip() + if not t_clean: + continue - """ + if is_comparison and len(entities) >= 2: + # One retrieval step per entity — mirrored wording + steps = [f"Find all items requested under '{t_clean}' for {entities[0]}", + f"Find all items requested under '{t_clean}' for {entities[1]}"] + else: + # Single entity (or unknown) + e0 = entities[0] if entities else "The Entity" + steps = [f"Find all items requested under '{t_clean}' for {e0}"] - messages = ChatPromptTemplate.from_template(template).format_messages( - query=query, - context=ctx, - entities=entities - ) - full_prompt = "\n".join(str(m.content) for m in messages) - self.log_prompt(full_prompt, "Planner") + plan.append({"topic": t_clean, "steps": steps}) + # 4) Log and return try: - raw = self.llm.invoke(messages).content.strip() - self.log_response(raw, "Planner") - cleaned = UniversalJSONCleaner.clean_and_extract_json(raw, expected_type="array") + self.log_response(json.dumps(plan, ensure_ascii=False, indent=2), "Planner: Plan (topics→steps)") + except Exception: + pass - plan = UniversalJSONCleaner.parse_with_validation( - cleaned, expected_structure="Array of objects with 'topic' and 'steps' keys" - ) + return plan, entities, is_comparison - if not isinstance(plan, list): - raise ValueError("Parsed plan is not a list") - - for section in plan: - if not isinstance(section, dict): - raise ValueError("Section is not a dict") - if "topic" not in section or "steps" not in section: - raise ValueError("Missing 'topic' or 'steps'") - if not isinstance(section["topic"], str): - raise ValueError("Topic must be a string") - if not isinstance(section["steps"], list): - raise ValueError("Steps must be a list") - if not all(isinstance(s, str) for s in section["steps"]): - raise ValueError("Each step must be a string") - - # Optional: Validate entity inclusion if this was a comparison task - if is_comparison and entities: - for section in plan: - step_text = " ".join(section["steps"]).lower() - for entity in entities: - if entity.lower() not in step_text: - logger.warning( - f"⚠️ Entity '{entity}' not found in steps for topic: '{section['topic']}'" - ) - - return plan, entities, is_comparison - except Exception as e: - logger.error(f"❌ Failed to parse planner output: {e}") - logger.error(f"Raw response:\n{raw}") - # Attempt a minimal prompt instead of hardcoded fallback - try: - fallback_prompt = f""" - Return a JSON list of 5 objects like this: - [{{ - "topic": "X and Y", - "steps": ["Find X for The Company", "Analyze Y for The Company"] - }}] - TASK: {query} - Respond with valid JSON - """ - raw_fallback = self.llm(fallback_prompt).strip() - cleaned_fallback = UniversalJSONCleaner.clean_and_extract_json(raw_fallback) - fallback_plan = UniversalJSONCleaner.parse_with_validation( - cleaned_fallback, expected_structure="Array of objects with 'topic' and 'steps' keys" - ) - return fallback_plan, entities, is_comparison - except Exception as inner_e: - logger.error(f"🛑 Fallback planner also failed: {inner_e}") - raise RuntimeError("Both planner and fallback planner failed") from inner_e class ResearchAgent(Agent): diff --git a/ai/generative-ai-service/complex-document-rag/files/agents/report_writer_agent.py b/ai/generative-ai-service/complex-document-rag/files/agents/report_writer_agent.py index f11ddde9e..302f1577e 100644 --- a/ai/generative-ai-service/complex-document-rag/files/agents/report_writer_agent.py +++ b/ai/generative-ai-service/complex-document-rag/files/agents/report_writer_agent.py @@ -5,63 +5,226 @@ import uuid import logging import datetime -import matplotlib.pyplot as plt - import math +import re +from docx.oxml.shared import OxmlElement +from docx.text.run import Run logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) os.makedirs("charts", exist_ok=True) + + +_MD_TOKEN_RE = re.compile(r'(\*\*.*?\*\*|__.*?__|\*.*?\*|_.*?_)') + +def add_inline_markdown_paragraph(doc, text: str): + """ + Creates a paragraph and renders lightweight inline Markdown: + **bold** or __bold__ → bold run + *italic* or _italic_ → italic run + Everything else is plain text. No links/lists/code handling. + """ + p = doc.add_paragraph() + i = 0 + for m in _MD_TOKEN_RE.finditer(text): + # leading text + if m.start() > i: + p.add_run(text[i:m.start()]) + token = m.group(0) + # strip the markers + if token.startswith('**') or token.startswith('__'): + content = token[2:-2] + run = p.add_run(content) + run.bold = True + else: + content = token[1:-1] + run = p.add_run(content) + run.italic = True + i = m.end() + # trailing text + if i < len(text): + p.add_run(text[i:]) + return p + def add_table(doc, table_data): - """Create a professionally styled Word table from list of dicts.""" + """Create a Word table from list of dicts or list of lists, robustly.""" if not table_data: return - + headers = [] - seen = set() - for row in table_data: - for k in row.keys(): - if k not in seen: - headers.append(k) - seen.add(k) - - # Create table with proper styling + rows_normalized = [] + + # Case 1: list of dicts + if isinstance(table_data[0], dict): + seen = set() + for row in table_data: + for k in row.keys(): + if k not in seen: + headers.append(k) + seen.add(k) + rows_normalized = table_data + + # Case 2: list of lists + elif isinstance(table_data[0], (list, tuple)): + max_len = max(len(row) for row in table_data) + headers = [f"Col {i+1}" for i in range(max_len)] + for row in table_data: + rows_normalized.append({headers[i]: row[i] if i < len(row) else "" + for i in range(max_len)}) + + else: + headers = ["Value"] + rows_normalized = [{"Value": str(row)} for row in table_data] + table = doc.add_table(rows=1, cols=len(headers)) table.style = 'Table Grid' - - # Style header row + header_row = table.rows[0] for i, h in enumerate(headers): cell = header_row.cells[i] cell.text = str(h) - # Make header bold for paragraph in cell.paragraphs: for run in paragraph.runs: run.bold = True - # Add data rows - for row in table_data: + for row in rows_normalized: row_cells = table.add_row().cells for i, h in enumerate(headers): row_cells[i].text = str(row.get(h, "")) +def _color_for_label(label: str, entities: list[str] | tuple[str, ...] | None, + base="#a9bbbc", e1="#437c94", e2="#c74634") -> str: + """Pick a bar color based on whether a label mentions one of the entities.""" + if not entities: + return base + lbl = label.lower() + ents = [e for e in entities if isinstance(e, str)] + if len(ents) >= 1 and ents[0].lower() in lbl: + return e1 + if len(ents) >= 2 and ents[1].lower() in lbl: + return e2 + return base + + +def detect_units(chart_data: dict, title: str = "") -> str: + """Detect units of measure from chart data and title.""" + # Common patterns for currency + currency_patterns = [ + (r'\$|USD|usd|dollar', 'USD'), + (r'€|EUR|eur|euro', 'EUR'), + (r'£|GBP|gbp|pound', 'GBP'), + (r'¥|JPY|jpy|yen', 'JPY'), + (r'₹|INR|inr|rupee', 'INR'), + ] + + # Common patterns for other units - order matters! + unit_patterns = [ + (r'million|millions|mn|mln|\$m|\$M', 'Million'), + (r'billion|billions|bn|bln|\$b|\$B', 'Billion'), + (r'thousand|thousands|k|\$k', 'Thousand'), + (r'percentage|percent|%', '%'), + (r'tonnes|tons|tonne|ton', 'Tonnes'), + (r'co2e|CO2e|co2|CO2', 'CO2e'), + (r'kwh|kWh|KWH', 'kWh'), + (r'mwh|MWh|MWH', 'MWh'), + (r'kg|kilogram|kilograms', 'kg'), + (r'employees|headcount|people', 'Employees'), + (r'days|day', 'Days'), + (r'hours|hour|hrs', 'Hours'), + (r'years|year|yrs', 'Years'), + ] + + # Check title and keys for units - also check values if they're strings + combined_text = title.lower() + " " + " ".join(str(k).lower() for k in chart_data.keys()) + # Also check string values which might contain unit info + for v in chart_data.values(): + if isinstance(v, str): + combined_text += " " + v.lower() + + detected_currency = None + detected_scale = None + detected_unit = None + + # Check for currency + for pattern, unit in currency_patterns: + if re.search(pattern, combined_text, re.IGNORECASE): + detected_currency = unit + break + + # Check for scale (million, billion, etc.) + for pattern, unit in unit_patterns[:4]: # First 4 are scales + if re.search(pattern, combined_text, re.IGNORECASE): + detected_scale = unit + break + + # Check for other units + for pattern, unit in unit_patterns[4:]: # Rest are units + if re.search(pattern, combined_text, re.IGNORECASE): + detected_unit = unit + break + + # Combine detected elements + if detected_currency and detected_scale: + return f"{detected_scale} {detected_currency}" + elif detected_currency: + # If we detect currency but no scale, look for financial context clues + if 'revenue' in combined_text or 'sales' in combined_text or 'income' in combined_text: + # Financial data without explicit scale often means millions + if 'fy' in combined_text or 'fiscal' in combined_text or 'quarterly' in combined_text: + return "Million USD" # Corporate financials are typically in millions + return detected_currency + return detected_currency + elif detected_unit: + if detected_scale and detected_unit not in ['%', 'Employees', 'Days', 'Hours', 'Years']: + return f"{detected_scale} {detected_unit}" + return detected_unit + elif detected_scale: + # If we only have scale (like "Million") without currency, check for financial context + if any(term in combined_text for term in ['revenue', 'cost', 'profit', 'income', 'sales', 'expense', 'financial']): + return f"{detected_scale} USD" + return detected_scale + + # For financial metrics without explicit units, default to "Million USD" + if any(term in combined_text for term in ['revenue', 'sales', 'profit', 'income', 'cost', 'expense', 'financial', 'fiscal', 'fy20']): + return "Million USD" + + return "Value" # Default fallback + + +def format_value_with_units(value: float, units: str) -> str: + """Format a value with appropriate precision based on units.""" + if '%' in units: + return f"{value:.1f}%" + elif 'Million' in units or 'Billion' in units: + return f"{value:,.1f}" + elif value >= 1000: + return f"{value:,.0f}" + else: + return f"{value:.1f}" + + +def make_chart(chart_data: dict, title: str = "", + entities: list[str] | tuple[str, ...] | None = None, + units: str | None = None) -> str | None: + """Generate a chart with conditional formatting and fallback for list values. + If `entities` contains up to two names, bars whose labels include those names + are highlighted in two distinct colors. Otherwise a default color is used. + Units are detected automatically or can be passed explicitly. + """ -def make_chart(chart_data: dict, title: str = "") -> str | None: - """Generate a chart with conditional formatting and fallback for list values.""" - import numpy as np import textwrap os.makedirs("charts", exist_ok=True) clean = {} for k, v in chart_data.items(): - # NEW: Reduce lists to latest entry if all elements are numeric + # Reduce lists to latest numeric entry if isinstance(v, list): if all(isinstance(i, (int, float)) for i in v): - v = v[-1] # use the latest value + v = v[-1] else: continue @@ -78,47 +241,56 @@ def make_chart(chart_data: dict, title: str = "") -> str | None: labels = list(clean.keys()) values = list(clean.values()) + + # Detect units if not provided + if not units: + units = detect_units(chart_data, title) + + # Update title to include units if not already present + if units and units != "Value" and units.lower() not in title.lower(): + title = f"{title} ({units})" - # Decide chart orientation based on label length and count - create more variety + # Decide orientation max_label_length = max(len(label) for label in labels) if labels else 0 - - # More nuanced decision for chart orientation - if len(clean) > 12: # Many items -> horizontal + if len(clean) > 12: horizontal = True - elif max_label_length > 40: # Very long labels -> horizontal + elif max_label_length > 40: horizontal = True - elif len(clean) <= 4 and max_label_length <= 20: # Few items, short labels -> vertical + elif len(clean) <= 4 and max_label_length <= 20: horizontal = False - elif len(clean) <= 6 and max_label_length <= 30: # Medium items, medium labels -> vertical + elif len(clean) <= 6 and max_label_length <= 30: horizontal = False - else: # Default to horizontal for edge cases + else: horizontal = True - fig, ax = plt.subplots(figsize=(12, 8)) # Increased figure size for better readability + fig, ax = plt.subplots(figsize=(12, 8)) if horizontal: - # Wrap long labels for horizontal charts wrapped_labels = ['\n'.join(textwrap.wrap(label, width=40)) for label in labels] - bars = ax.barh(wrapped_labels, values, color=["#2e7d32" if "aelwyn" in l.lower() else "#f9a825" if "elinexa" in l.lower() else "#4472C4" for l in labels]) - ax.set_xlabel("Value") + colors = [_color_for_label(l, entities) for l in labels] + bars = ax.barh(wrapped_labels, values, color=colors) + ax.set_xlabel(units) # Use detected units instead of "Value" ax.set_ylabel("Category") for bar in bars: width = bar.get_width() - ax.annotate(f"{width:.1f}", xy=(width, bar.get_y() + bar.get_height() / 2), xytext=(5, 0), - textcoords="offset points", ha='left', va='center', fontsize=8) + formatted_value = format_value_with_units(width, units) + ax.annotate(formatted_value, xy=(width, bar.get_y() + bar.get_height() / 2), + xytext=(5, 0), textcoords="offset points", + ha='left', va='center', fontsize=8) else: - # Wrap long labels for vertical charts wrapped_labels = ['\n'.join(textwrap.wrap(label, width=15)) for label in labels] - bars = ax.bar(range(len(labels)), values, color=["#2e7d32" if "aelwyn" in l.lower() else "#f9a825" if "elinexa" in l.lower() else "#4472C4" for l in labels]) - ax.set_ylabel("Value") + colors = [_color_for_label(l, entities) for l in labels] + bars = ax.bar(range(len(labels)), values, color=colors) + ax.set_ylabel(units) # Use detected units instead of "Value" ax.set_xlabel("Category") ax.set_xticks(range(len(labels))) ax.set_xticklabels(wrapped_labels, ha='center', va='top') - for bar in bars: height = bar.get_height() - ax.annotate(f"{height:.1f}", xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 5), - textcoords="offset points", ha='center', va='bottom', fontsize=8) + formatted_value = format_value_with_units(height, units) + ax.annotate(formatted_value, xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 5), textcoords="offset points", + ha='center', va='bottom', fontsize=8) ax.set_title(title[:100]) ax.grid(axis="y" if not horizontal else "x", linestyle="--", alpha=0.6) @@ -126,21 +298,18 @@ def make_chart(chart_data: dict, title: str = "") -> str | None: filename = f"chart_{uuid.uuid4().hex}.png" path = os.path.join("charts", filename) - fig.savefig(path, dpi=300, bbox_inches='tight') # Higher DPI and tight bbox for better quality + fig.savefig(path, dpi=300, bbox_inches='tight') plt.close(fig) return path - - def append_to_doc(doc, section_data: dict, level: int = 2, citation_map: dict | None = None): """Append section to document with heading, paragraph, table, chart, and citations.""" heading = section_data.get("heading", "Untitled Section") - # Use the level parameter to control heading hierarchy doc.add_heading(heading, level=level) text = section_data.get("text", "").strip() - + # Add citations to the text if sources are available if text and citation_map and section_data.get("sources"): citation_numbers = [] @@ -148,14 +317,13 @@ def append_to_doc(doc, section_data: dict, level: int = 2, citation_map: dict | source_key = f"{source.get('file', 'Unknown')}_{source.get('sheet', '')}_{source.get('entity', '')}" if source_key in citation_map: citation_numbers.append(citation_map[source_key]) - if citation_numbers: - # Add unique citation numbers at the end of the text unique_citations = sorted(set(citation_numbers)) citations_str = " " + "".join([f"[{num}]" for num in unique_citations]) text = text + citations_str - + if text: + add_inline_markdown_paragraph(doc, text) doc.add_paragraph(text) table_data = section_data.get("table", []) @@ -176,17 +344,23 @@ def append_to_doc(doc, section_data: dict, level: int = 2, citation_map: dict | else: flattened_chart_data[k] = v - chart_path = make_chart(flattened_chart_data, title=heading) + # Pass dynamic entities (if present) so colors match those names + entities = section_data.get("entities") + # Pass units if available in section data + units = section_data.get("units") + chart_path = make_chart(flattened_chart_data, title=heading, entities=entities, units=units) if chart_path: doc.add_picture(chart_path, width=Inches(6)) last_paragraph = doc.paragraphs[-1] last_paragraph.alignment = 1 # center + def save_doc(doc, filename: str = "_report.docx"): """Save the Word document.""" doc.save(filename) logger.info(f"✅ Report saved: {filename}") + class SectionWriterAgent: def __init__(self, llm, tokenizer=None): self.llm = llm @@ -197,34 +371,26 @@ def __init__(self, llm, tokenizer=None): print("⚠️ No tokenizer provided for SectionWriterAgent") def estimate_tokens(self, text: str) -> int: - # naive estimate: 1 token ≈ 4 characters for English-like text return max(1, len(text) // 4) def log_token_count(self, text: str, tokenizer=None, label: str = "Prompt"): if not text: print(f"⚠️ Cannot log tokens: empty text for {label}") return - if tokenizer: token_count = len(tokenizer.encode(text)) else: token_count = self.estimate_tokens(text) - print(f"{label} token count: {token_count}") - - - def write_section(self, section_title: str, context_chunks: list[dict]) -> dict: from collections import defaultdict - # Group chunks by entity and preserve metadata grouped = defaultdict(list) grouped_metadata = defaultdict(list) for chunk in context_chunks: entity = chunk.get("_search_entity", "Unknown") grouped[entity].append(chunk.get("content", "")) - # Preserve metadata for citations metadata = chunk.get("metadata", {}) grouped_metadata[entity].append(metadata) @@ -240,12 +406,15 @@ def write_section(self, section_title: str, context_chunks: list[dict]) -> dict: "text": f"Insufficient data for analysis. Entities: {entities}", "table": [], "chart_data": {}, - "sources": [] + "sources": [], + # propagate for downstream report logic + "is_comparison": False, + "entities": entities } def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, entity: str, grouped_metadata: dict | None = None) -> dict: text = "\n\n".join(grouped_chunks[entity]) - + # Extract unique sources from metadata sources = [] if grouped_metadata and entity in grouped_metadata: @@ -260,7 +429,6 @@ def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, }) seen_sources.add(source_key) - # OPTIMIZED: Shorter, more focused prompt for faster processing prompt = f"""Extract key data for {entity} on {section_title}. Return JSON: @@ -269,8 +437,12 @@ def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, Data: {text[:2000]} -CRITICAL: Never use possessive forms (no apostrophes). Instead of "manager's approval" write "manager approval" or "approval from manager". Use "N/A" for missing data. Valid JSON only.""" - +CRITICAL RULES: +1. NEVER use possessive forms or apostrophes (no 's). + - Wrong: "Oracle's revenue", "company's growth" + - Right: "Oracle revenue", "company growth", "revenue of Oracle" +2. Use "N/A" for missing data. +3. Return valid JSON only - no apostrophes in text values.""" try: self.log_token_count(prompt, self.tokenizer, label=f"SingleEntity Prompt ({section_title})") @@ -288,14 +460,16 @@ def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, chart_data = parsed.get("chart_data", {}) if isinstance(chart_data, str): try: - chart_data = ast.literal_eval(chart_data) + import ast as _ast + chart_data = _ast.literal_eval(chart_data) except Exception: chart_data = {} table = parsed.get("table", []) if isinstance(table, str): try: - table = ast.literal_eval(table) + import ast as _ast + table = _ast.literal_eval(table) except Exception: table = [] @@ -304,7 +478,10 @@ def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, "text": parsed.get("text", ""), "table": table, "chart_data": chart_data, - "sources": sources + "sources": sources, + # NEW: carry entity info so charts/titles can highlight correctly + "is_comparison": False, + "entities": [entity] } except Exception as e: @@ -314,7 +491,9 @@ def _write_single_entity_section(self, section_title: str, grouped_chunks: dict, "text": f"Could not generate section due to error: {e}", "table": [], "chart_data": {}, - "sources": sources + "sources": sources, + "is_comparison": False, + "entities": [entity] } def _write_comparison_section(self, section_title: str, grouped_chunks: dict, entities: list[str], grouped_metadata: dict | None = None) -> dict: @@ -328,39 +507,43 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en text_a = "\n\n".join(grouped_chunks[entity_a]) text_b = "\n\n".join(grouped_chunks[entity_b]) - # Construct prompt prompt = f""" - You are writing a structured section for a comparison report between {entity_a} and {entity_b}. +You are writing a structured section for a comparison report between {entity_a} and {entity_b}. - Topic: {section_title} +Topic: {section_title} - OBJECTIVE: - Summarize key data from the context and produce a clear, side-by-side comparison table. +OBJECTIVE: +Summarize key data from the context and produce a clear, side-by-side comparison table. - Always follow this exact structure in your JSON output: - - heading: A short, descriptive title for the section - - text: A 1–2 sentence overview comparing {entity_a} and {entity_b} - - table: List of dicts formatted as: Metric | {entity_a} | {entity_b} | Analysis - - chart_data: A dictionary of comparable numeric values to plot +Always follow this exact structure in your JSON output: +- heading: A short, descriptive title for the section +- text: A 1–2 sentence overview comparing {entity_a} and {entity_b} +- table: List of dicts formatted as: Metric | {entity_a} | {entity_b} | Analysis +- chart_data: A dictionary of comparable numeric values to plot - DATA: - === {entity_a} === - {text_a} +DATA: +=== {entity_a} === +{text_a} - === {entity_b} === - {text_b} +=== {entity_b} === +{text_b} - INSTRUCTIONS: - - Extract specific metrics (numbers, %, dates) from the data - - Use "N/A" if one entity is missing a value - - Use analysis terms like: "Higher", "Lower", "Similar", "{entity_a} Only", "{entity_b} Only" - - Do not echo file names or metadata - - Keep values human-readable (e.g., "18,500 tonnes CO2e") - - CRITICAL: Never use possessive forms (no apostrophes). Instead of "company's target" write "company target" or "target for company". +INSTRUCTIONS: +- Extract specific metrics (numbers, %, dates) from the data +- Use "N/A" if one entity is missing a value +- Use analysis terms like: "Higher", "Lower", "Similar", "{entity_a} only", "{entity_b} only" +- Do not echo file names or metadata +- Keep values human-readable (e.g., "18,500 tonnes CO2e") - Respond only in JSON format. - """ +CRITICAL RULES: +1. NEVER use possessive forms or apostrophes (no 's). + - Wrong: "Oracle's revenue", "company's performance" + - Right: "Oracle revenue", "company performance", "revenue of Oracle" +2. Ensure all JSON is valid - no apostrophes in text values. +3. Use proper escaping if quotes are needed in text. + +Respond only in valid JSON format. +""" try: if self.tokenizer: @@ -379,7 +562,6 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en expected_structure="Object with 'heading', 'text', 'table', and 'chart_data' keys" ) - # Chart data cleanup chart_data = parsed.get("chart_data", {}) if isinstance(chart_data, str): try: @@ -390,7 +572,6 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en if not isinstance(chart_data, dict): chart_data = {} - # Table cleanup table = parsed.get("table", []) if isinstance(table, str): try: @@ -415,7 +596,6 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en if validated_row[entity_a] != "N/A" or validated_row[entity_b] != "N/A": validated.append(validated_row) - # Flatten chart_data if nested flat_chart_data = {} for k, v in chart_data.items(): if isinstance(v, dict): @@ -424,7 +604,7 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en else: flat_chart_data[k] = v - # Extract unique sources from metadata + # Extract unique sources sources = [] if grouped_metadata: seen_sources = set() @@ -445,12 +625,14 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en "text": parsed.get("text", ""), "table": validated, "chart_data": flat_chart_data, - "sources": sources + "sources": sources, + # NEW: signal comparison + entities for downstream styling and charts + "is_comparison": True, + "entities": [entity_a, entity_b] } except Exception as e: logger.error("⚠️ Failed to write comparison section: %s", e) - # Still try to extract sources sources = [] if grouped_metadata: seen_sources = set() @@ -465,39 +647,36 @@ def _write_comparison_section(self, section_title: str, grouped_chunks: dict, en "entity": entity }) seen_sources.add(source_key) - + return { "heading": section_title, "text": f"Could not generate summary due to error: {e}", "table": [], "chart_data": {}, - "sources": sources + "sources": sources, + "is_comparison": True, + "entities": entities } - class ReportWriterAgent: def __init__(self, doc=None, model_name: str = "unknown", llm=None): - # Don't store the document - create fresh one for each report self.model_name = model_name self.llm = llm # Store LLM for generating summaries def _generate_executive_summary(self, sections: list[dict], is_comparison: bool, entities: list[str], target_language: str = "english", query: str | None = None) -> str: - """Generate an executive summary based on actual section content and user query""" if not self.llm: return self._generate_intro_section(is_comparison, entities) - - # Extract key information from sections + section_summaries = [] for section in sections: heading = section.get("heading", "Unknown Section") text = section.get("text", "") if text: - section_summaries.append(f"**{heading}**: {text}") - + section_summaries.append(f"{heading}: {text}") + sections_text = "\n\n".join(section_summaries) - - # Add language instruction if not English + language_instruction = "" if target_language == "arabic": language_instruction = "\n\nIMPORTANT: Write the entire executive summary in Arabic (العربية). Use professional Arabic business terminology." @@ -505,12 +684,9 @@ def _generate_executive_summary(self, sections: list[dict], is_comparison: bool, language_instruction = "\n\nIMPORTANT: Write the entire executive summary in Spanish. Use professional Spanish business terminology." elif target_language == "french": language_instruction = "\n\nIMPORTANT: Write the entire executive summary in French. Use professional French business terminology." - - # Include user query context if available - query_context = "" - if query: - query_context = f"\nUser's Original Request:\n{query}\n" - + + query_context = f"\nUser's Original Request:\n{query}\n" if query else "" + if is_comparison: prompt = f""" You are writing an executive summary for a comparison report between {entities[0]} and {entities[1]}. @@ -523,6 +699,8 @@ def _generate_executive_summary(self, sections: list[dict], is_comparison: bool, Section Summaries: {sections_text} +CRITICAL: Never use possessive forms (no apostrophes). Write "Oracle revenue" not "Oracle's revenue", "company performance" not "company's performance". + Write in a professional, analytical tone. Focus on answering the user's specific request.{language_instruction} """ else: @@ -537,9 +715,11 @@ def _generate_executive_summary(self, sections: list[dict], is_comparison: bool, Section Summaries: {sections_text} +CRITICAL: Never use possessive forms (no apostrophes). Write "Oracle revenue" not "Oracle's revenue", "company performance" not "company's performance". + Write in a professional, analytical tone. Focus on answering the user's specific request.{language_instruction} """ - + try: response = self.llm.invoke([type("Msg", (object,), {"content": prompt})()]).content.strip() return response @@ -548,31 +728,27 @@ def _generate_executive_summary(self, sections: list[dict], is_comparison: bool, return self._generate_intro_section(is_comparison, entities) def _generate_conclusion(self, sections: list[dict], is_comparison: bool, entities: list[str], target_language: str = "english", query: str | None = None) -> str: - """Generate a conclusion based on actual section content and user query""" if not self.llm: return "This analysis provides insights based on available data from retrieved documents." - - # Extract key findings from sections + key_findings = [] for section in sections: heading = section.get("heading", "Unknown Section") text = section.get("text", "") table = section.get("table", []) - - # Extract key metrics from tables + if table and isinstance(table, list): - for row in table[:3]: # Top 3 rows + for row in table[:3]: if isinstance(row, dict): metric = row.get("Metric", "") if metric: key_findings.append(f"{heading}: {metric}") - + if text: key_findings.append(f"{heading}: {text}") - - findings_text = "\n".join(key_findings[:8]) # Limit to prevent token overflow - - # Add language instruction if not English + + findings_text = "\n".join(key_findings[:8]) + language_instruction = "" if target_language == "arabic": language_instruction = "\n\nIMPORTANT: Write the entire conclusion in Arabic (العربية). Use professional Arabic business terminology." @@ -580,12 +756,9 @@ def _generate_conclusion(self, sections: list[dict], is_comparison: bool, entiti language_instruction = "\n\nIMPORTANT: Write the entire conclusion in Spanish. Use professional Spanish business terminology." elif target_language == "french": language_instruction = "\n\nIMPORTANT: Write the entire conclusion in French. Use professional French business terminology." - - # Include user query context if available - query_context = "" - if query: - query_context = f"\nUser's Original Request:\n{query}\n" - + + query_context = f"\nUser's Original Request:\n{query}\n" if query else "" + if is_comparison: prompt = f""" Based on the analysis of {entities[0]} and {entities[1]}, write a conclusion that directly answers the user's request. @@ -599,6 +772,8 @@ def _generate_conclusion(self, sections: list[dict], is_comparison: bool, entiti - Provide actionable insights based on their specific needs - Include specific recommendations if appropriate +CRITICAL: Never use possessive forms (no apostrophes). Write "Oracle revenue" not "Oracle's revenue", "company growth" not "company's growth". + Focus on providing value for the user's specific use case.{language_instruction} """ else: @@ -614,97 +789,77 @@ def _generate_conclusion(self, sections: list[dict], is_comparison: bool, entiti - Provide actionable insights based on their specific needs - Include specific recommendations if appropriate +CRITICAL: Never use possessive forms (no apostrophes). Write "Oracle revenue" not "Oracle's revenue", "company growth" not "company's growth". + Focus on providing value for the user's specific use case.{language_instruction} """ - + try: response = self.llm.invoke([type("Msg", (object,), {"content": prompt})()]).content.strip() return response except Exception as e: logger.warning(f"Failed to generate conclusion: {e}") return "This analysis provides insights based on available data from retrieved documents." - + def _filter_failed_sections(self, sections: list[dict]) -> list[dict]: - """Filter out sections that contain error messages or failed processing""" filtered_sections = [] - + error_patterns = [ + "Could not generate", + "due to error:", + "Expecting ',' delimiter:", + "Failed to", + "Error:", + "Exception:", + "Traceback" + ] for section in sections: text = section.get("text", "") heading = section.get("heading", "") - - # Check for common error patterns - error_patterns = [ - "Could not generate", - "due to error:", - "Expecting ',' delimiter:", - "Failed to", - "Error:", - "Exception:", - "Traceback" - ] - - # Check if section contains error messages has_error = any(pattern in text for pattern in error_patterns) - if not has_error: filtered_sections.append(section) else: logger.info(f"🚫 Filtered out failed section: {heading}") - return filtered_sections - + def _apply_document_styling(self, doc): - """Apply professional styling to the document""" from docx.shared import Pt, RGBColor - from docx.enum.text import WD_ALIGN_PARAGRAPH - - # Set default font for the document style = doc.styles['Normal'] font = style.font font.name = 'Times New Roman' font.size = Pt(12) - - # Style headings heading1_style = doc.styles['Heading 1'] heading1_style.font.name = 'Times New Roman' heading1_style.font.size = Pt(18) heading1_style.font.bold = True - heading1_style.font.color.rgb = RGBColor(0x00, 0x00, 0x00) # Black - + heading1_style.font.color.rgb = RGBColor(0x00, 0x00, 0x00) heading2_style = doc.styles['Heading 2'] heading2_style.font.name = 'Times New Roman' heading2_style.font.size = Pt(14) heading2_style.font.bold = True - heading2_style.font.color.rgb = RGBColor(0x00, 0x00, 0x00) # Black - + heading2_style.font.color.rgb = RGBColor(0x00, 0x00, 0x00) + def _generate_report_title(self, is_comparison: bool, entities: list[str], query: str | None, sections: list[dict]) -> str: - """Generate a dynamic, informative report title based on user query""" if query and self.llm: - # Use LLM to generate a more specific title based on the query try: entity_context = f"{entities[0]} vs {entities[1]}" if is_comparison and len(entities) >= 2 else entities[0] if entities else "Organization" - prompt = f"""Generate a concise, professional report title (max 10 words) based on: User Query: {query} Entities: {entity_context} Type: {'Comparison' if is_comparison else 'Analysis'} Report +CRITICAL: Never use possessive forms (no apostrophes). Write "Oracle Performance" not "Oracle's Performance". + Return ONLY the title, no quotes or extra text.""" - title = self.llm.invoke([type("Msg", (object,), {"content": prompt})()]).content.strip() - # Clean up the title title = title.replace('"', '').replace("'", '').strip() - # Ensure it's not too long if len(title) > 100: title = title[:97] + "..." return title except Exception as e: logger.warning(f"Failed to generate dynamic title: {e}") - # Fall back to default title generation - - # Default title generation logic + if query: - # Extract key topics from the query query_lower = query.lower() if "esg" in query_lower or "sustainability" in query_lower: topic_type = "ESG & Sustainability" @@ -719,7 +874,6 @@ def _generate_report_title(self, is_comparison: bool, entities: list[str], query else: topic_type = "Business Analysis" else: - # Infer from section headings section_topics = [s.get("heading", "") for s in sections[:3]] if any("climate" in h.lower() or "carbon" in h.lower() for h in section_topics): topic_type = "Climate & Environmental" @@ -727,161 +881,123 @@ def _generate_report_title(self, is_comparison: bool, entities: list[str], query topic_type = "ESG & Sustainability" else: topic_type = "Business Analysis" - + if is_comparison and len(entities) >= 2: return f"{topic_type} Report: {entities[0]} vs {entities[1]}" elif entities: return f"{topic_type} Report: {entities[0]}" else: return f"{topic_type} Report" - + def _add_report_header(self, doc, report_title: str, is_comparison: bool, entities: list[str]): - """Add a professional report header with title, date, and metadata""" from docx.shared import Pt, RGBColor from docx.enum.text import WD_ALIGN_PARAGRAPH - - # Main title + title_paragraph = doc.add_heading(report_title, level=1) title_paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER - - # Add subtitle with entity information + if is_comparison and len(entities) >= 2: subtitle = f"Comparative Analysis: {entities[0]} and {entities[1]}" elif entities: subtitle = f"Analysis of {entities[0]}" else: subtitle = "Comprehensive Analysis Report" - + subtitle_paragraph = doc.add_paragraph(subtitle) subtitle_paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER subtitle_run = subtitle_paragraph.runs[0] subtitle_run.font.size = Pt(12) subtitle_run.italic = True - - # Add generation date and metadata + now = datetime.datetime.now() date_str = now.strftime("%B %d, %Y") time_str = now.strftime("%H:%M") - - doc.add_paragraph() # spacing - - # Create a professional metadata section + + doc.add_paragraph() metadata_paragraph = doc.add_paragraph() metadata_paragraph.alignment = WD_ALIGN_PARAGRAPH.CENTER - metadata_text = f"Generated on {date_str} at {time_str}\nPowered by OCI Generative AI" metadata_run = metadata_paragraph.add_run(metadata_text) metadata_run.font.size = Pt(10) - metadata_run.font.color.rgb = RGBColor(0x70, 0x70, 0x70) # Gray color - - # Add separator line + metadata_run.font.color.rgb = RGBColor(0x70, 0x70, 0x70) + doc.add_paragraph() separator = doc.add_paragraph("─" * 50) separator.alignment = WD_ALIGN_PARAGRAPH.CENTER separator_run = separator.runs[0] separator_run.font.color.rgb = RGBColor(0x70, 0x70, 0x70) - - doc.add_paragraph() # spacing after header - + doc.add_paragraph() + def _detect_target_language(self, query: str | None) -> str: - """Detect the target language from the query""" if not query: return "english" - - query_lower = query.lower() - - # Arabic language indicators + q = query.lower() arabic_indicators = [ - "بالعربية", "باللغة العربية", "in arabic", "arabic report", "تقرير", + "بالعربية", "باللغة العربية", "in arabic", "arabic report", "تقرير", "تحليل", "باللغة العربيه", "عربي", "arabic language" ] - - # Check for Arabic script arabic_chars = any('\u0600' <= char <= '\u06FF' for char in query) - - # Check for explicit language requests - if any(indicator in query_lower for indicator in arabic_indicators) or arabic_chars: + if any(ind in q for ind in arabic_indicators) or arabic_chars: return "arabic" - - # Add more languages as needed - if "en español" in query_lower or "in spanish" in query_lower: + if "en español" in q or "in spanish" in q: return "spanish" - - if "en français" in query_lower or "in french" in query_lower: + if "en français" in q or "in french" in q: return "french" - return "english" - + def _ensure_language_consistency(self, sections: list[dict], target_language: str, query: str | None) -> list[dict]: - """Ensure all sections are in the target language""" if not self.llm or target_language == "english": return sections - logger.info(f"🔄 Ensuring language consistency for {target_language}") - corrected_sections = [] - for section in sections: corrected_section = section.copy() - - # Check and translate heading if needed heading = section.get("heading", "") + text = section.get("text", "") + table = section.get("table", []) + if heading and not self._is_in_target_language(heading, target_language): corrected_section["heading"] = self._translate_text(heading, target_language, "section heading") - - # Check and translate text if needed - text = section.get("text", "") if text and not self._is_in_target_language(text, target_language): corrected_section["text"] = self._translate_text(text, target_language, "section text") - - # Handle table translations - table = section.get("table", []) + if table and isinstance(table, list): corrected_table = [] for row in table: if isinstance(row, dict): corrected_row = {} for key, value in row.items(): - # Translate table headers and values - translated_key = self._translate_text(str(key), target_language, "table header") if not self._is_in_target_language(str(key), target_language) else str(key) - translated_value = self._translate_text(str(value), target_language, "table value") if not self._is_in_target_language(str(value), target_language) and not str(value).replace('.', '').replace(',', '').isdigit() else str(value) + k = str(key) + v = str(value) + translated_key = self._translate_text(k, target_language, "table header") if not self._is_in_target_language(k, target_language) else k + # keep numeric strings unchanged + if not self._is_in_target_language(v, target_language) and not v.replace('.', '').replace(',', '').isdigit(): + translated_value = self._translate_text(v, target_language, "table value") + else: + translated_value = v corrected_row[translated_key] = translated_value corrected_table.append(corrected_row) corrected_section["table"] = corrected_table - + corrected_sections.append(corrected_section) - return corrected_sections - + def _is_in_target_language(self, text: str, target_language: str) -> bool: - """Check if text is already in the target language""" if not text or target_language == "english": return True - if target_language == "arabic": - # Check if text contains Arabic characters arabic_chars = sum(1 for char in text if '\u0600' <= char <= '\u06FF') total_chars = sum(1 for char in text if char.isalpha()) if total_chars == 0: - return True # No alphabetic characters, assume it's fine - return arabic_chars / total_chars > 0.3 # At least 30% Arabic characters - - # Add more language detection logic as needed - return True # Default to assuming it's correct - + return True + return arabic_chars / total_chars > 0.3 + return True + def _translate_text(self, text: str, target_language: str, context: str = "") -> str: - """Translate text to target language using LLM""" if not text or not self.llm: return text - - language_names = { - "arabic": "Arabic", - "spanish": "Spanish", - "french": "French" - } - + language_names = {"arabic": "Arabic", "spanish": "Spanish", "french": "French"} target_lang_name = language_names.get(target_language, target_language.title()) - prompt = f"""Translate the following {context} to {target_lang_name}. Maintain the professional tone and technical accuracy. If it's already in {target_lang_name}, return it unchanged. @@ -889,7 +1005,6 @@ def _translate_text(self, text: str, target_language: str, context: str = "") -> Text to translate: {text} Translation:""" - try: response = self.llm.invoke([type("Msg", (object,), {"content": prompt})()]).content.strip() logger.info(f"Translated {context}: '{text[:50]}...' → '{response[:50]}...'") @@ -897,33 +1012,25 @@ def _translate_text(self, text: str, target_language: str, context: str = "") -> except Exception as e: logger.warning(f"Failed to translate {context}: {e}") return text - + def _generate_intro_section(self, is_comparison: bool, entities: list[str]) -> str: - """Fallback intro section when LLM is not available""" if is_comparison: - comparison_note = ( - f"This report compares data between {entities[0]} and {entities[1]} across key topics." - ) + comparison_note = f"This report compares data between {entities[0]} and {entities[1]} across key topics." else: comparison_note = f"This report presents information for {entities[0]}." - return ( f"{comparison_note} All data is sourced from retrieved documents and structured using LLM-based analysis.\n\n" "The analysis includes tables and charts where possible. Missing data is noted explicitly." ) - + def _organize_sections_with_llm(self, sections: list[dict], query: str | None, entities: list[str]) -> list[dict]: - """Use LLM to intelligently organize sections into a hierarchical structure""" if not query or not self.llm or not sections: return sections - - # Create a list of section titles section_info = [] for i, section in enumerate(sections): section_info.append(f"{i+1}. {section.get('heading', 'Untitled Section')}") - sections_list = "\n".join(section_info) - + prompt = f"""You are organizing sections for a report about {', '.join(entities)}. User's Original Request: @@ -942,7 +1049,7 @@ def _organize_sections_with_llm(self, sections: list[dict], query: str | None, e {{ "title": "Main Category Title from User's Request", "level": 1, - "sections": [1, 3, 5] // section numbers that belong under this category + "sections": [1, 3, 5] }}, {{ "title": "Another Main Category", @@ -950,7 +1057,7 @@ def _organize_sections_with_llm(self, sections: list[dict], query: str | None, e "sections": [2, 4, 6] }} ], - "orphan_sections": [7, 8] // sections that don't fit under any main category + "orphan_sections": [7, 8] }} IMPORTANT: @@ -964,39 +1071,29 @@ def _organize_sections_with_llm(self, sections: list[dict], query: str | None, e try: response = self.llm.invoke([type("Msg", (object,), {"content": prompt})()]).content.strip() - - # Clean and parse JSON response - import json - import re - - # Extract JSON from response + import json, re json_match = re.search(r'\{.*\}', response, re.DOTALL) if json_match: json_str = json_match.group() structure = json.loads(json_str) - - # Build organized sections list + organized = [] used_sections = set() - + for category in structure.get("structure", []): - # Add main category as a header-only section organized.append({ "heading": category.get("title", "Category"), "level": 1, "is_category_header": True }) - - # Add sections under this category for section_num in category.get("sections", []): - idx = section_num - 1 # Convert to 0-based index + idx = section_num - 1 if 0 <= idx < len(sections) and idx not in used_sections: section_copy = sections[idx].copy() section_copy["level"] = 2 organized.append(section_copy) used_sections.add(idx) - - # Add orphan sections at the end + for section_num in structure.get("orphan_sections", []): idx = section_num - 1 if 0 <= idx < len(sections) and idx not in used_sections: @@ -1004,33 +1101,23 @@ def _organize_sections_with_llm(self, sections: list[dict], query: str | None, e section_copy["level"] = 2 organized.append(section_copy) used_sections.add(idx) - - # Add any sections not mentioned in the structure + for i, section in enumerate(sections): if i not in used_sections: section_copy = section.copy() section_copy["level"] = 2 organized.append(section_copy) - + return organized - except Exception as e: logger.warning(f"Failed to organize sections with LLM: {e}") - # Return original sections if organization fails - pass - - # Return original sections if LLM organization fails or isn't attempted + return sections - - - + def _build_references_section(self, sections: list[dict]) -> tuple[dict, str]: - """Build a references section from all sources in sections and return citation map""" all_sources = [] citation_map = {} citation_counter = 1 - - # Collect all unique sources seen_sources = set() for section in sections: sources = section.get("sources", []) @@ -1041,137 +1128,108 @@ def _build_references_section(self, sections: list[dict]) -> tuple[dict, str]: citation_map[source_key] = citation_counter citation_counter += 1 seen_sources.add(source_key) - - # Build references text + references_text = [] for i, source in enumerate(all_sources, 1): file_name = source.get("file", "Unknown") sheet = source.get("sheet", "") entity = source.get("entity", "") - if sheet: ref_text = f"[{i}] {file_name}, Sheet: {sheet}" else: ref_text = f"[{i}] {file_name}" - if entity: ref_text += f" ({entity})" - references_text.append(ref_text) - + return citation_map, "\n".join(references_text) - + def write_report(self, sections: list[dict], filter_failures: bool = True, query: str | None = None) -> str: if not isinstance(sections, list): raise TypeError("Expected list of sections") - - # Detect requested language from query + target_language = self._detect_target_language(query) logger.info(f"🌐 Detected target language: {target_language}") - - # Filter out failed sections if requested + if filter_failures: sections = self._filter_failed_sections(sections) logger.info(f"📊 After filtering failures: {len(sections)} sections remaining") - - # Validate and fix language consistency across all sections + if target_language != "english": sections = self._ensure_language_consistency(sections, target_language, query) - - # Create a fresh document for each report to prevent accumulation + doc = Document() - - # Apply professional document styling self._apply_document_styling(doc) - - # Create reports directory if it doesn't exist + reports_dir = "reports" os.makedirs(reports_dir, exist_ok=True) - - # Extract metadata from sections - is_comparison = sections[0].get("is_comparison", False) if sections else False - entities = sections[0].get("entities", []) if sections else [] - - # Generate dynamic report title + + # NEW: infer comparison/entity context from first valid section (or defaults) + is_comparison = False + entities: list[str] = [] + for s in sections: + if "entities" in s: + entities = list(s.get("entities") or []) + if "is_comparison" in s: + is_comparison = bool(s.get("is_comparison")) + if entities: + break + report_title = self._generate_report_title(is_comparison, entities, query, sections) - - # Add professional header self._add_report_header(doc, report_title, is_comparison, entities) - # PARALLEL GENERATION of executive summary and conclusion while processing sections - from concurrent.futures import ThreadPoolExecutor, as_completed - - summary_and_conclusion_futures = [] - - if self.llm: # Only if LLM is available for intelligent generation + from concurrent.futures import ThreadPoolExecutor + if self.llm: with ThreadPoolExecutor(max_workers=2) as summary_executor: - # Start executive summary generation in parallel summary_future = summary_executor.submit( self._generate_executive_summary, sections, is_comparison, entities, target_language, query ) - summary_and_conclusion_futures.append(("summary", summary_future)) - - # Start conclusion generation in parallel conclusion_future = summary_executor.submit( self._generate_conclusion, sections, is_comparison, entities, target_language, query ) - summary_and_conclusion_futures.append(("conclusion", conclusion_future)) - - # Add executive summary + doc.add_heading("Executive Summary", level=2) - executive_summary = summary_future.result() # Wait for completion + executive_summary = summary_future.result() + add_inline_markdown_paragraph(doc, executive_summary) doc.add_paragraph(executive_summary) - doc.add_paragraph() # spacing + doc.add_paragraph() - # Organize sections hierarchically using LLM organized_sections = self._organize_sections_with_llm(sections, query, entities) - - # Build citation map before adding sections citation_map, references_text = self._build_references_section(organized_sections) - - # Add organized sections with citations + for section in organized_sections: if section.get("is_category_header"): - # This is a main category header doc.add_heading(section.get("heading", "Category"), level=1) else: - # Regular section with appropriate level and citations level = section.get("level", 2) append_to_doc(doc, section, level=level, citation_map=citation_map) - doc.add_paragraph() # spacing between sections + doc.add_paragraph() - # Add conclusion doc.add_heading("Conclusion", level=2) - conclusion = conclusion_future.result() # Wait for completion + conclusion = conclusion_future.result() + add_inline_markdown_paragraph(doc, conclusion) doc.add_paragraph(conclusion) - - # Add References section (already built above) + if references_text: - doc.add_paragraph() # spacing + doc.add_paragraph() doc.add_heading("References", level=2) doc.add_paragraph(references_text) else: - # Fallback for when no LLM is available doc.add_heading("Executive Summary", level=2) executive_summary = self._generate_intro_section(is_comparison, entities) doc.add_paragraph(executive_summary) - doc.add_paragraph() # spacing + doc.add_paragraph() - # Build citation map citation_map, references_text = self._build_references_section(sections) - - # Add all sections with citations (no LLM available for organization) for section in sections: append_to_doc(doc, section, level=2, citation_map=citation_map) - doc.add_paragraph() # spacing between sections + doc.add_paragraph() doc.add_heading("Conclusion", level=2) conclusion = "This analysis provides insights based on available data from retrieved documents." doc.add_paragraph(conclusion) - - # Add References section (already built above) if references_text: - doc.add_paragraph() # spacing + doc.add_paragraph() doc.add_heading("References", level=2) doc.add_paragraph(references_text) @@ -1181,15 +1239,19 @@ def write_report(self, sections: list[dict], filter_failures: bool = True, query save_doc(doc, filepath) return filepath + # Example usage if __name__ == "__main__": doc = Document() sample_section = { "heading": "Climate Commitments", - "text": "Both Elinexa and Aelwyn have committed to net-zero targets...", - "table": [{"Bank": "Elinexa", "Target": "Net-zero 2050"}, - {"Bank": "Aelwyn", "Target": "Net-zero 2050"}], - "chart_data": {"Elinexa": 42, "Aelwyn": 36} + "text": "Both Acme Bank and Globex Bank have committed to net-zero targets...", + "table": [{"Bank": "Acme Bank", "Target": "Net-zero 2050"}, + {"Bank": "Globex Bank", "Target": "Net-zero 2050"}], + "chart_data": {"Acme Bank": 42, "Globex Bank": 36}, + # NEW: tell the pipeline which two entities are being compared + "entities": ["Acme Bank", "Globex Bank"], + "is_comparison": True } agent = ReportWriterAgent(doc) agent.write_report([sample_section]) diff --git a/ai/generative-ai-service/complex-document-rag/files/gradio.css b/ai/generative-ai-service/complex-document-rag/files/gradio.css index 9d847f9c6..3296d7199 100644 --- a/ai/generative-ai-service/complex-document-rag/files/gradio.css +++ b/ai/generative-ai-service/complex-document-rag/files/gradio.css @@ -1,12 +1,15 @@ /* ===== CLEAN LIGHT THEME ===== */ :root { - --primary-color: #ff6b35; - --secondary-color: #6c757d; - --background-color: #ffffff; - --surface-color: #ffffff; + --primary-color: #c74634; + --oracle-red: #c74634; + --secondary-color: #6f757e; + --background-color: #fffefe; + --surface-color: #fffefe; + --off-white: #fffefe; --border-color: #dee2e6; - --text-color: #212529; - --text-muted: #6c757d; + --text-color: #312d2a; + --text-muted: #6f7572; + --dark-grey: #404040; } /* ===== GLOBAL STYLING ===== */ @@ -36,7 +39,7 @@ /* ===== BUTTONS ===== */ .gr-button, button, .primary-button, .secondary-button { - background: white !important; + background: var(--off-white) !important; color: var(--primary-color) !important; border: 1px solid var(--primary-color) !important; padding: 10px 20px !important; @@ -46,43 +49,95 @@ letter-spacing: 0.5px !important; cursor: pointer !important; font-size: 12px !important; - transition: color 0.2s ease !important; + transition: background-color 0.2s ease, color 0.2s ease !important; } .gr-button:hover, button:hover, .primary-button:hover, .secondary-button:hover { background: #f8f8f8 !important; color: var(--primary-color) !important; + padding: 10px 20px !important; /* Keep same padding to prevent jumpy behavior */ } .gr-button:active, button:active, .primary-button:active, .secondary-button:active { background: #f0f0f0 !important; color: var(--primary-color) !important; + padding: 10px 20px !important; /* Keep same padding to prevent jumpy behavior */ } /* ===== TABS ===== */ -.gr-tabs .gr-tab-nav button { - background: #6c757d !important; - color: white !important; +/* Target all possible tab button selectors for Gradio */ +.gr-tabs .tab-nav button, +.gr-tabs .gr-tab-nav button, +div[role="tablist"] button, +button[role="tab"], +.gradio-container .gr-tabs button[role="tab"], +.gradio-container button.tab-nav-button { + background: #c74634 !important; + background-color: #c74634 !important; + color: #fffefe !important; border: none !important; + border-bottom: 3px solid transparent !important; /* Remove orange underline */ padding: 12px 20px !important; font-weight: 500 !important; text-transform: uppercase !important; letter-spacing: 0.5px !important; border-radius: 4px 4px 0 0 !important; margin-right: 2px !important; -} - -.gr-tabs .gr-tab-nav button.selected { - background: #495057 !important; -} - -.gr-tabs .gr-tab-nav button:hover { - background: #5a6268 !important; + transition: background-color 0.3s ease, border-bottom 0.3s ease !important; + opacity: 0.8 !important; +} + +/* Selected/Active tab with black underline */ +.gr-tabs .tab-nav button.selected, +.gr-tabs .gr-tab-nav button.selected, +div[role="tablist"] button.selected, +button[role="tab"][aria-selected="true"], +button[role="tab"].selected, +.gradio-container .gr-tabs button[role="tab"].selected, +.gradio-container button.tab-nav-button.selected { + background: #c74634 !important; + background-color: #c74634 !important; + opacity: 1 !important; + color: #fffefe !important; + font-weight: 500 !important; /* Keep same weight as non-selected to prevent jumpy behavior */ + border-bottom: 3px solid #312d2a !important; /* Black underline for active tab */ + padding: 12px 20px !important; /* Keep same padding */ +} + +/* Hover state for non-selected tabs */ +.gr-tabs .tab-nav button:hover:not(.selected), +.gr-tabs .gr-tab-nav button:hover:not(.selected), +div[role="tablist"] button:hover:not(.selected), +button[role="tab"]:hover:not([aria-selected="true"]), +button[role="tab"]:hover:not(.selected), +.gradio-container .gr-tabs button[role="tab"]:hover:not(.selected), +.gradio-container button.tab-nav-button:hover:not(.selected) { + background: #404040 !important; + background-color: #404040 !important; + color: #fffefe !important; + opacity: 1 !important; + padding: 12px 20px !important; /* Keep same padding */ +} + +/* Additional override for any nested spans or text elements in tabs */ +.gr-tabs button span, +button[role="tab"] span, +.gr-tabs button *, +button[role="tab"] * { + color: inherit !important; +} + +/* Remove any orange borders/underlines that might appear */ +button[role="tab"]::after, +button[role="tab"]::before, +.gr-tabs button::after, +.gr-tabs button::before { + display: none !important; } /* ===== COMPACT UPLOAD SECTIONS ===== */ .upload-section { - background: white !important; + background: var(--off-white) !important; border: 1px solid var(--border-color) !important; border-radius: 8px !important; padding: 12px !important; @@ -113,13 +168,13 @@ margin: 8px 0 !important; display: block !important; padding: 12px !important; - background: white !important; + background: var(--off-white) !important; border: 1px solid var(--primary-color) !important; } /* ===== INFERENCE LAYOUT ===== */ .inference-left-column, .inference-right-column { - background: white !important; + background: var(--off-white) !important; padding: 20px !important; } @@ -127,30 +182,37 @@ margin-bottom: 16px !important; } -.model-controls, .collection-controls { - background: white !important; +/* Make control sections more compact */ +.model-controls, .collection-controls, .processing-controls { + background: var(--off-white) !important; border: 1px solid var(--border-color) !important; border-radius: 6px !important; - padding: 12px !important; - margin-bottom: 12px !important; + padding: 8px !important; /* Reduced padding for compactness */ + margin-bottom: 8px !important; /* Reduced margin */ } .processing-controls { - background: white !important; border: 1px solid var(--primary-color) !important; - border-radius: 6px !important; - padding: 12px !important; - margin-bottom: 12px !important; } -.compact-query textarea { - min-height: 120px !important; - max-height: 150px !important; +/* Compact headers in control sections */ +.model-controls h4, +.collection-controls h4, +.processing-controls h4 { + font-size: 12px !important; + margin-bottom: 4px !important; +} + +/* Make query textarea much larger */ +.compact-query textarea, +.query-section textarea { + min-height: 360px !important; /* 3x larger than before */ + max-height: 450px !important; } /* ===== INPUT FIELDS ===== */ .gr-textbox, .gr-textbox textarea, .gr-textbox input { - background: white !important; + background: var(--off-white) !important; border: 1px solid var(--border-color) !important; border-radius: 4px !important; color: var(--text-color) !important; @@ -159,15 +221,22 @@ /* ===== DROPDOWNS ===== */ .gr-dropdown, .gr-dropdown select { - background: white !important; + background: var(--off-white) !important; border: 1px solid var(--border-color) !important; border-radius: 4px !important; color: var(--text-color) !important; } +/* Make dropdowns more compact */ +.model-controls .gr-dropdown, +.collection-controls .gr-dropdown { + padding: 6px !important; + font-size: 13px !important; +} + /* ===== FILE UPLOAD ===== */ .gr-file { - background: white !important; + background: var(--off-white) !important; border: 2px dashed var(--primary-color) !important; border-radius: 8px !important; padding: 20px !important; @@ -212,17 +281,68 @@ display: none !important; } -/* ===== FORCE WHITE BACKGROUNDS ===== */ +/* ===== FORCE OFF-WHITE BACKGROUNDS ===== */ .gr-group, .gr-form, .gr-block { - background: white !important; + background: var(--off-white) !important; } /* ===== DELETE BUTTON ===== */ .gr-button[variant="stop"] { - background: #dc3545 !important; - color: white !important; + background: var(--oracle-red) !important; + color: var(--off-white) !important; + border: 1px solid var(--oracle-red) !important; } .gr-button[variant="stop"]:hover { - background: #c82333 !important; + background: #a13527 !important; + border: 1px solid #a13527 !important; +} + +/* ===== CHECKBOXES - MORE COMPACT ===== */ +.gr-checkbox-group { + display: flex !important; + gap: 12px !important; + flex-wrap: wrap !important; +} + +.gr-checkbox-group label { + font-size: 13px !important; + margin-bottom: 0 !important; +} + +/* ===== COMPACT SETTINGS SECTION ===== */ +.compact-settings { + background: var(--off-white) !important; + border: 1px solid var(--border-color) !important; + border-radius: 6px !important; + padding: 8px !important; + margin-top: 8px !important; +} + +.compact-settings .gr-row { + margin-bottom: 4px !important; +} + +.compact-settings .gr-dropdown { + margin-bottom: 4px !important; +} + +.compact-settings .gr-dropdown label { + font-size: 12px !important; + margin-bottom: 2px !important; +} + +.compact-settings .gr-checkbox { + margin: 0 !important; + padding: 4px !important; +} + +.compact-settings .gr-checkbox label { + font-size: 12px !important; + margin: 0 !important; +} + +/* Remove extra spacing in compact settings */ +.compact-settings > div { + gap: 4px !important; } diff --git a/ai/generative-ai-service/complex-document-rag/files/gradio_app.py b/ai/generative-ai-service/complex-document-rag/files/gradio_app.py index e41b98ebe..9e52af416 100644 --- a/ai/generative-ai-service/complex-document-rag/files/gradio_app.py +++ b/ai/generative-ai-service/complex-document-rag/files/gradio_app.py @@ -1,6 +1,9 @@ #!/usr/bin/env python3 """Oracle Enterprise RAG System Interface.""" +# Disable telemetry first to prevent startup errors +import disable_telemetry + import gradio as gr import logging import os @@ -82,7 +85,7 @@ def __init__(self) -> None: self._initialize_vector_store(self.current_embedding_model) self._initialize_rag_agent( self.current_llm_model, - collection="Multi-Collection", + collection="multi", embedding_model=self.current_embedding_model ) @@ -257,7 +260,7 @@ def _initialize_processors(self) -> Tuple[Optional[XLSXIngester], Optional[PDFIn def _initialize_rag_agent( self, llm_model: str, - collection: str = "Multi-Collection", + collection: str = "multi", embedding_model: Optional[str] = None ) -> bool: """ @@ -265,7 +268,7 @@ def _initialize_rag_agent( Args: llm_model: Name of the LLM model to use - collection: Name of the collection to use (default: "Multi-Collection") + collection: Name of the collection to use (default: "multi") embedding_model: Optional embedding model to switch to Returns: @@ -483,72 +486,30 @@ def create_oracle_interface(): placeholder="Deletion results will appear here..." ) - with gr.Tab("SEARCH COLLECTIONS", id="search"): - gr.Markdown("### Search through your vector store collections") - - # Add embedding model selector for search tab - with gr.Row(): - embedding_model_selector_search = gr.Dropdown( - choices=rag_system.available_embedding_models, - value=rag_system.current_embedding_model, - label="Embedding Model for Search", - info="Select the embedding model to use for searching" - ) - - with gr.Row(): - search_query = gr.Textbox( - label="Search Query", - placeholder="Enter search terms...", - scale=3 - ) - search_collection = gr.Dropdown( - choices=["PDF Documents", "XLSX Documents"], - value="XLSX Documents", - label="Collection", - scale=1 - ) - search_results_count = gr.Slider( - minimum=1, - maximum=20, - value=5, - step=1, - label="Results", - scale=1 - ) - - search_btn = gr.Button("Search", variant="secondary", elem_classes=["secondary-button"]) - - search_results = gr.Textbox( - elem_id="scientific-results-box", - label="Search Results", - lines=25, - max_lines=30, - placeholder="Search results will appear here..." - ) - with gr.Tab("INFERENCE & QUERY", id="inference"): with gr.Row(): - # Left Column - Input Controls + # Left Column - Query Input with gr.Column(scale=1, elem_classes=["inference-left-column"]): - # Query Section + # Large Query Section with gr.Group(elem_classes=["query-section"]): query_input = gr.Textbox( label="Query", - lines=4, - max_lines=6, + lines=15, # Much larger query area + max_lines=20, placeholder="Enter your query here...", elem_classes=["compact-query"] ) + query_btn = gr.Button( "Run Query", elem_classes=["primary-button"], - size="sm", + size="lg", elem_id="run-query-btn" ) - # Model Configuration - with gr.Group(elem_classes=["model-controls"]): - gr.HTML("

Model Configuration

") + # Compact Configuration Section - All in one group + with gr.Group(elem_classes=["compact-settings"]): + # Model Configuration in one row with gr.Row(): llm_model_selector = gr.Dropdown( choices=rag_system.available_llm_models, @@ -559,26 +520,16 @@ def create_oracle_interface(): embedding_model_selector_query = gr.Dropdown( choices=rag_system.available_embedding_models, value=rag_system.current_embedding_model, - label="Embedding Model", + label="Embeddings", interactive=True, scale=1 ) - - # Data Sources - with gr.Group(elem_classes=["collection-controls"]): - gr.HTML("

Data Sources

") + + # Data Sources and Processing Mode in one compact row with gr.Row(): - collection_pdf = gr.Checkbox(label="Include PDF Collection", value=False) - collection_xlsx = gr.Checkbox(label="Include XLSX Collection", value=False) - - # Processing Mode - with gr.Group(elem_classes=["processing-controls"]): - gr.HTML("

Processing Mode

") - agent_mode = gr.Checkbox( - label="Use Agentic Workflow", - value=False, - info="Enable advanced reasoning and multi-step processing" - ) + collection_pdf = gr.Checkbox(label="Include PDF", value=False, scale=1) + collection_xlsx = gr.Checkbox(label="Include XLSX", value=False, scale=1) + agent_mode = gr.Checkbox(label="Agentic Mode", value=False, scale=1) # Right Column - Results with gr.Column(scale=1, elem_classes=["inference-right-column"]): @@ -636,11 +587,11 @@ def process_pdf_and_clear(file, model, entity): outputs=[collection_documents] ) - search_btn.click( - fn=lambda q, coll, emb, n: search_chunks(q, coll, emb, rag_system, n), - inputs=[search_query, search_collection, embedding_model_selector_search, search_results_count], - outputs=[search_results] - ) + # search_btn.click( + # fn=lambda q, coll, emb, n: search_chunks(q, coll, emb, rag_system, n), + # inputs=[search_query, search_collection, embedding_model_selector_search, search_results_count], + # outputs=[search_results] + # ) list_chunks_btn.click( fn=lambda coll, emb: list_all_chunks(coll, emb, rag_system), @@ -697,8 +648,11 @@ def handle_query_with_download(query, llm_model, embedding_model, include_pdf, i gr.update(visible=False) ) - # Actually process the query - response, report_path = process_query(query, llm_model, embedding_model, include_pdf, include_xlsx, agentic, rag_system) + # Actually process the query with entity parameters + # Pass empty strings for entities to trigger automatic detection + entity1 = "" # Will be automatically detected by the LLM + entity2 = "" # Will be automatically detected by the LLM + response, report_path = process_query(query, llm_model, embedding_model, include_pdf, include_xlsx, agentic, rag_system, entity1, entity2) progress(1.0, desc="Complete!") @@ -720,6 +674,7 @@ def handle_query_with_download(query, llm_model, embedding_model, include_pdf, i query_btn.click( fn=handle_query_with_download, + # inputs=[query_input, llm_model_selector, embedding_model_selector_query, collection_pdf, collection_xlsx, agent_mode, entity1_input, entity2_input], inputs=[query_input, llm_model_selector, embedding_model_selector_query, collection_pdf, collection_xlsx, agent_mode], outputs=[status_box, response_box, download_file], show_progress="full" diff --git a/ai/generative-ai-service/complex-document-rag/files/handlers/pdf_handler.py b/ai/generative-ai-service/complex-document-rag/files/handlers/pdf_handler.py index 6fb4caeed..b4889de09 100644 --- a/ai/generative-ai-service/complex-document-rag/files/handlers/pdf_handler.py +++ b/ai/generative-ai-service/complex-document-rag/files/handlers/pdf_handler.py @@ -53,8 +53,8 @@ def progress(*args, **kwargs): return "❌ ERROR: Vector store not initialized", "" file_path = Path(file.name) - chunks, doc_id = rag_system.pdf_processor.process_pdf(file_path, entity=entity) - + chunks, doc_id, _ = rag_system.pdf_processor.ingest_pdf(file_path, entity=entity) + print("PDF processor type:", type(rag_system.pdf_processor)) progress(0.7, desc="Adding to vector store...") converted_chunks = [ diff --git a/ai/generative-ai-service/complex-document-rag/files/handlers/query_handler.py b/ai/generative-ai-service/complex-document-rag/files/handlers/query_handler.py index f16a174d8..b9b5f9abf 100644 --- a/ai/generative-ai-service/complex-document-rag/files/handlers/query_handler.py +++ b/ai/generative-ai-service/complex-document-rag/files/handlers/query_handler.py @@ -20,9 +20,11 @@ def process_query( include_xlsx: bool, agentic: bool, rag_system, + entity1: str = "", + entity2: str = "", progress=gr.Progress() ) -> Tuple[str, Optional[str]]: - """Process a query using the RAG system""" + """Process a query using the RAG system with optional entity specification""" if not query.strip(): return "ERROR: Please enter a query", None @@ -108,11 +110,24 @@ def _safe_query(collection_label: str, text: str, n: int = 5): progress(0.8, desc="Generating response...") + # Prepare provided entities if any + provided_entities = [] + if entity1 and entity1.strip(): + provided_entities.append(entity1.strip().lower()) + if entity2 and entity2.strip(): + provided_entities.append(entity2.strip().lower()) + + # Log entities being used + if provided_entities: + logger.info(f"Using provided entities: {provided_entities}") + if all_results: + # Pass provided entities to the RAG system result = rag_system.rag_agent.process_query_with_multi_collection_context( query, all_results, - collection_mode=active_collection + collection_mode=active_collection, + provided_entities=provided_entities if provided_entities else None ) # Ensure result is a dictionary if not isinstance(result, dict): @@ -140,12 +155,12 @@ def query_collection(collection_type): """Query a single collection in parallel""" try: if collection_type == "pdf": - # Increased to 20 chunks for non-agentic workflows - results = _safe_query("pdf", query, n=20) + # Optimized to 10 chunks for faster processing + results = _safe_query("pdf", query, n=10) return ("PDF", results if results else []) elif collection_type == "xlsx": - # Increased to 20 chunks for non-agentic workflows - results = _safe_query("xlsx", query, n=20) + # Optimized to 10 chunks for faster processing + results = _safe_query("xlsx", query, n=10) return ("XLSX", results if results else []) else: return (collection_type.upper(), []) @@ -178,8 +193,11 @@ def query_collection(collection_type): return "No relevant information found in selected collections.", None # Use more chunks for better context in non-agentic mode - # Take top 20 chunks total (or all if less than 20) - chunks_to_use = retrieved_chunks[:20] + # Optimize chunk usage based on model + if llm_model == "grok-4": + chunks_to_use = retrieved_chunks[:15] # Can handle more context + else: + chunks_to_use = retrieved_chunks[:10] # Optimized for speed context_str = "\n\n".join(chunk["content"] for chunk in chunks_to_use) prompt = f"""You are an expert assistant. diff --git a/ai/generative-ai-service/complex-document-rag/files/handlers/vector_handler.py b/ai/generative-ai-service/complex-document-rag/files/handlers/vector_handler.py index 782700a0a..f37cfada1 100644 --- a/ai/generative-ai-service/complex-document-rag/files/handlers/vector_handler.py +++ b/ai/generative-ai-service/complex-document-rag/files/handlers/vector_handler.py @@ -319,8 +319,13 @@ def delete_all_chunks_in_collection(collection_name: str, embedding_model: str, client = rag_system.vector_store.client all_colls = client.list_collections() - # Find all physical collections for this logical group (e.g., xlsx_documents_*) - targets = [c for c in all_colls if c.name.startswith(f"{base_prefix}_")] + # Find all physical collections for this logical group (e.g., xlsx_documents_* or pdf_documents_*) + targets = [] + for c in all_colls: + # Handle both collection objects and dict representations + coll_name = getattr(c, 'name', None) or (c.get('name') if isinstance(c, dict) else str(c)) + if coll_name and coll_name.startswith(f"{base_prefix}_"): + targets.append((coll_name, c)) if not targets: return f"Collection group '{collection_name}' has no collections to delete." @@ -328,21 +333,31 @@ def delete_all_chunks_in_collection(collection_name: str, embedding_model: str, # Delete them all total_deleted_chunks = 0 deleted_names = [] - for coll in targets: + for coll_name, coll_obj in targets: try: count = 0 try: - count = coll.count() + # Get the actual collection object if we only have the name + if isinstance(coll_obj, str): + actual_coll = client.get_collection(coll_name) + else: + actual_coll = coll_obj + count = actual_coll.count() except Exception: pass + total_deleted_chunks += count - client.delete_collection(coll.name) - deleted_names.append(coll.name) - # Also drop from in-memory map if present + client.delete_collection(coll_name) + deleted_names.append(coll_name) + + # Clean up all in-memory references if hasattr(rag_system.vector_store, "collections"): - rag_system.vector_store.collections.pop(coll.name, None) + rag_system.vector_store.collections.pop(coll_name, None) + if hasattr(rag_system.vector_store, "collection_map"): + rag_system.vector_store.collection_map.pop(coll_name, None) + except Exception as e: - logging.error(f"Failed to delete collection '{coll.name}': {e}") + logging.error(f"Failed to delete collection '{coll_name}': {e}") # Recreate the CURRENT model's empty collection so the app keeps a live handle # Build full name like: {base_prefix}_{model_name}_{dimensions} @@ -357,16 +372,28 @@ def delete_all_chunks_in_collection(collection_name: str, embedding_model: str, new_full_name = f"{base_prefix}_{model_name}_{dims}" new_collection = client.get_or_create_collection(name=new_full_name, metadata=metadata) - # Refresh vector_store references for this base prefix + # Refresh ALL vector_store references comprehensively if hasattr(rag_system.vector_store, "collections"): rag_system.vector_store.collections[new_full_name] = new_collection - # Also store under the base key for compatibility with older code paths - rag_system.vector_store.collections[base_prefix] = new_collection - + + if hasattr(rag_system.vector_store, "collection_map"): + rag_system.vector_store.collection_map[new_full_name] = new_collection + # Also ensure the collection_map is properly updated + rag_system.vector_store.collection_map = { + k: v for k, v in rag_system.vector_store.collection_map.items() + if not k.startswith(f"{base_prefix}_") or k == new_full_name + } + rag_system.vector_store.collection_map[new_full_name] = new_collection + + # Update the specific collection references if base_prefix == "xlsx_documents": rag_system.vector_store.xlsx_collection = new_collection + if hasattr(rag_system.vector_store, "current_xlsx_collection_name"): + rag_system.vector_store.current_xlsx_collection_name = new_full_name elif base_prefix == "pdf_documents": rag_system.vector_store.pdf_collection = new_collection + if hasattr(rag_system.vector_store, "current_pdf_collection_name"): + rag_system.vector_store.current_pdf_collection_name = new_full_name # Nice summary deleted_list = "\n".join(f" • {name}" for name in deleted_names) if deleted_names else " • (none)" @@ -374,7 +401,7 @@ def delete_all_chunks_in_collection(collection_name: str, embedding_model: str, "✅ DELETION COMPLETED\n\n" f"Logical collection: {collection_name}\n" f"Collections removed: {len(deleted_names)}\n" - f"Total chunks deleted (best-effort): {total_deleted_chunks}\n" + f"Total chunks deleted: {total_deleted_chunks}\n" f"Deleted collections:\n{deleted_list}\n\n" "Recreated empty collection for current model:\n" f" • {new_full_name}\n" diff --git a/ai/generative-ai-service/complex-document-rag/files/handlers/xlsx_handler.py b/ai/generative-ai-service/complex-document-rag/files/handlers/xlsx_handler.py index 3c4ca27c0..95a118aaf 100644 --- a/ai/generative-ai-service/complex-document-rag/files/handlers/xlsx_handler.py +++ b/ai/generative-ai-service/complex-document-rag/files/handlers/xlsx_handler.py @@ -56,7 +56,15 @@ def progress(*args, **kwargs): return "❌ ERROR: Vector store not initialized", "" file_path = Path(file.name) - chunks, doc_id = rag_system.xlsx_processor.ingest_xlsx(file_path, entity=entity) + # Now returns 3 values: chunks, doc_id, and chunks_to_delete + result = rag_system.xlsx_processor.ingest_xlsx(file_path, entity=entity) + + # Handle both old (2-tuple) and new (3-tuple) return formats + if len(result) == 3: + chunks, doc_id, chunks_to_delete = result + else: + chunks, doc_id = result + chunks_to_delete = [] progress(0.7, desc="Adding to vector store...") @@ -69,6 +77,17 @@ def progress(*args, **kwargs): for chunk in chunks ] + # Delete original chunks FIRST if they were rewritten + if chunks_to_delete and hasattr(rag_system.vector_store, 'delete_chunks'): + progress(0.7, desc="Removing original chunks that were rewritten...") + try: + rag_system.vector_store.delete_chunks('xlsx_documents', chunks_to_delete) + logger.info(f"Deleted {len(chunks_to_delete)} original chunks that were rewritten") + except Exception as e: + logger.warning(f"Could not delete original chunks: {e}") + + # THEN add the new rewritten chunks to vector store + progress(0.8, desc="Adding rewritten chunks to vector store...") rag_system.vector_store.add_xlsx_chunks(converted_chunks, doc_id) progress(1.0, desc="Complete!") @@ -78,6 +97,9 @@ def progress(*args, **kwargs): actual_collection_name = rag_system.vector_store.xlsx_collection.name collection_name = f"{actual_collection_name} ({embedding_model})" + + # Count rewritten chunks + rewritten_count = sum(1 for chunk in chunks if chunk.get('metadata', {}).get('rewritten', False)) summary = f""" ✅ **XLSX PROCESSING COMPLETE** @@ -86,6 +108,7 @@ def progress(*args, **kwargs): **Document ID:** {doc_id} **Entity:** {entity} **Chunks created:** {len(chunks)} +**Chunks with rewritten content:** {rewritten_count} **Embedding model:** {embedding_model} **Collection:** {collection_name} @@ -123,4 +146,3 @@ def progress(*args, **kwargs): error_msg = f"❌ ERROR: Processing XLSX file failed: {str(e)}" logger.error(f"{error_msg}\n{traceback.format_exc()}") return error_msg, traceback.format_exc() - diff --git a/ai/generative-ai-service/complex-document-rag/files/images/screenshot1.png b/ai/generative-ai-service/complex-document-rag/files/images/screenshot1.png new file mode 100644 index 000000000..b93876012 Binary files /dev/null and b/ai/generative-ai-service/complex-document-rag/files/images/screenshot1.png differ diff --git a/ai/generative-ai-service/complex-document-rag/files/ingest_pdf.py b/ai/generative-ai-service/complex-document-rag/files/ingest_pdf.py index 83110ebd5..f2542c54c 100644 --- a/ai/generative-ai-service/complex-document-rag/files/ingest_pdf.py +++ b/ai/generative-ai-service/complex-document-rag/files/ingest_pdf.py @@ -1,159 +1,355 @@ +# pdf_ingester_v2.py +import logging, time, uuid, re, os from pathlib import Path from typing import List, Dict, Any, Optional, Tuple -import uuid -import time -import re import tiktoken +import pandas as pd + +# Hard deps you likely already have: import pdfplumber -import logging + +# Optional but recommended for tables +try: + import camelot + _HAS_CAMELOT = True +except Exception: + _HAS_CAMELOT = False + +# Optional for embedded files +try: + from pypdf import PdfReader + _HAS_PYPDF = True +except Exception: + _HAS_PYPDF = False logger = logging.getLogger(__name__) class PDFIngester: - def __init__(self, tokenizer: str = "BAAI/bge-small-en-v1.5", chunk_rewriter=None): + """ + PDF -> chunks with consistent semantics to XLSXIngester. + Strategy: + 1) Detect embedded spreadsheets -> delegate to XLSXIngester + 2) Try Camelot (lattice->stream) for vector tables + 3) Fallback to pdfplumber tables + 4) Extract remaining prose blocks + 5) Batch + select + batch-rewrite (same as XLSX flow) + """ + + def __init__(self, tokenizer: str = "BAAI/bge-small-en-v1.5", + chunk_rewriter=None, + batch_size: int = 16): + self.tokenizer_name = tokenizer self.chunk_rewriter = chunk_rewriter + self.batch_size = batch_size self.accurate_tokenizer = tiktoken.get_encoding("cl100k_base") - self.tokenizer_name = tokenizer self.stats = { 'total_chunks': 0, 'rewritten_chunks': 0, - 'processing_time': 0, - 'rewriting_time': 0 + 'high_value_chunks': 0, + 'processing_time': 0.0, + 'extraction_time': 0.0, + 'rewriting_time': 0.0, + 'selection_time': 0.0 } - logger.info("📄 PDF processor initialized") - + # ---------- Utility parity with XLSX ---------- def _count_tokens(self, text: str) -> int: if not text or not text.strip(): return 0 return len(self.accurate_tokenizer.encode(text)) - def _should_rewrite(self, text: str) -> bool: - if not text.strip() or self._count_tokens(text) < 120: - return False + def _is_high_value_chunk(self, text: str, metadata: Dict[str, Any]) -> int: + # Same heuristic as your XLSX version (copy/paste with tiny tweaks) + if len(text.strip()) < 100: + return 0 + score = 0 + if re.search(r'\d+\.?\d*\s*(%|MW|GW|tCO2|ktCO2|MtCO2|€|\$|£|million|billion)', + text, re.IGNORECASE): + score += 2 + key_terms = ['revenue','guidance','margin','cash flow','eps', + 'emission','target','reduction','scope','net-zero', + 'renewable','sustainability','biodiversity'] + score += min(2, sum(1 for term in key_terms if term in text.lower())) + if text.count('|') > 5: + score += 1 + skip_indicators = ['cover', 'disclaimer', 'notice', 'table of contents'] + if any(skip in text.lower()[:200] for skip in skip_indicators): + score = max(0, score - 2) + return min(5, score) - pipe_count = text.count('|') - number_ratio = sum(c.isdigit() for c in text) / len(text) if text else 0 - line_count = len(text.splitlines()) + def _batch_rows_by_token_count(self, rows: List[str], max_tokens: int = 400) -> List[List[str]]: + chunks, current, tok = [], [], 0.0 + for row in rows: + if not row or not row.strip(): + continue + est = len(row.split()) * 1.3 + if tok + est > max_tokens: + if current: chunks.append(current) + current, tok = [row], est + else: + current.append(row); tok += est + if current: chunks.append(current) + return chunks - is_tabular = (pipe_count > 10 or number_ratio > 0.3 or line_count > 20) - messy = 'nan' in text.lower() or 'null' in text.lower() - sentence_count = len([s for s in text.split('.') if s.strip()]) - is_prose = sentence_count > 3 and pipe_count < 5 + def _batch_rewrite_chunks(self, chunks_to_rewrite: List[Tuple[str, Dict[str, Any], int]]): + if not chunks_to_rewrite or not self.chunk_rewriter: + return chunks_to_rewrite + start = time.time() + results = [] - return (is_tabular or messy) and not is_prose + # Fast path if your rewriter supports batch + if hasattr(self.chunk_rewriter, 'rewrite_chunks_batch'): + BATCH_SIZE = min(self.batch_size, len(chunks_to_rewrite)) + batches = [chunks_to_rewrite[i:i+BATCH_SIZE] + for i in range(0, len(chunks_to_rewrite), BATCH_SIZE)] - def _rewrite_chunk(self, text: str, metadata: Dict[str, Any]) -> str: - if not self.chunk_rewriter: - return text + for bidx, batch in enumerate(batches, 1): + batch_input = [{'text': t, 'metadata': m} for (t, m, _) in batch] + try: + rewritten = self.chunk_rewriter.rewrite_chunks_batch(batch_input, batch_size=BATCH_SIZE) + except Exception as e: + logger.warning(f"⚠️ Batch {bidx} failed: {e}") + rewritten = [None]*len(batch) + for i, (orig_text, meta, idx) in enumerate(batch): + new_text = rewritten[i] if i < len(rewritten) else None + if new_text and new_text != orig_text: + meta = meta.copy() + meta['rewritten'] = True + self.stats['rewritten_chunks'] += 1 + results.append((new_text, meta, idx)) + else: + results.append((orig_text, meta, idx)) + else: + # Sequential fallback + for (t, m, idx) in chunks_to_rewrite: + try: + new_t = self.chunk_rewriter.rewrite_chunk(t, metadata=m).strip() + except Exception as e: + logger.warning(f"⚠️ Rewrite failed for chunk {idx}: {e}") + new_t = None + if new_t and new_t != t: + m = m.copy(); m['rewritten'] = True + self.stats['rewritten_chunks'] += 1 + results.append((new_t, m, idx)) + else: + results.append((t, m, idx)) + self.stats['rewriting_time'] += time.time() - start + return results + + # ---------- Ingestion helpers ---------- + def _find_embedded_spreadsheets(self, pdf_path: Path) -> List[Tuple[str, bytes]]: + if not _HAS_PYPDF: + return [] try: - rewritten = self.chunk_rewriter.rewrite_chunk(text, metadata=metadata).strip() - if rewritten: - self.stats['rewritten_chunks'] += 1 - return rewritten - except Exception as e: - logger.warning(f"⚠️ Rewrite failed: {e}") - return text - - def process_pdf( - self, - file_path: str | Path, - entity: Optional[str] = None, - max_rewrite_chunks: int = 100 - ) -> Tuple[List[Dict[str, Any]], str]: - start_time = time.time() - self.stats = { - 'total_chunks': 0, - 'rewritten_chunks': 0, - 'processing_time': 0, - 'rewriting_time': 0 - } - all_chunks = [] - rewrite_candidates = [] - document_id = str(uuid.uuid4()) + reader = PdfReader(str(pdf_path)) + names_tree = reader.trailer.get("/Root", {}).get("/Names", {}) + efiles = names_tree.get("/EmbeddedFiles", {}) + names = efiles.get("/Names", []) + pairs = list(zip(names[::2], names[1::2])) + out = [] + for fname, ref in pairs: + spec = ref.getObject() + if "/EF" in spec and "/F" in spec["/EF"]: + data = spec["/EF"]["/F"].getData() + if str(fname).lower().endswith((".xlsx", ".xls", ".csv")): + out.append((str(fname), data)) + return out + except Exception: + return [] - # -------- 1. Validate Inputs -------- + def _extract_tables_with_camelot(self, pdf_path: Path, pages="all") -> List[pd.DataFrame]: + if not _HAS_CAMELOT: + return [] + dfs: List[pd.DataFrame] = [] try: - file = Path(file_path) - if not file.exists() or not file.is_file(): - raise FileNotFoundError(f"File not found: {file_path}") - if not str(file).lower().endswith(('.pdf',)): - raise ValueError(f"File must be a PDF: {file_path}") + # 1) lattice first + tables = camelot.read_pdf(str(pdf_path), pages=pages, flavor="lattice", line_scale=40) + dfs.extend([t.df for t in tables] if tables else []) + # 2) stream fallback if sparse + if not dfs: + tables = camelot.read_pdf(str(pdf_path), pages=pages, flavor="stream", edge_tol=200) + dfs.extend([t.df for t in tables] if tables else []) except Exception as e: - logger.error(f"❌ Error opening file: {e}") - return [], document_id + logger.info(f"Camelot failed: {e}") + return dfs - if not entity or not isinstance(entity, str): - logger.error("❌ Entity name must be provided as a non-empty string when ingesting a PDF file.") - return [], document_id - entity = entity.strip().lower() - - logger.info(f"📄 Processing {file.name}") - - # -------- 2. Main Extraction -------- - try: - with pdfplumber.open(file) as pdf: - for page_num, page in enumerate(pdf.pages): - try: - text = page.extract_text() - except Exception as e: - logger.warning(f"⚠️ Failed to extract text from page {page_num+1}: {e}") + def _extract_tables_with_pdfplumber(self, pdf_path: Path) -> List[Tuple[pd.DataFrame, int]]: + out = [] + with pdfplumber.open(str(pdf_path)) as pdf: + for pno, page in enumerate(pdf.pages, 1): + try: + tables = page.extract_tables() or [] + except Exception: + tables = [] + for tbl in tables: + if not tbl or len(tbl) < 2: # need header + at least 1 row continue + df = pd.DataFrame(tbl[1:], columns=tbl[0]) + out.append((df, pno)) + return out - if not text or len(text.strip()) < 50: - logger.debug(f"Skipping short/empty page {page_num+1}") - continue + def _df_to_rows(self, df: pd.DataFrame) -> List[str]: + # Normalize like your XLSX rows + df = df.copy() + df = df.replace(r'\n', ' ', regex=True) + df.columns = [str(c).strip() for c in df.columns] + return [ " | ".join([str(v) for v in row if (pd.notna(v) and str(v).strip())]) + for _, row in df.iterrows() ] - metadata = { - "page": page_num + 1, - "source": str(file), - "filename": file.name, - "entity": entity, - "document_id": document_id, - "type": "pdf_page" - } + def _extract_prose_blocks(self, pdf_path: Path) -> List[Tuple[str, int]]: + blocks = [] + with pdfplumber.open(str(pdf_path)) as pdf: + for pno, page in enumerate(pdf.pages, 1): + try: + text = page.extract_text() or "" + except Exception: + text = "" + text = re.sub(r'[ \t]+\n', '\n', text) # unwrap ragged whitespace + text = re.sub(r'\n{3,}', '\n\n', text) + if len(text.strip()) >= 40: + blocks.append((text.strip(), pno)) + return blocks - self.stats['total_chunks'] += 1 + # ---------- Public API ---------- + def ingest_pdf(self, + file_path: str | Path, + entity: Optional[str] = None, + max_rewrite_chunks: int = 30, + min_chunk_score: int = 2, + delete_original_if_rewritten: bool = True, + prefer_tables_first: bool = True + ) -> Tuple[List[Dict[str, Any]], str, List[str]]: + """ + Returns (chunks, document_id, original_chunk_ids_to_delete) + """ + start = time.time() + self.stats = {k: 0.0 if 'time' in k else 0 for k in self.stats} + all_chunks: List[Dict[str, Any]] = [] + original_chunks_to_delete: List[str] = [] + doc_id = str(uuid.uuid4()) - if self._should_rewrite(text): - rewrite_candidates.append((text, metadata)) - else: - all_chunks.append({"content": text.strip(), "metadata": metadata}) - except Exception as e: - logger.error(f"❌ PDF read error: {e}") - return [], document_id + file = Path(file_path) + if not file.exists() or not file.is_file() or not file.suffix.lower() == ".pdf": + raise FileNotFoundError(f"Not a PDF: {file_path}") + if not entity or not isinstance(entity, str): + raise ValueError("Entity name must be provided") + entity = entity.strip().lower() - # -------- 3. Rewrite Candidates (if needed) -------- - rewritten_chunks = [] - try: - if self.chunk_rewriter and rewrite_candidates: - logger.info(f"🧠 Rewriting {min(len(rewrite_candidates), max_rewrite_chunks)} of {len(rewrite_candidates)} chunks") - rewrite_candidates = rewrite_candidates[:max_rewrite_chunks] - for text, metadata in rewrite_candidates: - rewritten = self._rewrite_chunk(text, metadata) - metadata = dict(metadata) # make a copy for safety - metadata["rewritten"] = True - rewritten_chunks.append({"content": rewritten, "metadata": metadata}) + # 0) Router: embedded spreadsheets? + embedded = self._find_embedded_spreadsheets(file) + if embedded: + # Save, then delegate to your XLSX flow for each + from your_xlsx_module import XLSXIngester # <-- import your class + xlsx_ingester = XLSXIngester(chunk_rewriter=self.chunk_rewriter) + for fname, data in embedded: + tmp = file.with_name(f"__embedded__{fname}") + with open(tmp, "wb") as f: f.write(data) + x_chunks, _, _ = xlsx_ingester.ingest_xlsx( + tmp, entity=entity, + max_rewrite_chunks=max_rewrite_chunks, + min_chunk_score=min_chunk_score, + delete_original_if_rewritten=delete_original_if_rewritten + ) + # Tag source and page unknown for embedded + for ch in x_chunks: + ch['metadata']['source_pdf'] = str(file) + ch['metadata']['embedded_file'] = fname + all_chunks.append(ch) + try: os.remove(tmp) + except Exception: pass + # Note: continue to extract PDF content as well (often desirable) + + # 1) Tables (Camelot → pdfplumber) + extraction_start = time.time() + table_chunks: List[Dict[str, Any]] = [] + if prefer_tables_first: + dfs = self._extract_tables_with_camelot(file) + if not dfs: + for df, pno in self._extract_tables_with_pdfplumber(file): + rows = self._df_to_rows(df) + if not rows: continue + chunks = self._batch_rows_by_token_count(rows) + for cidx, rows_batch in enumerate(chunks): + content = f"Detected Table (pdfplumber)\n" + "\n".join(rows_batch) + meta = { + "page": pno, + "source": str(file), + "filename": file.name, + "entity": entity, + "document_id": doc_id, + "type": "pdf_table", + "extractor": "pdfplumber" + } + table_chunks.append({'id': f"{doc_id}_chunk_{len(all_chunks)+len(table_chunks)}", + 'content': content, 'metadata': meta}) else: - rewritten_chunks = [{"content": text, "metadata": metadata} for text, metadata in rewrite_candidates] - except Exception as e: - logger.warning(f"⚠️ Error rewriting chunks: {e}") - for text, metadata in rewrite_candidates: - rewritten_chunks.append({"content": text, "metadata": metadata}) + # Camelot doesn't preserve page numbers directly; we’ll mark unknown unless available on t.parsing_report + for t_idx, df in enumerate(dfs): + rows = self._df_to_rows(df) + if not rows: continue + chunks = self._batch_rows_by_token_count(rows) + for cidx, rows_batch in enumerate(chunks): + content = f"Detected Table (camelot)\n" + "\n".join(rows_batch) + meta = { + "page": None, # could be added by parsing report if needed + "source": str(file), + "filename": file.name, + "entity": entity, + "document_id": doc_id, + "type": "pdf_table", + "extractor": "camelot", + "table_index": t_idx + } + table_chunks.append({'id': f"{doc_id}_chunk_{len(all_chunks)+len(table_chunks)}", + 'content': content, 'metadata': meta}) - all_chunks.extend(rewritten_chunks) + # 2) Prose blocks + prose_chunks: List[Dict[str, Any]] = [] + for text, pno in self._extract_prose_blocks(file): + meta = { + "page": pno, "source": str(file), "filename": file.name, + "entity": entity, "document_id": doc_id, "type": "pdf_page_text" + } + prose_chunks.append({'id': f"{doc_id}_chunk_{len(all_chunks)+len(table_chunks)+len(prose_chunks)}", + 'content': text, 'metadata': meta}) - # -------- 4. Finalize IDs and Metadata -------- - try: - for i, chunk in enumerate(all_chunks): - chunk["id"] = f"{document_id}_chunk_{i}" - chunk.setdefault("metadata", {}) - chunk["metadata"]["document_id"] = document_id - except Exception as e: - logger.warning(f"⚠️ Error finalizing chunk IDs: {e}") + extracted = (table_chunks + prose_chunks) if prefer_tables_first else (prose_chunks + table_chunks) + all_chunks.extend(extracted) + self.stats['extraction_time'] = time.time() - extraction_start + self.stats['total_chunks'] = len(all_chunks) + + # 3) Smart selection + rewriting (same semantics as XLSX) + if self.chunk_rewriter and max_rewrite_chunks > 0 and all_chunks: + # score + selection_start = time.time() + scored = [] + for i, ch in enumerate(all_chunks): + s = self._is_high_value_chunk(ch['content'], ch['metadata']) + if s >= min_chunk_score: + scored.append((ch['content'], ch['metadata'], i, s)) + self.stats['high_value_chunks'] += 1 + scored.sort(key=lambda x: x[3], reverse=True) + to_rewrite = [(t, m, idx) for (t, m, idx, _) in scored[:max_rewrite_chunks]] + self.stats['selection_time'] = time.time() - selection_start + + # rewrite + rewritten = self._batch_rewrite_chunks(to_rewrite) + for new_text, new_meta, original_idx in rewritten: + if new_meta.get('rewritten'): + original_id = all_chunks[original_idx]['id'] + if delete_original_if_rewritten: + # replace in place but mark original id for vector-store deletion + original_chunks_to_delete.append(original_id) + new_id = f"{original_id}_rewritten" + all_chunks[original_idx]['id'] = new_id + all_chunks[original_idx]['content'] = new_text + all_chunks[original_idx]['metadata'] = {**all_chunks[original_idx]['metadata'], **new_meta, + "original_chunk_id": original_id} - self.stats['processing_time'] = time.time() - start_time - logger.info(f"✅ PDF processing complete in {self.stats['processing_time']:.2f}s — Total: {len(all_chunks)}") + self.stats['processing_time'] = time.time() - start + logger.info(f"✅ PDF processed: {file.name} — chunks: {len(all_chunks)}; " + f"extract {self.stats['extraction_time']:.2f}s; " + f"rewrite {self.stats['rewriting_time']:.2f}s") - return all_chunks, document_id + return all_chunks, doc_id, original_chunks_to_delete diff --git a/ai/generative-ai-service/complex-document-rag/files/ingest_xlsx.py b/ai/generative-ai-service/complex-document-rag/files/ingest_xlsx.py index a3729ba4d..22d3d355a 100644 --- a/ai/generative-ai-service/complex-document-rag/files/ingest_xlsx.py +++ b/ai/generative-ai-service/complex-document-rag/files/ingest_xlsx.py @@ -116,8 +116,8 @@ def _batch_rows_by_token_count(self, rows: List[str], max_tokens: int = 400) -> return chunks - def _batch_rewrite_chunks(self, chunks_to_rewrite: List[Tuple[str, Dict[str, Any]]]) -> List[Tuple[str, Dict[str, Any]]]: - """Fast parallel batch rewriting""" + def _batch_rewrite_chunks(self, chunks_to_rewrite: List[Tuple[str, Dict[str, Any], int]]) -> List[Tuple[str, Dict[str, Any], int]]: + """Fast parallel batch rewriting - now returns tuples with indices""" if not chunks_to_rewrite or not self.chunk_rewriter: return chunks_to_rewrite @@ -140,22 +140,29 @@ def _batch_rewrite_chunks(self, chunks_to_rewrite: List[Tuple[str, Dict[str, Any logger.info(f"📦 Processing {len(batches)} batches of size {BATCH_SIZE}") - def process_batch(batch_idx: int, batch: List[Tuple[str, Dict[str, Any]]]): - batch_input = [{'text': text, 'metadata': metadata} for text, metadata in batch] + def process_batch(batch_idx: int, batch: List[Tuple[str, Dict[str, Any], int]]): + batch_input = [{'text': text, 'metadata': metadata} for text, metadata, _ in batch] try: logger.info(f" Processing batch {batch_idx + 1}/{len(batches)}") rewritten_texts = self.chunk_rewriter.rewrite_chunks_batch(batch_input, batch_size=BATCH_SIZE) batch_result = [] - for i, (original_text, metadata) in enumerate(batch): + for i, (original_text, metadata, chunk_idx) in enumerate(batch): rewritten_text = rewritten_texts[i] if i < len(rewritten_texts) else None - if rewritten_text and rewritten_text != original_text: + # Check for None (failure) or empty string (failure) explicitly + if rewritten_text is None or rewritten_text == "": + logger.warning(f" ⚠️ Chunk {chunk_idx} rewriting failed, keeping original") + batch_result.append((original_text, metadata, chunk_idx)) + elif rewritten_text != original_text: + # Successfully rewritten and different from original metadata = metadata.copy() metadata["rewritten"] = True + metadata["original_chunk_id"] = f"{metadata.get('document_id', '')}_chunk_{chunk_idx}" self.stats['rewritten_chunks'] += 1 - batch_result.append((rewritten_text, metadata)) + batch_result.append((rewritten_text, metadata, chunk_idx)) else: - batch_result.append((original_text, metadata)) + # Rewritten but same as original (no changes needed) + batch_result.append((original_text, metadata, chunk_idx)) logger.info(f" ✅ Batch {batch_idx + 1} complete") return batch_result @@ -178,19 +185,20 @@ def process_batch(batch_idx: int, batch: List[Tuple[str, Dict[str, Any]]]): else: # Fallback to sequential processing logger.info(f"🔄 Sequential rewriting for {len(chunks_to_rewrite)} chunks") - for text, metadata in chunks_to_rewrite: + for text, metadata, chunk_idx in chunks_to_rewrite: try: rewritten = self.chunk_rewriter.rewrite_chunk(text, metadata=metadata).strip() if rewritten: metadata = metadata.copy() metadata["rewritten"] = True + metadata["original_chunk_id"] = f"{metadata.get('document_id', '')}_chunk_{chunk_idx}" self.stats['rewritten_chunks'] += 1 - results.append((rewritten, metadata)) + results.append((rewritten, metadata, chunk_idx)) else: - results.append((text, metadata)) + results.append((text, metadata, chunk_idx)) except Exception as e: logger.warning(f"Failed to rewrite chunk: {e}") - results.append((text, metadata)) + results.append((text, metadata, chunk_idx)) self.stats['rewriting_time'] = time.time() - start_time return results @@ -200,9 +208,14 @@ def ingest_xlsx( file_path: str | Path, entity: Optional[str] = None, max_rewrite_chunks: int = 30, # Reasonable default - min_chunk_score: int = 2 # Only rewrite chunks with score >= 2 - ) -> Tuple[List[Dict[str, Any]], str]: - """Fast XLSX processing with smart chunk selection""" + min_chunk_score: int = 2, # Only rewrite chunks with score >= 2 + delete_original_if_rewritten: bool = True # New parameter + ) -> Tuple[List[Dict[str, Any]], str, List[str]]: + """Fast XLSX processing with smart chunk selection + + Returns: + Tuple of (chunks, document_id, original_chunk_ids_to_delete) + """ start_time = time.time() self.stats = { @@ -216,6 +229,7 @@ def ingest_xlsx( } all_chunks = [] document_id = str(uuid.uuid4()) + original_chunks_to_delete = [] # Validate inputs file = Path(file_path) @@ -297,38 +311,43 @@ def ingest_xlsx( # Smart chunk selection for rewriting selection_start = time.time() if self.chunk_rewriter and max_rewrite_chunks > 0: - # Score all chunks + # Score all chunks and include their indices scored_chunks = [] - for chunk in all_chunks: + for i, chunk in enumerate(all_chunks): score = self._is_high_value_chunk(chunk['content'], chunk['metadata']) if score >= min_chunk_score: - scored_chunks.append((chunk['content'], chunk['metadata'], score)) + scored_chunks.append((chunk['content'], chunk['metadata'], i, score)) self.stats['high_value_chunks'] += 1 # Sort by score and take top N - scored_chunks.sort(key=lambda x: x[2], reverse=True) - chunks_to_rewrite = [(text, meta) for text, meta, _ in scored_chunks[:max_rewrite_chunks]] + scored_chunks.sort(key=lambda x: x[3], reverse=True) + chunks_to_rewrite = [(text, meta, idx) for text, meta, idx, _ in scored_chunks[:max_rewrite_chunks]] self.stats['selection_time'] = time.time() - selection_start - logger.info(f"🎯 Selected {len(chunks_to_rewrite)} high-value chunks from {self.stats['high_value_chunks']} candidates in {self.stats['selection_time']:.2f}s") + logger.info(f"Selected {len(chunks_to_rewrite)} high-value chunks from {self.stats['high_value_chunks']} candidates in {self.stats['selection_time']:.2f}s") if chunks_to_rewrite: # Rewrite selected chunks rewritten = self._batch_rewrite_chunks(chunks_to_rewrite) - # Create mapping for quick lookup - rewritten_map = {} - for text, meta in rewritten: - if meta.get('rewritten'): - key = f"{meta['sheet']}_{meta.get('chunk_index', 0)}" - rewritten_map[key] = text - # Update original chunks with rewritten content - for chunk in all_chunks: - key = f"{chunk['metadata']['sheet']}_{chunk['metadata'].get('chunk_index', 0)}" - if key in rewritten_map: - chunk['content'] = rewritten_map[key] - chunk['metadata']['rewritten'] = True + for rewritten_text, rewritten_meta, original_idx in rewritten: + if rewritten_meta.get('rewritten'): + # Store the original chunk ID for deletion + original_chunk_id = all_chunks[original_idx]['id'] + if delete_original_if_rewritten: + original_chunks_to_delete.append(original_chunk_id) + + # Create NEW ID for rewritten chunk (append _rewritten) + new_chunk_id = f"{original_chunk_id}_rewritten" + + # Update the chunk with rewritten content and NEW ID + all_chunks[original_idx]['id'] = new_chunk_id + all_chunks[original_idx]['content'] = rewritten_text + all_chunks[original_idx]['metadata'] = rewritten_meta + all_chunks[original_idx]['metadata']['original_chunk_id'] = original_chunk_id + + logger.info(f"✅ Replaced chunk {original_idx} with rewritten version (new ID: {new_chunk_id})") self.stats['processing_time'] = time.time() - start_time @@ -339,6 +358,8 @@ def ingest_xlsx( logger.info(f"📊 Total chunks: {len(all_chunks)}") logger.info(f"🎯 High-value chunks: {self.stats['high_value_chunks']}") logger.info(f"🔥 Rewritten chunks: {self.stats['rewritten_chunks']}") + if original_chunks_to_delete: + logger.info(f"🗑️ Original chunks to delete: {len(original_chunks_to_delete)}") logger.info(f"\n⏱️ TIMING BREAKDOWN:") logger.info(f" Extraction: {self.stats['extraction_time']:.2f}s") logger.info(f" Selection: {self.stats['selection_time']:.2f}s") @@ -348,7 +369,7 @@ def ingest_xlsx( logger.info(f" Speed: {len(all_chunks)/self.stats['processing_time']:.1f} chunks/sec") logger.info(f"{'='*60}\n") - return all_chunks, document_id + return all_chunks, document_id, original_chunks_to_delete def main(): """CLI interface""" @@ -359,6 +380,7 @@ def main(): parser.add_argument("--max-rewrite", type=int, default=30, help="Maximum chunks to rewrite") parser.add_argument("--min-score", type=int, default=2, help="Minimum score for rewriting (0-5)") parser.add_argument("--no-rewrite", action="store_true", help="Skip chunk rewriting") + parser.add_argument("--keep-originals", action="store_true", help="Keep original chunks even if rewritten") args = parser.parse_args() @@ -388,11 +410,12 @@ def main(): # Process file try: - chunks, doc_id = processor.ingest_xlsx( + chunks, doc_id, chunks_to_delete = processor.ingest_xlsx( args.input, entity=args.entity, max_rewrite_chunks=args.max_rewrite, - min_chunk_score=args.min_score + min_chunk_score=args.min_score, + delete_original_if_rewritten=not args.keep_originals ) # Save results @@ -400,7 +423,8 @@ def main(): result_data = { "document_id": doc_id, "chunks": chunks, - "stats": processor.stats + "stats": processor.stats, + "original_chunks_to_delete": chunks_to_delete } with open(args.output, "w", encoding="utf-8") as f: diff --git a/ai/generative-ai-service/complex-document-rag/files/local_rag_agent.py b/ai/generative-ai-service/complex-document-rag/files/local_rag_agent.py index b5bf1c580..5b8b29f55 100644 --- a/ai/generative-ai-service/complex-document-rag/files/local_rag_agent.py +++ b/ai/generative-ai-service/complex-document-rag/files/local_rag_agent.py @@ -74,7 +74,7 @@ class OCIModelHandler: "grok-4": { "model_id": os.getenv("OCI_GROK_4_MODEL_ID"), "request_type": "generic", - "max_output_tokens": 120000, + "max_output_tokens": 8000, # Reduced from 120000 for faster response "default_params": { "temperature": 1, "top_p": 1 @@ -84,7 +84,7 @@ class OCIModelHandler: "model_id": os.getenv("OCI_GROK_3_MODEL_ID", os.getenv("GROK_MODEL_ID")), "request_type": "generic", - "max_output_tokens": 16000, + "max_output_tokens": 8000, # Reduced from 16000 for consistency "default_params": { "temperature": 0.7, "top_p": 0.9 @@ -94,7 +94,7 @@ class OCIModelHandler: "model_id": os.getenv("OCI_GROK_3_FAST_MODEL_ID", os.getenv("GROK_MODEL_ID")), "request_type": "generic", - "max_output_tokens": 16000, + "max_output_tokens": 4000, # Optimized for speed "default_params": { "temperature": 0.7, "top_p": 0.9 @@ -197,13 +197,29 @@ def __init__(self, model_name: str = "grok-3", config_profile: str = "DEFAULT", region = self.model_config.get("region", "us-chicago-1") self.endpoint = f"https://inference.generativeai.{region}.oci.oraclecloud.com" - # Initialize OCI client + # Initialize OCI client with better retry and timeout settings config = oci.config.from_file("~/.oci/config", config_profile) + + # Create a custom retry strategy for chunk rewriting operations + retry_strategy = oci.retry.RetryStrategyBuilder( + max_attempts=3, + retry_max_wait_between_calls_seconds=10, + retry_base_sleep_time_seconds=2, + retry_exponential_growth_multiplier=2, + retry_eligible_service_errors=[429, 500, 502, 503, 504], + service_error_retry_config={ + -1: [] # Retry on timeout errors + } + ).add_service_error_check( + service_error_retry_config={-1: []}, + service_error_retry_on_any_5xx=True + ).get_retry_strategy() + self.client = oci.generative_ai_inference.GenerativeAiInferenceClient( config=config, service_endpoint=self.endpoint, - retry_strategy=oci.retry.NoneRetryStrategy(), - timeout=(10, 240) + retry_strategy=retry_strategy, + timeout=(30, 120) # Increased timeout: 30s connect, 120s read for chunk rewriting ) print(f"✅ Initialized OCI handler for {model_name}") @@ -359,7 +375,7 @@ def get_model_info(self) -> Dict[str, Any]: class RAGSystem: def __init__(self, vector_store: EnhancedVectorStore = None, model_name: str = None, use_cot: bool = False, skip_analysis: bool = False, - quantization: str = None, use_oracle_db: bool = True, collection: str = "Multi-Collection", + quantization: str = None, use_oracle_db: bool = True, collection: str = "multi", embedding_model: str = "cohere-embed-multilingual-v3.0"): """Initialize local RAG agent with vector store and local LLM @@ -484,9 +500,54 @@ def __init__(self, vector_store: EnhancedVectorStore = None, model_name: str = N tokenizer=self.tokenizer ) logger.info(f"Agents initialized: {list(self.agents.keys())}") + # --- known tag cache loaded from vector store - helps identify entities in the query --- + self.known_tags: set[str] = set() + try: + self.refresh_known_tags() + except Exception as e: + logger.warning(f"[RAG] Could not load known tags on init: {e}") + def _vector_store_all_ids(self) -> list[str]: + """ + Return ALL canonical document/entity IDs (tags) from the vector store. + Tries a few common method names to avoid tight coupling. + """ + vs = self.vector_store + # Try common APIs + for attr in ("list_ids", "get_all_ids", "get_all_document_ids", "all_ids"): + if hasattr(vs, attr) and callable(getattr(vs, attr)): + try: + ids = getattr(vs, attr)() + return [str(x) for x in ids] + except Exception as e: + logger.debug(f"[RAG] {_safe_name(vs)}.{attr} failed: {e}") + # Fallback: try listing collections and aggregating + try: + if hasattr(vs, "list_collections"): + coll_names = vs.list_collections() + ids = [] + for c in coll_names: + try: + ids.extend(vs.list_ids(collection=c)) + except Exception: + pass + return [str(x) for x in ids] + except Exception as e: + logger.debug(f"[RAG] Could not enumerate collections: {e}") + return [] + def refresh_known_tags(self) -> None: + """ + Populate self.known_tags (lowercased) from the vector store. + Call this after any ingest/update that changes IDs. + """ + ids = self._vector_store_all_ids() + self.known_tags = {s.lower() for s in ids if isinstance(s, str)} + logger.info(f"[RAG] known_tags loaded: {len(self.known_tags)}") + def _safe_name(obj) -> str: + return getattr(obj, "__class__", type(obj)).__name__ + def _initialize_sub_agents(self, llm_model: str) -> bool: """ Initializes agents for agentic workflows (planner, researcher, etc.) @@ -521,22 +582,28 @@ def _initialize_sub_agents(self, llm_model: str) -> bool: def process_query_with_multi_collection_context(self, query: str, multi_collection_context: List[Dict[str, Any]], is_comparison_report: bool = False, - collection_mode: str = "multi") -> Dict[str, Any]: - """Process a query with pre-retrieved multi-collection context""" + collection_mode: str = "multi", + provided_entities: Optional[List[str]] = None) -> Dict[str, Any]: + """Process a query with pre-retrieved multi-collection context and optional provided entities""" logger.info(f"Processing query with {len(multi_collection_context)} multi-collection chunks") + if provided_entities: + logger.info(f"Using provided entities: {provided_entities}") if self.use_cot: - return self._process_query_with_report_agent(query, multi_collection_context, is_comparison_report, collection_mode=collection_mode) + return self._process_query_with_report_agent(query, multi_collection_context, is_comparison_report, + collection_mode=collection_mode, provided_entities=provided_entities) else: # For non-CoT mode, use the context directly return self._generate_response(query, multi_collection_context) + def _process_query_with_report_agent( self, query: str, multi_collection_context: Optional[List[Dict[str, Any]]] = None, is_comparison_report: bool = False, - collection_mode: str = "multi" + collection_mode: str = "multi", + provided_entities: Optional[List[str]] = None ) -> Dict[str, Any]: """ Report agent pipeline: @@ -558,8 +625,10 @@ def _process_query_with_report_agent( # STEP 1: Plan the report logger.info("Planning report sections...") + if provided_entities: + logger.info(f"Using provided entities for planning: {provided_entities}") try: - result = planner.plan(query, is_comparison_report=is_comparison_report) + result = planner.plan(query, is_comparison_report=is_comparison_report, provided_entities=provided_entities) if not isinstance(result, tuple) or len(result) != 3: raise ValueError(f"Planner returned unexpected format: {type(result)} → {result}") plan, entities, is_comparison = result @@ -799,7 +868,7 @@ def main(): parser = argparse.ArgumentParser(description="Query documents using local LLM") parser.add_argument("--query", required=True, help="Query to search for") parser.add_argument("--embed", default="oracle", choices=["oracle", "chromadb"], help="embed backend to use") - parser.add_argument("--model", default="qwen2", help="Model to use (default: qwen2)") + parser.add_argument("--model", default="grok3", help="Model to use (default: qwen2)") parser.add_argument("--collection", help="Collection to search (PDF, Repository, General Knowledge)") parser.add_argument("--use-cot", action="store_true", help="Use Chain of Thought reasoning") parser.add_argument("--store-path", default="embed", help="Path to ChromaDB store") diff --git a/ai/generative-ai-service/complex-document-rag/files/oci_embedding_handler.py b/ai/generative-ai-service/complex-document-rag/files/oci_embedding_handler.py index 2d7a7a19f..82404bb64 100644 --- a/ai/generative-ai-service/complex-document-rag/files/oci_embedding_handler.py +++ b/ai/generative-ai-service/complex-document-rag/files/oci_embedding_handler.py @@ -88,6 +88,9 @@ def __init__(self, config_profile: OCI config profile to use compartment_id: OCI compartment ID """ + # Load environment variables from .env file if not already loaded + load_dotenv() + self.model_name = model_name # Validate model name @@ -100,6 +103,10 @@ def __init__(self, # Set compartment ID - check both OCI_COMPARTMENT_ID and COMPARTMENT_ID for compatibility self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") or os.getenv("COMPARTMENT_ID") + # Log if compartment ID is missing + if not self.compartment_id: + logger.error("❌ No compartment ID found. Please set COMPARTMENT_ID or OCI_COMPARTMENT_ID in .env file") + # Set endpoint region based on model configuration (supports multiple OCI regions) endpoint_region = self.model_config.get("endpoint", "us-chicago-1") self.endpoint = f"https://inference.generativeai.{endpoint_region}.oci.oraclecloud.com" diff --git a/ai/generative-ai-service/complex-document-rag/files/requirements.txt b/ai/generative-ai-service/complex-document-rag/files/requirements.txt index c1f17bffb..9d81f7410 100644 --- a/ai/generative-ai-service/complex-document-rag/files/requirements.txt +++ b/ai/generative-ai-service/complex-document-rag/files/requirements.txt @@ -13,7 +13,7 @@ pdfplumber==0.11.4 python-docx==1.1.2 # NLP and Embeddings -transformers==4.53.0 +transformers==4.44.2 tokenizers==0.19.1 tiktoken==0.7.0 diff --git a/ai/generative-ai-service/complex-document-rag/files/vector_store.py b/ai/generative-ai-service/complex-document-rag/files/vector_store.py index 04eecc5b4..f5f4ce16f 100644 --- a/ai/generative-ai-service/complex-document-rag/files/vector_store.py +++ b/ai/generative-ai-service/complex-document-rag/files/vector_store.py @@ -4,7 +4,7 @@ Extends the existing VectorStore to support OCI Cohere embeddings alongside ChromaDB defaults """ from oci_embedding_handler import OCIEmbeddingHandler, EmbeddingModelManager -import logging +import logging, numbers from typing import List, Dict, Any, Optional, Union, Tuple from pathlib import Path import chromadb @@ -27,143 +27,137 @@ def __init__(self, *args, **kwargs): "VectorStore is an abstract base class. Use EnhancedVectorStore instead." ) - - class EnhancedVectorStore(VectorStore): """Enhanced vector store with multi-embedding model support (SAFER VERSION)""" - def __init__(self, persist_directory: str = "embeddings", embedding_model: str = "cohere-embed-multilingual-v3.0", embedder=None): + def __init__(self, persist_directory: str = "embeddings", + embedding_model: str = "cohere-embed-multilingual-v3.0", + embedder=None): self.embedding_manager = EmbeddingModelManager() - self.embedding_model_name = embedding_model # string (name) - self.embedder = embedder # object (has .embed_query/.embed_documents) + self.embedding_model_name = embedding_model + self.embedder = embedder self.embedding_dimensions = getattr(embedder, "model_config", {}).get("dimensions", None) if embedder else None - - # If embedder is provided, use it; otherwise fall back to embedding manager - if embedder: - self.embedding_model = embedder - else: - self.embedding_model = self.embedding_manager.get_model(embedding_model) + # Resolve embedding handler + self.embedding_model = embedder or self.embedding_manager.get_model(embedding_model) + + # Chroma client (ensure Settings import: from chromadb.config import Settings) self.client = chromadb.PersistentClient( path=persist_directory, - settings=Settings(allow_reset=True) + settings=Settings(allow_reset=True, anonymized_telemetry=False) ) - # Always get dimensions from the embedding manager or embedder - embedding_dim = None - if embedder: - # Use the provided embedder's dimensions - info = embedder.get_model_info() - if info and "dimensions" in info: - embedding_dim = info["dimensions"] - else: - raise ValueError( - f"Cannot determine embedding dimensions from provided embedder." - ) - elif isinstance(self.embedding_model, str): - # Try to get from embedding_manager - embedding_info = self.embedding_manager.get_model_info(self.embedding_model_name) - if embedding_info and "dimensions" in embedding_info: - embedding_dim = embedding_info["dimensions"] - else: - raise ValueError( - f"Unknown embedding dimension for model '{self.embedding_model_name}'." - " Please update your EmbeddingModelManager to include this info." - ) - else: - # Should have a get_model_info() method - info = self.embedding_model.get_model_info() - if info and "dimensions" in info: - embedding_dim = info["dimensions"] + # Resolve dimensions once + self._embedding_dim = self._resolve_dimensions() + + # Internal maps/handles + self.collections: dict[str, Any] = {} + self.collection_map = self.collections # alias + + # Create/bind base collections (pdf/xlsx) for current model+dim + self._ensure_base_collections(self._embedding_dim) + + logger.info(f"✅ Enhanced vector store initialized with {self.embedding_model_name} ({self._embedding_dim}D)") + + # --- Utility: sanitize metadata before sending to Chroma --- + def _safe_metadata(self, metadata: dict) -> dict: + """Ensure Chroma-compatible metadata (convert everything non-str → str).""" + safe = {} + for k, v in (metadata or {}).items(): + key = str(k) + if isinstance(v, str): + safe[key] = v + elif isinstance(v, numbers.Number): # catches numpy.int64, Decimal, etc. + safe[key] = str(v) + elif v is None: + continue else: - raise ValueError( - f"Cannot determine embedding dimensions for non-string embedding model {self.embedding_model}." - ) + safe[key] = str(v) + return safe + + def _as_int(self, x): + try: + return int(x) + except Exception: + return None + def _resolve_dimensions(self) -> int: + if self.embedder: + info = self.embedder.get_model_info() + if info and "dimensions" in info: + return int(info["dimensions"]) + raise ValueError("Cannot determine embedding dimensions from provided embedder.") + if isinstance(self.embedding_model, str): + info = self.embedding_manager.get_model_info(self.embedding_model_name) + if info and "dimensions" in info: + return int(info["dimensions"]) + raise ValueError(f"Unknown embedding dimension for model '{self.embedding_model_name}'.") + # non-string handler + info = self.embedding_model.get_model_info() + if info and "dimensions" in info: + return int(info["dimensions"]) + raise ValueError("Cannot determine embedding dimensions for non-string embedding model.") + + def _ensure_base_collections(self, embedding_dim: int): + base_collection_names = ["pdf_documents", "xlsx_documents"] metadata = { "hnsw:space": "cosine", - "embedding_model": self.embedding_model_name, - "embedding_dimensions": embedding_dim + "embedding_model": self.embedding_model_name, # keep int in memory + "embedding_dimensions": embedding_dim # keep int in memory } - base_collection_names = [ - "pdf_documents", "xlsx_documents" - ] - - self.collections = {} - for base_name in base_collection_names: full_name = f"{base_name}_{self.embedding_model_name}_{embedding_dim}" - try: - # Check for exact match first - existing_collections = self.client.list_collections() - by_name = {c.name: c for c in existing_collections} - if full_name in by_name: - coll = by_name[full_name] - actual_dim = coll.metadata.get("embedding_dimensions", None) - if actual_dim != embedding_dim: - # This should never happen unless DB is corrupt - logger.error( - f"❌ Dimension mismatch for collection '{full_name}'. Expected {embedding_dim}, found {actual_dim}." - ) - raise ValueError( - f"Collection '{full_name}' has dim {actual_dim}, but expected {embedding_dim}." - ) - collection = coll - logger.info(f"🎯 Using existing collection '{full_name}' ({embedding_dim}D, {coll.count()} chunks)") - else: - # Safe: only ever create the *fully qualified* name - collection = self.client.get_or_create_collection( - name=full_name, - metadata=metadata - ) - logger.info(f"🗂️ Created new collection '{full_name}' with dimension {embedding_dim}") + # Prefer fast path: get_or_create with safe metadata + coll = self.client.get_or_create_collection( + name=full_name, + metadata=self._safe_metadata(metadata) # ← sanitize only here + ) - self.collections[full_name] = collection + # Defensive dim check (cast back to int if Chroma stored as str) + actual_dim = self._as_int((coll.metadata or {}).get("embedding_dimensions")) + if actual_dim and actual_dim != embedding_dim: + logger.error(f"❌ Dimension mismatch for '{full_name}'. Expected {embedding_dim}, found {actual_dim}.") + raise ValueError(f"Collection '{full_name}' has dim {actual_dim}, expected {embedding_dim}.") - # For direct access: always the selected model/dim + self.collections[full_name] = coll if base_name == "pdf_documents": - self.pdf_collection = collection - elif base_name == "xlsx_documents": - self.xlsx_collection = collection + self.pdf_collection = coll + self.current_pdf_collection_name = full_name + else: + self.xlsx_collection = coll + self.current_xlsx_collection_name = full_name + logger.info(f"🗂️ Ready collection '{full_name}' ({embedding_dim}D, {coll.count()} chunks)") except Exception as e: logger.error(f"❌ Failed to create or get collection '{full_name}': {e}") raise - # Only include full names in the map; never ambiguous short names - self.collection_map = self.collections - - logger.info(f"✅ Enhanced vector store initialized with {embedding_model} ({embedding_dim}D)") - - def get_collection_key(self, base_name: str) -> str: - # Build the correct key for a base collection name - embedding_dim = ( - self.get_embedding_info()["dimensions"] - if hasattr(self, "get_embedding_info") - else 1024 - ) - return f"{base_name}_{self.embedding_model_name}_{embedding_dim}" + return f"{base_name}_{self.embedding_model_name}_{self._embedding_dim}" def _find_collection_variants(self, base_name: str): """ - Yield (name, collection) for all collections in the DB that start with base_name + "_", - across ANY embedding model/dimension (not just the ones cached at init). + Yield (name, collection) for all collections that start with base_name+"_". + Never create here—only fetch existing collections. """ for c in self.client.list_collections(): try: - name = c.name - except Exception: - # Some clients return plain dicts name = getattr(c, "name", None) or (c.get("name") if isinstance(c, dict) else None) + except Exception: + name = None if not name: continue - if name.startswith(base_name + "_"): - # get_or_create is fine; if it exists it just returns it - yield name, self.client.get_or_create_collection(name=name) + if not name.startswith(base_name + "_"): + continue + try: + coll = self.client.get_collection(name=name) # ← get (NOT get_or_create) + yield name, coll + except Exception as e: + logger.warning(f"Skip collection {name}: {e}") + def list_documents(self, collection_name: str) -> List[Dict[str, Any]]: """ @@ -524,145 +518,249 @@ def _add_cite(self, meta: Union[Dict[str, Any], "Metadata"]) -> Dict[str, Any]: return meta + def delete_chunks(self, collection_name: str, chunk_ids: List[str]): + """Delete specific chunks from a collection by their IDs + + Args: + collection_name: Name of the collection (e.g., 'xlsx_documents', 'pdf_documents') + chunk_ids: List of chunk IDs to delete + """ + if not chunk_ids: + return + + try: + # Get the appropriate collection + if collection_name == "xlsx_documents": + collection = self.xlsx_collection + elif collection_name == "pdf_documents": + collection = self.pdf_collection + else: + # Try to get from collection map + collection = self.collection_map.get(collection_name) + if not collection: + # Try with current model/dimension suffix + full_name = self.get_collection_key(collection_name) + collection = self.collection_map.get(full_name) + + if not collection: + logger.error(f"Collection {collection_name} not found") + return + + # Delete the chunks + collection.delete(ids=chunk_ids) + logger.info(f"✅ Deleted {len(chunk_ids)} chunks from {collection_name}") + + except Exception as e: + logger.error(f"❌ Failed to delete chunks: {e}") + raise + def add_xlsx_chunks(self, chunks: List[Dict[str, Any]], document_id: str): """Add XLSX chunks to the vector store with proper embedding handling""" if not chunks: return - + # Extract texts and metadata - texts = [chunk["content"] for chunk in chunks] - metadatas = [chunk["metadata"] for chunk in chunks] - ids = [chunk["id"] for chunk in chunks] - - # Check collection metadata to see what dimensions are expected + texts = [c["content"] for c in chunks] + metadatas = [self._add_cite(c.get("metadata", {})) for c in chunks] # add cite & normalize + ids = [c["id"] for c in chunks] + + # Normalize expected dimensions/model from collection metadata collection_metadata = self.xlsx_collection.metadata or {} - expected_dimensions = collection_metadata.get('embedding_dimensions') - expected_model = collection_metadata.get('embedding_model') - - # Handle embeddings based on model type + expected_dimensions = self._as_int(collection_metadata.get("embedding_dimensions")) + expected_model = collection_metadata.get("embedding_model") + + # Path A: chroma-default (Chroma embeds on add) if isinstance(self.embedding_model, str): - # ChromaDB default - let ChromaDB handle embeddings + # If the collection expects non-384, error early (your policy) if expected_dimensions and expected_dimensions != 384: logger.error(f"❌ Collection expects {expected_dimensions}D but using ChromaDB default (384D)") - raise ValueError(f"Dimension mismatch: collection expects {expected_dimensions}D, ChromaDB default is 384D") - - self.xlsx_collection.add( - documents=texts, - metadatas=metadatas, - ids=ids - ) - else: - # Use OCI embeddings + raise ValueError( + f"Dimension mismatch: collection expects {expected_dimensions}D, ChromaDB default is 384D" + ) + + # Optional: warn if the collection was created without an embedding function bound (older Chroma) try: - embeddings = self.embedding_model.embed_documents(texts) - actual_dimensions = len(embeddings[0]) if embeddings and embeddings[0] else 0 - - if expected_dimensions and actual_dimensions != expected_dimensions: - # Try to find or create the correct collection - correct_collection_name = f"xlsx_documents_{self.embedding_model_name}_{actual_dimensions}" - logger.warning(f"⚠️ Dimension mismatch: collection '{self.xlsx_collection.name}' expects {expected_dimensions}D, embedder produces {actual_dimensions}D") - logger.info(f"🔍 Looking for correct collection: {correct_collection_name}") - - try: - # Try to get the correct collection - correct_collection = self.client.get_collection(correct_collection_name) - logger.info(f"✅ Found correct collection: {correct_collection_name}") - except: - # Create new collection with correct dimensions - metadata = { - "hnsw:space": "cosine", - "embedding_model": self.embedding_model_name, - "embedding_dimensions": actual_dimensions - } - correct_collection = self.client.create_collection( - name=correct_collection_name, - metadata=metadata - ) - logger.info(f"✅ Created new collection: {correct_collection_name}") - - # Add to the correct collection - correct_collection.add( - documents=texts, - metadatas=metadatas, - ids=ids, - embeddings=embeddings - ) - - # Update the reference for future use - self.xlsx_collection = correct_collection - self.collections[correct_collection_name] = correct_collection - - logger.info(f"✅ Added {len(chunks)} XLSX chunks to {correct_collection_name}") - else: - # Dimensions match, proceed normally - self.xlsx_collection.add( - documents=texts, - metadatas=metadatas, - ids=ids, - embeddings=embeddings - ) - logger.info(f"✅ Added {len(chunks)} XLSX chunks to {self.embedding_model_name}") - + self.xlsx_collection.add(documents=["probe"], metadatas=[{}], ids=["__probe__tmp__"]) + self.xlsx_collection.delete(ids=["__probe__tmp__"]) except Exception as e: - logger.error(f"❌ Failed to add chunks with OCI embeddings: {e}") - raise # Don't silently fall back - this causes dimension mismatches + logger.warning(f"⚠️ Chroma default embedding may not be bound; add() failed probe: {e}") + + # Add documents directly (Chroma will embed) + # Consider batching if many chunks + self.xlsx_collection.add(documents=texts, metadatas=metadatas, ids=ids) + logger.info(f"✅ Added {len(chunks)} XLSX chunks to {self.embedding_model_name} (chroma-default)") + return + + # Path B: OCI (you provide embeddings explicitly) + try: + embeddings = self.embedding_model.embed_documents(texts) + if not embeddings or not embeddings[0] or not hasattr(embeddings[0], "__len__"): + raise RuntimeError("Embedder returned empty/invalid embeddings") + + actual_dimensions = len(embeddings[0]) + + if expected_dimensions and actual_dimensions != expected_dimensions: + # Try to find or create the correct collection + correct_collection_name = f"xlsx_documents_{self.embedding_model_name}_{actual_dimensions}" + logger.warning( + f"⚠️ Dimension mismatch: collection '{self.xlsx_collection.name}' " + f"expects {expected_dimensions}D, embedder produces {actual_dimensions}D" + ) + logger.info(f"🔍 Looking for correct collection: {correct_collection_name}") + + try: + correct_collection = self.client.get_collection(correct_collection_name) + logger.info(f"✅ Found correct collection: {correct_collection_name}") + except Exception: + # Create new collection with correct dimensions (sanitize metadata for Chroma) + metadata = { + "hnsw:space": "cosine", + "embedding_model": self.embedding_model_name, + "embedding_dimensions": actual_dimensions, # keep as int internally + } + correct_collection = self.client.create_collection( + name=correct_collection_name, + metadata=self._safe_metadata(metadata) # ← sanitize only here + ) + logger.info(f"✅ Created new collection: {correct_collection_name}") + + # Add to the correct collection (explicit vectors) + # Consider batching if many chunks + correct_collection.add( + documents=texts, + metadatas=metadatas, + ids=ids, + embeddings=embeddings + ) + + # Update the reference for future use + self.xlsx_collection = correct_collection + self.collections[correct_collection_name] = correct_collection + + logger.info(f"✅ Added {len(chunks)} XLSX chunks to {correct_collection_name}") + else: + # Dimensions match, proceed normally + self.xlsx_collection.add( + documents=texts, + metadatas=metadatas, + ids=ids, + embeddings=embeddings + ) + logger.info(f"✅ Added {len(chunks)} XLSX chunks to {self.embedding_model_name}") + + except Exception as e: + logger.error(f"❌ Failed to add chunks with OCI embeddings: {e}") + raise # Keep explicit; prevents silent dimension drift + def add_pdf_chunks(self, chunks: List[Dict[str, Any]], document_id: str): - """Add PDF chunks to the vector store with proper embedding handling""" + """Add PDF chunks to the vector store with proper embedding handling.""" if not chunks: return - - # Extract texts and metadata - texts = [chunk["content"] for chunk in chunks] - metadatas = [chunk["metadata"] for chunk in chunks] - ids = [chunk["id"] for chunk in chunks] - - # Check collection metadata to see what dimensions are expected - collection_metadata = self.pdf_collection.metadata or {} - expected_dimensions = collection_metadata.get('embedding_dimensions') - expected_model = collection_metadata.get('embedding_model') - - # Handle embeddings based on model type and expected dimensions + + # Extract texts and metadata; add cite + normalize metadata + texts = [c["content"] for c in chunks] + metadatas = [self._add_cite(c.get("metadata", {})) for c in chunks] + ids = [c["id"] for c in chunks] + + # Collection expectations (cast back to int to avoid string/int mismatches) + coll_meta = self.pdf_collection.metadata or {} + expected_dimensions = self._as_int(coll_meta.get("embedding_dimensions")) + expected_model = coll_meta.get("embedding_model") + + # A) chroma-default path (Chroma embeds on add) if isinstance(self.embedding_model, str): - # String identifier - check if it matches expected model if expected_model and self.embedding_model_name != expected_model: - logger.warning(f"⚠️ Model mismatch: collection expects '{expected_model}', got '{self.embedding_model_name}'") - - if expected_dimensions == 384 or self.embedding_model_name == "chromadb-default": - # ChromaDB default - let ChromaDB handle embeddings - logger.info(f"📝 Using ChromaDB default embeddings ({expected_dimensions or 384}D)") - self.pdf_collection.add( + logger.warning( + f"⚠️ Model mismatch: collection expects '{expected_model}', got '{self.embedding_model_name}'" + ) + + # Your policy: chroma-default is 384D only + if expected_dimensions and expected_dimensions != 384: + raise ValueError( + f"Dimension mismatch: collection expects {expected_dimensions}D, " + f"but chroma-default produces 384D. Recreate the collection with chroma-default " + f"or switch to the correct OCI embedder." + ) + + # Optional: probe add for older Chroma builds without an embedding_function bound + try: + self.pdf_collection.add(documents=["__probe__"], metadatas=[{}], ids=["__probe__"]) + self.pdf_collection.delete(ids=["__probe__"]) + except Exception as e: + logger.warning(f"⚠️ Chroma default embedder may not be bound; add() probe failed: {e}") + + # Add (consider batching if very large) + self.pdf_collection.add(documents=texts, metadatas=metadatas, ids=ids) + logger.info(f"✅ Added {len(chunks)} PDF chunks via chroma-default (384D)") + return + + # B) OCI path (explicit embeddings) + try: + embeddings = self.embedding_model.embed_documents(texts) + if not embeddings or not embeddings[0] or not hasattr(embeddings[0], "__len__"): + raise RuntimeError("Embedder returned empty/invalid embeddings") + + actual_dimensions = len(embeddings[0]) + + # If the target collection's dim doesn't match, route/create the correct one + if expected_dimensions and actual_dimensions != expected_dimensions: + logger.warning( + f"⚠️ Dimension mismatch: collection '{self.pdf_collection.name}' expects " + f"{expected_dimensions}D, embedder produced {actual_dimensions}D" + ) + correct_name = f"pdf_documents_{self.embedding_model_name}_{actual_dimensions}" + try: + correct_collection = self.client.get_collection(correct_name) + # Sanity check: if it already contains data of a different dim (shouldn’t happen), bail + probe_meta = correct_collection.metadata or {} + probe_dim = self._as_int(probe_meta.get("embedding_dimensions")) + if probe_dim and probe_dim != actual_dimensions: + raise RuntimeError( + f"Existing collection '{correct_name}' is {probe_dim}D, expected {actual_dimensions}D" + ) + logger.info(f"✅ Found correct PDF collection: {correct_name}") + except Exception: + # Create with sanitized metadata (only at API boundary) + md = { + "hnsw:space": "cosine", + "embedding_model": self.embedding_model_name, + "embedding_dimensions": actual_dimensions, # keep int internally + } + correct_collection = self.client.get_or_create_collection( + name=correct_name, + metadata=self._safe_metadata(md) # sanitize here + ) + logger.info(f"🆕 Created PDF collection: {correct_name}") + + # Add to the correct collection + correct_collection.add( documents=texts, metadatas=metadatas, - ids=ids + ids=ids, + embeddings=embeddings ) + + # Re-point handles + self.pdf_collection = correct_collection + self.collections[correct_name] = correct_collection + self.current_pdf_collection_name = correct_name + + logger.info(f"✅ Added {len(chunks)} PDF chunks to {correct_name}") else: - # Expected OCI model but got string - this is a configuration error - logger.error(f"❌ Configuration error: Expected {expected_model} ({expected_dimensions}D) but OCI embedding handler failed to initialize") - logger.error(f"💡 Falling back to ChromaDB default, but this will cause dimension mismatch!") - raise ValueError(f"Cannot add {expected_dimensions}D embeddings using ChromaDB default (384D). Please fix OCI configuration or recreate collection with chromadb-default.") - else: - # Use OCI embeddings - try: - embeddings = self.embedding_model.embed_documents(texts) - actual_dimensions = len(embeddings[0]) if embeddings and embeddings[0] else 0 - - if expected_dimensions and actual_dimensions != expected_dimensions: - logger.error(f"❌ Dimension mismatch: collection expects {expected_dimensions}D, embedder produces {actual_dimensions}D") - raise ValueError(f"Dimension mismatch: collection expects {expected_dimensions}D, got {actual_dimensions}D") - - logger.info(f"📝 Using OCI embeddings ({actual_dimensions}D)") + # Dimensions match; add directly self.pdf_collection.add( documents=texts, metadatas=metadatas, ids=ids, embeddings=embeddings ) - except Exception as e: - logger.error(f"❌ Failed to add PDF chunks with OCI embeddings: {e}") - raise # Don't fall back silently - this causes dimension mismatches - - logger.info(f"✅ Added {len(chunks)} PDF chunks to {self.embedding_model_name}") + logger.info(f"✅ Added {len(chunks)} PDF chunks ({actual_dimensions}D)") + + except Exception as e: + logger.error(f"❌ Failed to add PDF chunks with OCI embeddings: {e}") + raise # keep explicit; prevents silent dimension drift + @@ -875,7 +973,7 @@ def query_pdf_collection( } self.pdf_collection = self.client.get_or_create_collection( name=correct_collection_name, - metadata=metadata + metadata=self._safe_metadata(metadata) ) logger.info(f"✅ Created new PDF collection: {correct_collection_name}") actual_dim = handler_dim @@ -923,75 +1021,6 @@ def query_pdf_collection( return [] - def OLD_query_pdf_collection(self, query: str, n_results: int = 3, entity: Optional[str] = None, add_cite: bool = False) -> List[Dict[str, Any]]: - """Query PDF collection with embedding support and optional citation markup.""" - try: - # Build filter - where_filter = {"entity": entity.lower()} if entity else None - - # ✅ Minimal guard – blow up early if dims mismatch - if (self.pdf_collection.metadata or {}).get("embedding_dimensions") != (self.get_embedding_info() or {}).get("dimensions"): - raise ValueError( - f"EMBEDDING_DIMENSION_MISMATCH: collection expects " - f"{(self.pdf_collection.metadata or {}).get('embedding_dimensions')}D, " - f"current handler has {(self.get_embedding_info() or {}).get('dimensions')}D" - ) - - # Query by embedding or text, depending on backend - if isinstance(self.embedding_model, str): - # ChromaDB default - results = self.pdf_collection.query( - query_texts=[query], - n_results=n_results, - where=where_filter, - include=["documents", "metadatas", "distances"] - ) - else: - try: - query_embedding = self.embedding_model.embed_query(query) - results = self.pdf_collection.query( - query_embeddings=[query_embedding], - n_results=n_results, - where=where_filter, - include=["documents", "metadatas", "distances"] - ) - except Exception as e: - logger.error(f"❌ OCI query embedding failed: {e}") - # Fallback to text query - results = self.pdf_collection.query( - query_texts=[query], - n_results=n_results, - where=where_filter, - include=["documents", "metadatas", "distances"] - ) - - # Format results with optional citation - formatted_results = [] - docs = results.get("documents", [[]])[0] - metas = results.get("metadatas", [[]])[0] - dists = results.get("distances", [[]])[0] if "distances" in results else [0.0] * len(docs) - - for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)): - out = { - "content": doc, - "metadata": meta if meta else {}, - "distance": dist - } - if add_cite and hasattr(self, "_add_cite"): - meta_with_cite = self._add_cite(meta) - out["metadata"] = meta_with_cite - out["content"] = f"{doc} {meta_with_cite['cite']}" - formatted_results.append(out) - - return formatted_results - - except Exception as e: - logger.error(f"❌ Error querying PDF collection: {e}") - return [] - - - - def inspect_xlsx_chunk_metadata(self, limit: int = 10): """ Print stored metadata from the XLSX vector store for debugging. @@ -1070,14 +1099,21 @@ def bind_collections_for_model(self, embedding_model: str) -> None: "embedding_model": self.embedding_model_name, "embedding_dimensions": embedding_dim } - + logger.info( + "Create/get collections: PDF=%r, XLSX=%r | meta=%r (dim_field=%s:%s)", + pdf_name, + xlsx_name, + metadata, + "embedding_dimensions", + type(metadata.get("embedding_dimensions")).__name__, + ) self.pdf_collection = self.client.get_or_create_collection( name=pdf_name, - metadata=metadata + metadata=self._safe_metadata(metadata) ) self.xlsx_collection = self.client.get_or_create_collection( name=xlsx_name, - metadata=metadata + metadata=self._safe_metadata(metadata) ) # Cache for debugging