### Imports

In [167]:
from __future__ import annotations
from pathlib import Path
from typing import List, Dict, Any, Tuple, Literal
from PIL import Image
from collections import Counter, defaultdict
import re, math, json
import fitz           
import pdfplumber
from unstructured.partition.pdf import partition_pdf
from dotenv import load_dotenv
import anthropic
import hashlib
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
from operator import itemgetter
import pandas as pd
import cohere
from copy import deepcopy

import base64
import io
import os
from io import BytesIO
import tqdm

from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.retrievers.multi_vector import SearchType
from langchain.storage import InMemoryStore
from langchain_core.documents import Document
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from uuid import uuid4
import time
import gzip, pickle

In [135]:
load_dotenv()

True

In [3]:
client = anthropic.Anthropic()

In [4]:
pdf_dir = '../instructions/Dataset/'
image_output_path = '../generated_images/'

### Разбиение на чанки

In [5]:

NUM_RE = re.compile(r"^[\s\-\+\(\)]*[\d]+([.,\s]\d+)*\s*[%₸$€₽KZTUSD EURRUB]*$", re.I)

def is_numericish(s: str) -> bool:
    s = (s or "").strip()
    if not s: return False
    if re.search(r"\d", s) and len(s) <= 32:
        return True if NUM_RE.search(s) else False
    return False

def frac_numeric_cells(table_rows: List[List[str]]) -> float:
    cells = [c for row in table_rows for c in row]
    if not cells: return 0.0
    return sum(is_numericish(c) for c in cells) / len(cells)

def frac_numeric_rows(table_rows: List[List[str]]) -> float:
    if not table_rows: return 0.0
    cnt = 0
    for row in table_rows:
        if not row: 
            continue
        nonempty = [c for c in row if (c or "").strip()]
        if not nonempty:
            continue
        if sum(is_numericish(c) for c in nonempty) / len(nonempty) >= 0.5:
            cnt += 1
    return cnt / max(1, len(table_rows))

def table_quality_ok(rows: List[List[str]]) -> bool:
    if not rows: return False
    R = len(rows)
    C = max((len(r) for r in rows), default=0)
    if R < 2 or C < 2:
        return False
    cells = [c for r in rows for c in r]
    nonempty_frac = sum(1 for c in cells if (c or "").strip()) / max(1, len(cells))
    if nonempty_frac < 0.35:
        return False
    if frac_numeric_rows(rows) < 0.25:  
        return False
    mean_len = sum(len((c or "").strip()) for c in cells) / max(1, len(cells))
    if mean_len > 120:
        return False
    return True

def list_to_camelot_pages(pages_1based: List[int]) -> str:
    if not pages_1based: return ""
    pages = sorted(set(pages_1based))
    rngs, s = [], pages[0]
    for p in pages[1:]:
        if p != pages[pages.index(p)-1] + 1:
            rngs.append((s, pages[pages.index(p)-1]))
            s = p
    rngs.append((s, pages[-1]))
    return ",".join(f"{a}-{b}" if a != b else f"{a}" for a,b in rngs)

# ===================== 1) скрининг страниц =====================

def screen_pages(pdf_path: str, min_words: int = 10) -> Tuple[List[int], List[int]]:
    text_pages, image_pages = [], []
    doc = fitz.open(pdf_path)
    try:
        for i in range(len(doc)):
            words = doc[i].get_text("words")
            (text_pages if len(words) >= min_words else image_pages).append(i+1)
    finally:
        doc.close()
    return text_pages, image_pages

# ===================== 2) оценка «табличности» страницы =====================

def rulings_score(pdf_path: str, page_no: int, min_len: float = 60.0) -> float:
    """
    Считает количество длинных горизонтальных/вертикальных линий на странице
    по векторной графике (drawings). Возвращает скаляр 0..1.
    Поддерживает tuple-формат items из get_drawings().
    """
    def as_xy(pt):
        # pt может быть fitz.Point или (x, y)
        try:
            return float(pt[0]), float(pt[1])
        except Exception:
            return float(pt.x), float(pt.y)

    doc = fitz.open(pdf_path)
    try:
        page = doc[page_no - 1]
        drawings = page.get_drawings()
        H = V = 0

        for d in drawings:
            items = d["items"] if isinstance(d, dict) else getattr(d, "items", [])
            for it in items:
                if isinstance(it, tuple):
                    op = it[0]

                    if op == "l" and len(it) >= 3:
                        p0, p1 = it[1], it[2]
                        x0, y0 = as_xy(p0)
                        x1, y1 = as_xy(p1)
                        dx, dy = abs(x1 - x0), abs(y1 - y0)
                        length = math.hypot(dx, dy)
                        if length >= min_len:
                            if dy <= 1.0: H += 1
                            if dx <= 1.0: V += 1

                    elif op == "re" and len(it) >= 2:
                        rect = it[1]
                        try:
                            x0, y0, x1, y1 = float(rect.x0), float(rect.y0), float(rect.x1), float(rect.y1)
                        except Exception:
                            x0, y0, x1, y1 = rect  # если это tuple
                        edges = [
                            (x0, y0, x1, y0), (x1, y0, x1, y1),
                            (x1, y1, x0, y1), (x0, y1, x0, y0),
                        ]
                        for ex0, ey0, ex1, ey1 in edges:
                            dx, dy = abs(ex1 - ex0), abs(ey1 - ey0)
                            length = math.hypot(dx, dy)
                            if length >= min_len:
                                if dy <= 1.0: H += 1
                                if dx <= 1.0: V += 1


                elif isinstance(it, dict):
                    pts = it.get("points")
                    if pts and len(pts) >= 2:
                        p0, p1 = pts[0], pts[-1]
                        x0, y0 = as_xy(p0)
                        x1, y1 = as_xy(p1)
                        dx, dy = abs(x1 - x0), abs(y1 - y0)
                        length = math.hypot(dx, dy)
                        if length >= min_len:
                            if dy <= 1.0: H += 1
                            if dx <= 1.0: V += 1

        return min(1.0, (H / 6.0 + V / 4.0))
    finally:
        doc.close()

def column_grid_score(pdf_path: str, page_no: int, x_tol: int = 6) -> float:
    """
    На основе слов: считаем повторяющиеся x-координаты начала слов (кластеры колонок).
    Чем больше устойчивых вертикалей и строк, тем «табличнее».
    """
    doc = fitz.open(pdf_path)
    try:
        page = doc[page_no-1]
        words = page.get_text("words") 
        if not words:
            return 0.0
        rows = {}
        for x0,y0,x1,y1,txt, *_ in words:
            key = round(y0/5)*5
            rows.setdefault(key, []).append((x0, txt))
        xbins = Counter()
        row_cnt = 0
        for _, items in rows.items():
            if len(items) < 3:
                continue
            row_cnt += 1
            xs = [round(x/x_tol)*x_tol for x,_ in items]
            for x in set(xs):
                xbins[x] += 1
        if row_cnt == 0:
            return 0.0
        stable_cols = sum(1 for _,c in xbins.items() if c >= max(2, int(0.3*row_cnt)))
        # нормировка: 0..1
        return min(1.0, stable_cols / 6.0)
    finally:
        doc.close()

def pick_table_pages(pdf_path: str, text_pages: List[int],
                     thr_rulings: float = 0.25, thr_cols: float = 0.35) -> List[int]:
    candidates = []
    for p in text_pages:
        rs = rulings_score(pdf_path, p)
        cs = column_grid_score(pdf_path, p)
        if (rs >= thr_rulings and cs >= thr_cols) or (rs >= thr_rulings*2) or (cs >= thr_cols*1.5):
            candidates.append(p)
    return candidates

# ===================== 3) извлечение таблиц только на кандидатов =====================

def extract_tables_on_candidates(pdf_path: str, candidate_pages: List[int]) -> List[Dict[str, Any]]:
    tables: List[Dict[str, Any]] = []
    if not candidate_pages:
        return tables

    pages_str = list_to_camelot_pages(candidate_pages)

    def _add_rows(rows: List[List[str]], page: int, parser: str):
        if not table_quality_ok(rows):
            return
        csv = "\n".join([",".join([(c or "").replace(",", " ") for c in r]) for r in rows])
        flat = "\n".join([",".join([(c or "") for c in r]) for r in rows])
        tables.append({
            "type": "table",
            "parser": parser,
            "page": page,
            "text": flat,
            "table_csv": csv,
            "source_pdf": pdf_path,
        })

    tried_camelot = False
    try:
        import camelot
        tried_camelot = True
        try:
            tabs = camelot.read_pdf(pdf_path, pages=pages_str, flavor="lattice",
                                    line_scale=40, process_background=True, strip_text="\n")
            for t in tabs:
                page = t.parsing_report.get("page")
                rows = t.df.values.tolist()
                _add_rows(rows, page, "camelot-lattice")
        except Exception:
            pass
        try:
            tabs = camelot.read_pdf(pdf_path, pages=pages_str, flavor="stream",
                                    edge_tol=150, row_tol=10, column_tol=20, strip_text="\n")
            for t in tabs:
                page = t.parsing_report.get("page")
                rows = t.df.values.tolist()
                _add_rows(rows, page, "camelot-stream")
        except Exception:
            pass
    except Exception:
        pass

    with pdfplumber.open(pdf_path) as pdf:
        for p in candidate_pages:
            page = pdf.pages[p-1]
            settings = dict(
                vertical_strategy="lines",
                horizontal_strategy="lines",
                snap_tolerance=3,
                join_tolerance=3,
                edge_min_length=20,
                intersection_x_tolerance=5,
                intersection_y_tolerance=5,
            )
            found = page.extract_tables(settings) or []
            if not found:
                found = page.extract_tables() or []
            for tbl in found:
                rows = [[(c or "").strip() for c in row] for row in tbl]
                _add_rows(rows, p, "pdfplumber")
    return dedup_tables(tables)

def dedup_tables(tables: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    seen = set()
    out = []
    for t in tables:
        sig = (t["page"], "\n".join(t["text"].split("\n")[:2]))
        if sig in seen:
            continue
        seen.add(sig)
        out.append(t)
    return out

# ===================== 4) текст и картинки =====================

def extract_text_chunks(pdf_path: str) -> List[Dict[str, Any]]:
    els = partition_pdf(
        filename=pdf_path,
        strategy="fast",
        chunking_strategy="by_title",
        languages=["ru","en"],
        max_characters=3000,
        new_after_n_chars=2200,
        combine_under_n_chars=900,
        overlap=200,
        infer_table_structure=False,
        extract_images_in_pdf=False,
        include_page_breaks=False,
    )
    chunks = []
    for e in els:
        txt = (e.text or "").strip()
        if not txt: 
            continue
        chunks.append({
            "type": "text",
            "subtype": e.category,
            "text": txt,
            "page": getattr(e.metadata, "page_number", None),
            "source_pdf": pdf_path,
        })
    return chunks

def extract_figures(pdf_path: str) -> List[Dict[str, Any]]:
    figs = []
    doc = fitz.open(pdf_path)
    try:
        for i in range(len(doc)):
            imgs = doc[i].get_images(full=True)
            for xref, *_ in imgs:
                figs.append({"type":"figure","page":i+1,
                             "image_xref":xref,"text":"",
                             "source_pdf":pdf_path,
                             "image_b_64": ""})
    finally:
        doc.close()
    return figs

# ===================== 5) пайплайн ============================================

def parse_pdf_fast_strict_tables(pdf_path: str) -> List[Dict[str, Any]]:
    text_pages, image_pages = screen_pages(pdf_path, min_words=10)
    text_chunks = extract_text_chunks(pdf_path)
    candidate_pages = pick_table_pages(pdf_path, text_pages, thr_rulings=0.3, thr_cols=0.4)
    tables = extract_tables_on_candidates(pdf_path, candidate_pages)
    figures = extract_figures(pdf_path)
    prio = {"text":1,"table":2,"figure":3}
    items = text_chunks + tables + figures
    items.sort(key=lambda x: ((x.get("page") or 10**9), prio.get(x["type"],9)))
    return items

### Добавление ID

In [6]:
def _sha1_file(path: str, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha1()
    p = Path(path)
    try:
        with p.open("rb") as f:
            while True:
                b = f.read(chunk_size)
                if not b:
                    break
                h.update(b)
        return h.hexdigest()
    except Exception:
        return hashlib.sha1(str(p).encode("utf-8")).hexdigest()

def _sha1_text(s: str) -> str:
    return hashlib.sha1((s or "").encode("utf-8")).hexdigest()

def attach_ids(
    items: List[Dict[str, Any]],
    doc_id_mode: Literal["file_hash", "path_hash"] = "file_hash",
    chunk_id_mode: Literal["page_local", "sequential", "content_hash"] = "page_local",
) -> List[Dict[str, Any]]:
    """
    Добавляет doc_id, chunk_id, original_index в каждый словарь чанка.
    Ожидается, что у чанка уже есть поля: type, subtype, text, page, source_pdf.
    """
    doc_ids = {}
    for it in items:
        src = it.get("source_pdf") or "unknown"
        if src not in doc_ids:
            if doc_id_mode == "file_hash":
                doc_ids[src] = _sha1_file(src)
            else:
                doc_ids[src] = hashlib.sha1(str(src).encode("utf-8")).hexdigest()

    per_page_counters = defaultdict(lambda: defaultdict(int))  
    per_doc_counter = defaultdict(int)                         
    seen_chunk_ids = set()

    out = []
    for idx, ch in enumerate(items):
        src = ch.get("source_pdf") or "unknown"
        page = ch.get("page") if ch.get("page") is not None else 0
        doc_id = doc_ids[src]

        if chunk_id_mode == "page_local":
            per_page_counters[doc_id][page] += 1
            local = per_page_counters[doc_id][page]
            chunk_id = f"{doc_id}:p{page}:{local}"
        elif chunk_id_mode == "sequential":
            per_doc_counter[doc_id] += 1
            chunk_id = f"{doc_id}:{per_doc_counter[doc_id]}"
        else:  
            text = (ch.get("text") or ch.get("content") or "")
            text = text[:1024].replace("\n", " ").strip()
            page = ch.get("page") if ch.get("page") is not None else 0
            chunk_key = f"{page}|{ch.get('type')}|{text}"
            chunk_id  = f"{doc_id}:{_sha1_text(chunk_key)}"

        if chunk_id in seen_chunk_ids:
            suffix = _sha1_text(f"{chunk_id}|{idx}")[:6]
            chunk_id = f"{chunk_id}:{suffix}"
        seen_chunk_ids.add(chunk_id)

        ch2 = dict(ch)
        ch2["doc_id"] = doc_id
        ch2["chunk_id"] = chunk_id
        ch2["original_index"] = idx
        out.append(ch2)

    return out

### Собираем все метаданные

In [7]:
def create_all_metadata(pdf_folder: str):
    all_items = []
    pdf_paths = sorted([str(p) for p in Path(pdf_folder).rglob("*.pdf")])

    for pdf_path in pdf_paths:
        items = parse_pdf_fast_strict_tables(pdf_path)
        used_pages = []

        for id, item in enumerate(items):
            if item['type'] == 'figure':
                if item['page'] in used_pages:
                    items.pop(id)
                else:
                    used_pages.append(item['page'])
            else:
                continue
        
        items_w_ids = attach_ids(items, doc_id_mode="file_hash", chunk_id_mode="page_local")
        all_items.extend(items_w_ids)

    return all_items

In [None]:
items = create_all_metadata(pdf_folder=pdf_dir)

In [18]:
uniq_docs = []
uniq_chunks = []

for item in items:
    if item.get('doc_id') not in uniq_docs:
        uniq_docs.append(item.get('doc_id'))
    if item.get('chunk_id') not in uniq_chunks:
        uniq_chunks.append(item.get('chunk_id'))

### Обработка изображений

In [68]:
unique_image_pages = defaultdict(set)

In [70]:
for item in items:
    doc_name = item.get('source_pdf', '')
    if str(item.get('type', '')).lower() == 'figure':
        page = item.get('page')
        if page is not None:
            unique_image_pages[doc_name].add(int(page))

In [71]:
unique_image_pages = {k: sorted(v) for k, v in unique_image_pages.items()}

In [73]:
def render_page_png(pdf_path: str, page: int, dpi: int = 220, 
                    max_side: int = 1600, out_path=image_output_path) -> str:
    doc = fitz.open(pdf_path)
    try:
        pg = doc[page-1]
        scale = dpi / 72.0
        pix = pg.get_pixmap(matrix=fitz.Matrix(scale, scale), alpha=False)
        
        output_folder = out_path + pdf_path.split('/')[-1].split('.')[0]
        if not os.path.exists(output_folder):
            os.mkdir(output_folder)
        out_path_new = output_folder + '/' + f"p_{page}.png"
        pix.save(out_path_new)
    finally:
        doc.close()

    im = Image.open(out_path_new)
    w, h = im.size
    m = max(w, h)
    if m > max_side:
        ratio = max_side / m
        im = im.resize((int(w*ratio), int(h*ratio)), Image.LANCZOS)
        im.save(out_path_new, optimize=True, quality=100)
    return out_path_new

In [65]:
def encode_image(image_path):
    """Getting the base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def image_summarize(img_base64, prompt):
    """Image summary"""
    chat = ChatOpenAI(model="gpt-4o", max_tokens=2048)

    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
                    },
                ]
            )
        ]
    )
    return msg.content

In [66]:
prompt = """You are an assistant tasked with summarizing images for retrieval. \
These summaries will be embedded and used to retrieve the raw image. \
Give a concise summary of the image that is well optimized for retrieval in russian language strictly."""

In [84]:
for doc_path, pages in tqdm.tqdm(unique_image_pages.items(), desc='Number of docs'):
    for p in pages:
        test_image_path = render_page_png(doc_path,
                                          page=p)
        encoded_image = encode_image(test_image_path)
        summary = image_summarize(encoded_image, prompt=prompt)
        for item in items:
            if item['source_pdf'] == doc_path and item['type'] == 'figure' and item['page'] == p:
                item['text'] = summary
                item['image_b_64'] = encoded_image

Number of docs:   0%|          | 0/26 [00:00<?, ?it/s]

../instructions/Dataset/KazTelecom/kztkp_2024_rus.pdf


Number of docs:  58%|█████▊    | 15/26 [10:34<07:45, 42.30s/it]

../instructions/Dataset/Maten Petroleum/matnp_2024_rus.pdf


Number of docs:  62%|██████▏   | 16/26 [11:24<07:09, 42.94s/it]

../instructions/Dataset/Oasis Logistics/oasi_af_4_2025.pdf


Number of docs:  65%|██████▌   | 17/26 [11:52<06:11, 41.33s/it]

../instructions/Dataset/Oasis Logistics/oasif6_2024_rus.pdf


Number of docs:  69%|██████▉   | 18/26 [16:21<09:38, 72.33s/it]

../instructions/Dataset/Qazaqstan Temir Joly/tmjl_2024_rus.pdf


Number of docs:  73%|███████▎  | 19/26 [28:43<21:10, 181.52s/it]

../instructions/Dataset/Qazaqstan Temir Joly/tmjl_af_1_2025.pdf


Number of docs:  77%|███████▋  | 20/26 [28:49<14:49, 148.32s/it]

../instructions/Dataset/Qazaqstan Temir Joly/tmjlf6_2024_cons_rus.pdf


Number of docs:  81%|████████  | 21/26 [29:33<10:31, 126.26s/it]

../instructions/Dataset/Rakhat/raht_af_4_2025.pdf


Number of docs:  85%|████████▍ | 22/26 [30:48<07:37, 114.27s/it]

../instructions/Dataset/Rakhat/rahtp_2024_rus.pdf


Number of docs:  88%|████████▊ | 23/26 [35:17<07:38, 152.95s/it]

../instructions/Dataset/Teniz Capital/tcib_af_4_2025.pdf


Number of docs:  92%|█████████▏| 24/26 [35:52<04:03, 121.90s/it]

../instructions/Dataset/Teniz Capital/tcibp_2024_rus.pdf


Number of docs:  96%|█████████▌| 25/26 [42:43<03:20, 200.94s/it]

../instructions/Dataset/Transtelecom/tcom_af_4_2025.pdf


Number of docs: 100%|██████████| 26/26 [43:07<00:00, 99.52s/it] 


### Добавление контекста

In [85]:
SUMMARY_PROMPT = """
Here is the full document
<document>
{doc_content}
</document>

Please give it a strict and precise summary in a range of 2500 to 3000 words in russian language for the purposes 
of using it as a knowledge base for retrieval of the chunks. Answer only with the strict and precise summary and nothing else and keep the content relevant.
"""

In [86]:
DOCUMENT_CONTEXT_PROMPT = """
<document>
{doc_content}
</document>
"""

CHUNK_CONTEXT_PROMPT = """
Here is the chunk we want to situate within the whole document
<chunk>
{chunk_content}
</chunk>

Please give a short succinct context in russian language to situate this chunk within the overall document for the purposes 
of improving search retrieval of the chunk. Answer only with the succinct context and nothing else.
"""

In [92]:
def create_summary(pdf_path: str, summary_prompt: str):
    with pdfplumber.PDF(open(file=pdf_path, mode='rb')) as pdf:
        pages = [page.extract_text() for page in pdf.pages]
    doc_text = ''.join(pages)
    doc_text = doc_text.replace('\n','')

    if len(doc_text) >= 200000:
        doc_text = doc_text[:199000]

    with client.messages.stream(
        model="claude-3-5-haiku-20241022",
        max_tokens=3500,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": summary_prompt.format(doc_content=doc_text)
                    }
                ]
            }
        ]
    ) as stream:
        chunks = []
        for delta in stream.text_stream:
            chunks.append(delta)
        final_msg = stream.get_final_message()
    return "".join(chunks)

In [None]:
docs_summary = {}

pdf_paths = sorted([str(p) for p in Path(pdf_dir).rglob("*.pdf")])

In [None]:
for pdf in tqdm.tqdm(pdf_paths, desc="Creating summary for each document"):
    if pdf in docs_summary.keys():
        continue
    docs_summary[pdf] = create_summary(pdf_path=pdf, summary_prompt=SUMMARY_PROMPT)
    time.sleep(40)

In [98]:
def situate_context(doc: str, chunk: str) -> str:
    response = client.messages.create(
        model="claude-3-5-haiku-20241022",
        max_tokens=800,
        temperature=0.0,
        messages=[
            {
                "role": "user", 
                "content": [
                    {
                        "type": "text",
                        "text": DOCUMENT_CONTEXT_PROMPT.format(doc_content=doc),
                        "cache_control": {"type": "ephemeral", "ttl": "1h"} 
                    },
                    {
                        "type": "text",
                        "text": CHUNK_CONTEXT_PROMPT.format(chunk_content=chunk),
                    }
                ]
            }
        ],
        extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
    )
    return response

In [132]:
for item in tqdm.tqdm(items, desc="Making context for each chunk"):
    if item.get('contextualized_content'):
        continue
    doc_path = item.get('source_pdf')
    doc_content = docs_summary[doc_path]
    chunk_content = item['text']
    if not chunk_content:
        continue
    else:
        response = situate_context(doc_content, chunk_content)
        item['contextualized_content'] = response

Making context for each chunk: 100%|██████████| 4822/4822 [34:08<00:00,  2.35it/s]  


In [107]:
print(items[30]['contextualized_content'].content[0].text)

Дополнительная информация об аффилированных лицах компании ТОО "QazaqGaz Onimderi", входящей в группу компаний АО "Самрук-Қазына", с подробным описанием семейных связей руководства и членов правления компании, включая их родственников и статус резидентства.


### Semantic retriever

In [136]:
def create_multi_vector_retriever(vectorstore, chunks):
    store = InMemoryStore()

    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key="chunk_id"
    )

    retriever.search_type = SearchType.mmr

    child_docs = []
    kv_to_store = [] 

    for ch in chunks:
        original_text = (ch.get("text") or "").strip()
        context_text = (ch.get("contextualized_content").content[0].text if ch.get("contextualized_content") else "").strip()
        final_text = original_text + "\n\n" + "Context: " + context_text
        if not final_text:
            continue  

        parent = Document(
            page_content=final_text,
            metadata={
                "page_num": ch.get("page"),
                "source": ch.get("source_pdf", ""),
                "data_type": ch.get("type", "text"),
                "raw_image": ch.get("image_b_64", ""),
                "chunk_id": ch.get("chunk_id"),
                "original_index": ch.get("original_index")
            },
        )
        kv_to_store.append((ch.get("chunk_id"), parent))

        child = Document(
            page_content=final_text,
            metadata={
                "chunk_id": ch.get("chunk_id"),                
                "page_num": ch.get("page"),
                "source": ch.get("source_pdf", ""),
                "data_type": ch.get("type", "text"),
            },
        )
        child_docs.append(child)

    retriever.docstore.mset(kv_to_store)
    retriever.vectorstore.add_documents(child_docs)

    return retriever


embeddings = OpenAIEmbeddings(
    model="text-embedding-3-large",
    chunk_size=32                   
)

multi_vector_img = Chroma(
    collection_name="multi_vector_img_all_ver2", 
    persist_directory='../persist_dir_2',
    embedding_function=embeddings
)

retriever_multi_vector_img = create_multi_vector_retriever(
    multi_vector_img,
    items
)

### BM25

In [137]:
class ElasticsearchBM25:
    def __init__(self, index_name: str = "contextual_bm25_index_all_ver2"):
        self.es_client = Elasticsearch("http://localhost:9200")
        self.index_name = index_name
        self.create_index()

    def create_index(self):
        index_settings = {
            "settings": {
                "similarity": {"default": {"type": "BM25"}}
            },
            "mappings": {
                "properties": {
                    "content": {"type": "text", "analyzer": "russian"},
                    "contextualized_content": {"type": "text", "analyzer": "russian"},
                    "doc_id": {"type": "keyword"},
                    "chunk_id": {"type": "keyword"},
                    "page": {"type": "integer"},
                    "source_pdf": {"type": "keyword", "index": False},
                    "data_type": {"type": "keyword", "index": False},
                    "raw_image": {"type": "binary"},
                    "original_index": {"type": "integer"} 
                }
            },
        }
        if not self.es_client.indices.exists(index=self.index_name):
            self.es_client.indices.create(index=self.index_name, body=index_settings)
            print(f"Created index: {self.index_name}")
    
    def index_documents(self, documents: List[Dict[str, Any]]):
        actions = [
            {
                "_index": self.index_name,
                "_source": {
                    "content": doc["text"],
                    "contextualized_content": doc.get("contextualized_content").content[0].text if doc.get("contextualized_content") else "",
                    "doc_id": doc["doc_id"],
                    "chunk_id": doc["chunk_id"],
                    "page": doc["page"],
                    "source_pdf": doc["source_pdf"],
                    "data_type": doc["type"],
                    "raw_image": doc.get("raw_image"),
                    "original_index": doc["original_index"]
                },
            }
            for doc in documents
        ]
        success, _ = bulk(self.es_client, actions)
        self.es_client.indices.refresh(index=self.index_name)
        return success
    
    def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
        self.es_client.indices.refresh(index=self.index_name) 
        search_body = {
            "query": {
                "multi_match": {
                    "query": query,
                    "fields": ["content", "contextualized_content"],
                }
            },
            "size": k,
        }
        response = self.es_client.search(index=self.index_name, body=search_body)
        return [
            {
                "content": hit["_source"]["content"],
                "contextualized_content": hit["_source"]["contextualized_content"],
                "doc_id": hit["_source"]["doc_id"],
                "chunk_id": hit["_source"]["chunk_id"],
                "score": hit["_score"],
                "page": hit["_source"]["page"],
                "source_pdf": hit["_source"]["source_pdf"],
                "data_type": hit["_source"]["data_type"],
                "raw_image": hit["_source"]["raw_image"],
                "original_index": hit["_source"]["original_index"]
            }
            for hit in response["hits"]["hits"]
        ]

def create_elasticsearch_bm25_index(data):
    es_bm25 = ElasticsearchBM25()
    es_bm25.index_documents(data)
    return es_bm25

In [138]:
es_bm25 = create_elasticsearch_bm25_index(items)

  self.es_client.indices.create(index=self.index_name, body=index_settings)


Created index: contextual_bm25_index_all_ver2


In [149]:
def retrieve_advanced(query: str, db: MultiVectorRetriever, es_bm25: ElasticsearchBM25, 
                      k: int, 
                      docs: List[Dict[str, Any]],
                      semantic_weight: float = 0.7, bm25_weight: float = 0.3,
                      num_chunks_to_recall: int = 50):

    semantic_results = db.invoke(query, k=num_chunks_to_recall)
    ranked_chunk_ids = [(result.metadata['chunk_id'], result.metadata['original_index']) for result in semantic_results]

    bm25_results = es_bm25.search(query, k=num_chunks_to_recall)
    ranked_bm25_chunk_ids = [(result['chunk_id'], result['original_index']) for result in bm25_results]

    chunk_ids = list(set(ranked_chunk_ids + ranked_bm25_chunk_ids))
    chunk_id_to_score = {}

    for chunk_id in chunk_ids:
        score = 0
        if chunk_id in ranked_chunk_ids:
            index = ranked_chunk_ids.index(chunk_id)
            score += semantic_weight * (1 / (index + 1))  
        if chunk_id in ranked_bm25_chunk_ids:
            index = ranked_bm25_chunk_ids.index(chunk_id)
            score += bm25_weight * (1 / (index + 1)) 
        chunk_id_to_score[chunk_id] = score

    sorted_chunk_ids = sorted(
        chunk_id_to_score.keys(), key=lambda x: (chunk_id_to_score[x], x[0], x[1]), reverse=True
    )

    # -------------------- Cohere Rerank (минимальное добавление) --------------------
    api_key = os.getenv("COHERE_API_KEY")
    if api_key and sorted_chunk_ids:
        co = cohere.Client(api_key)

        doc_map = { (d["chunk_id"], d["original_index"]) : d for d in docs }

        top_n_rr = min(len(sorted_chunk_ids), max(5 * k, 50))
        candidates = sorted_chunk_ids[:top_n_rr]

        docs_for_rr = []
        kept_idx = []
        for idx, cid in enumerate(candidates):
            d = doc_map.get(cid)
            if not d:
                continue
            base = d.get("text")
            ctx  = d.get("contextualized_content")
            if isinstance(ctx, str):
                ctx_text = ctx
            else:
                try:
                    ctx_text = ctx.content[0].text 
                except Exception:
                    ctx_text = ""
            txt = (base + ("\n\n" + ctx_text if ctx_text else "")).strip()
            if not txt:
                continue
            docs_for_rr.append({"text": txt})
            kept_idx.append(idx)

        if docs_for_rr:
            rr = co.rerank(
                model="rerank-v3.5",
                query=query,
                documents=docs_for_rr,
                top_n=len(docs_for_rr),
            )
            order = sorted(rr.results, key=lambda r: r.relevance_score, reverse=True)
            reranked_ids = [candidates[kept_idx[r.index]] for r in order]
            remainder = [cid for cid in candidates if cid not in set(reranked_ids)]
            tail = [cid for cid in sorted_chunk_ids if cid not in set(candidates)]
            sorted_chunk_ids = reranked_ids + remainder + tail

    for index, chunk_id in enumerate(sorted_chunk_ids):
        chunk_id_to_score[chunk_id] = 1 / (index + 1)

    final_results = []
    semantic_count = 0
    bm25_count = 0
    for chunk_id in sorted_chunk_ids[:k]:
        chunk_metadata = next(chunk for chunk in docs if chunk['chunk_id'] == chunk_id[0] and chunk['original_index'] == chunk_id[1])
        is_from_semantic = chunk_id in ranked_chunk_ids
        is_from_bm25 = chunk_id in ranked_bm25_chunk_ids
        final_results.append({
            'chunk': chunk_metadata,
            'score': chunk_id_to_score[chunk_id],
            'from_semantic': is_from_semantic,
            'from_bm25': is_from_bm25
        })
        
        if is_from_semantic and not is_from_bm25:
            semantic_count += 1
        elif is_from_bm25 and not is_from_semantic:
            bm25_count += 1
        else: 
            semantic_count += 0.5
            bm25_count += 0.5

    return final_results, semantic_count, bm25_count

In [150]:
def _build_page_content(ch: dict) -> str:
    text = (ch.get("text") or ch.get("content") or "")  
    ctx  = (ch.get("contextualized_content").content[0].text if ch.get("contextualized_content") else "") 
    return (text + ("  " + ctx if ctx else "")).strip()

def make_hybrid_runnable(
    db, es_bm25, docs,
    top_k_out: int = 20,
    recall: int = 50,
    semantic_weight: float = 0.7,
    bm25_weight: float = 0.3
):
    """Возвращает Runnable, который принимает строку-вопрос и отдаёт List[Document]."""
    def _search(query: str):
        results, _, _ = retrieve_advanced(
            query=query,
            db=db,
            es_bm25=es_bm25,
            k=top_k_out,                   
            docs=docs,                         
            semantic_weight=semantic_weight,
            bm25_weight=bm25_weight,
            num_chunks_to_recall=recall
        )
        out = []
        for r in results:
            ch = r["chunk"]
            out.append(
                Document(
                    page_content=_build_page_content(ch),
                    metadata={
                        "chunk_id": ch["chunk_id"],
                        "doc_id": ch["doc_id"],
                        "original_index": ch["original_index"],
                        "page_num": ch.get("page"),
                        "source": ch.get("source_pdf"),
                        "data_type": ch.get("type"),
                        "raw_image": ch.get("raw_image"), 
                        "score": r["score"],
                        "from_semantic": r["from_semantic"],
                        "from_bm25": r["from_bm25"]
                    },
                )
            )
        return out
    return RunnableLambda(_search)

In [151]:
hybrid_retriever = make_hybrid_runnable(
    db=retriever_multi_vector_img,
    es_bm25=es_bm25,     
    docs=items
)

### Multi-modal RAG

In [192]:
def build_messages_with_indices(payload: Dict[str, Any]) -> Dict[str, Any]:
    """
    payload: {"docs": List[Document], "question": str, "question_id": int, "answer_type": str}
    Возвращает {"messages": [HumanMessage], "docs": docs, "question_id": ...}
    """
    docs: List[Document] = payload["docs"]
    question: str = payload.get("question") or payload.get("full_question") or ""
    answer_type: str = payload["answer_type"]

    instruction = (
        "Отвечай на русском языке ТОЛЬКО на основе фрагментов ниже.\n"
        "Каждый фрагмент помечен индексом в квадратных скобках, например [0].\n"
        "Если в вопросе просят указать должность, то напиши его должность целиком и обязательно проверь, что имена в вопросе и найденных фрагментах совпадают.\n"
        "Часто ответ на такой вопрос имеется в таблицах, проверяй, что ты берешь информацию из нужного ряда.\n"
        "Верни JSON ровно такого вида и ничего больше:\n"
        '{ "answer": "<строка ответа>", "evidence": [<индексы фрагментов, по убыванию важности>] }\n'
        "В поле answer верни только ответ на вопрос, без дополнительных пояснений и знаков препинания в конце.\n"
        f"Ожидаемый тип данных ответа на вопрос - {answer_type}.\n"
        "Если ожидаемый тип данных int либо float либо в вопросе просят указать количество чего-то, либо назвать какое-то число (баллов, очков и т.д.), то в ответе напиши ТОЛЬКО цифру, без пояснений, меры измерения и валюты.\n"
        "Не отделяй разряды пробелом, например 123 000 должно быть 123000.\n"
        "При этом, если ожидаемый тип данных float, то в ответе обязательно должна быть дробная часть, отдели его точкой, например 6.2 или 14.75.\n"
        "Если использован один фрагмент — верни один индекс.\n"
        "Если для ответа на вопрос использовано несколько фрагментов, например необходимо сравнить информацию на разных страницах либо в разных документах - верни несколько индексов.\n"
        "Не добавляй лишний текст вне JSON."
    )

    content = [
        {"type": "text", "text": instruction},
        {"type": "text", "text": f"Вопрос: {question}"},
    ]

    for i, d in enumerate(docs):
        src = d.metadata.get("source") or d.metadata.get("source_pdf") or ""
        name = Path(src).name if src else ""
        page = d.metadata.get("page_num") or d.metadata.get("page")

        if d.metadata.get("data_type") == "figure" and d.metadata.get("raw_image"):
            content.append({"type": "text", "text": f"[{i}] IMAGE — {name}, стр. {page}."})
            img = d.metadata["raw_image"]
            url = img if str(img).startswith("data:") else f"data:image/png;base64,{img}"
            content.append({"type": "image_url", "image_url": {"url": url}})
        else:
            txt = d.page_content or ""
            content.append({
                "type": "text",
                "text": f"[{i}] TEXT — {name}, стр. {page}\n{txt}"
            })

    return {
        "messages": [HumanMessage(content=content)],
        "docs": docs,
        "question_id": payload["question_id"],
        "answer_type": payload["answer_type"]
    }

In [193]:
def parse_model_json(s: str) -> Dict[str, Any]:
    try:
        start = s.find("{")
        end = s.rfind("}")
        if start == -1 or end == -1:
            return {"answer": s.strip(), "evidence": []}
        data = json.loads(s[start:end+1])
        # подстрахуем поля
        if not isinstance(data, dict):
            return {"answer": s.strip(), "evidence": []}
        ans = str(data.get("answer", "")).strip()
        ev = data.get("evidence", [])
        if isinstance(ev, int):
            ev = [ev]
        if not isinstance(ev, list):
            ev = []
        ev = [int(x) for x in ev if isinstance(x, (int, float))] 
        return {"answer": ans or s.strip(), "evidence": ev}
    except Exception:
        return {"answer": s.strip(), "evidence": []}

In [194]:
def build_final_payload(inputs: Dict[str, Any]) -> Dict[str, Any]:
    """
    inputs: {
      "answer_text": "<raw model output>",
      "docs": List[Document],
      "question_id": int,
      "answer_type": str
    }
    """
    parsed = parse_model_json(inputs["answer_text"])
    answer_text: str = parsed["answer"]
    if inputs["answer_type"] == 'float':
        try:
          answer_text = float(answer_text)
        except ValueError:
           pass
    elif inputs["answer_type"] == 'int':
        try:
          answer_text = int(answer_text)   
        except ValueError:
           try:
              answer_text = float(answer_text)
           except ValueError:
              pass
    ev: List[int] = parsed["evidence"] or []

    docs: List[Document] = inputs["docs"]
    qid: int = int(inputs["question_id"])

    idx = ev[0] if ev and 0 <= ev[0] < len(docs) else 0
    d = docs[idx]
    src = d.metadata.get("source") or d.metadata.get("source_pdf") or ""
    name = Path(src).name if src else ""
    page = d.metadata.get("page_num") or d.metadata.get("page")

    return {
        "question_id": qid,
        "relevant_chunks": [{"document_name": name, "page_number": int(page) if page is not None else None}],
        "answer": answer_text
    }

In [195]:
def build_structured_chain_that_cites_one(hybrid_retriever_runnable, llm=None):
    model = llm or ChatOpenAI(temperature=0, model="gpt-5", max_tokens=2048)

    chain = (
    {
        "docs": (itemgetter("full_question") | hybrid_retriever),
        "question": itemgetter("full_question"),
        "question_id": itemgetter("id"),
        "answer_type": itemgetter("answer_type")
    }
    | RunnableLambda(build_messages_with_indices)
    | {
        "answer_text": (itemgetter("messages") | model | StrOutputParser()),
        "docs": itemgetter("docs"),
        "question_id": itemgetter("question_id"),
        "answer_type": itemgetter("answer_type")
    }
    | RunnableLambda(build_final_payload)
)
    return chain

In [196]:
chain_structured = build_structured_chain_that_cites_one(hybrid_retriever)

### Тестирование

In [205]:
import warnings
warnings.filterwarnings('ignore')

In [208]:
questions = pd.read_excel('../instructions/questions_private.xlsx')

In [210]:
questions.tail()

Unnamed: 0,id,block,full_question,answer_type
195,196,OCR,Сколько пассажирских вагонов у АО «НК «ҚТЖ» (в...,float
196,197,OCR,Какой объем добычи нефти (тыс. тонн) с месторо...,float
197,198,AR,Сколько шоколадных цехов есть у АО “Баян Сулу”?,int
198,199,Both,Какую долю (в процентах) в грузообороте АО «НК...,int
199,200,Both,Какова суммарная доля участия членов Совета Ди...,float


In [211]:
questions.shape

(200, 4)

In [None]:
ans = []

for id, row in tqdm.tqdm(questions.iterrows(), desc="Answering questions..."):
    question_payload = {}
    question_payload['id'] = row['id']
    question_payload['full_question'] = row['full_question']
    question_payload['answer_type'] = row['answer_type']
    result = chain_structured.invoke(question_payload)
    print(f"Answer: {result['answer']} --- Question_ID: {result['question_id']} --- DOC: {result['relevant_chunks'][0]['document_name']} --- Page: {result['relevant_chunks'][0]['page_number']}")
    ans.append(result)
    time.sleep(1)

In [221]:
with open("our_answers.json", "w", encoding="utf-8") as f:
    json.dump(ans, f, ensure_ascii=False, indent=2)