In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import (
    Dict,
    List,
    Protocol,
    Sequence,
    Tuple,
    TypedDict,
)

import numpy as np
import numpy.typing as npt

from daft_func import NestedPipeline, Pipeline, func

# ---- Core vector type -------------------------------------------------------
Vector = npt.NDArray[np.float32]


# ---- Protocols -------------------------------------------------------------
class Encoder(Protocol):
    dim: int

    def encode(self, text: str) -> Vector: ...


class Indexer(Protocol):
    def index(self, encoded: Sequence["EncodedPassage"]) -> "BaseIndex": ...


class Reranker(Protocol):
    def rerank(
        self, query: "Query", hits: Sequence["RetrievedDoc"], top_k: int | None = None
    ) -> List["RetrievedDoc"]: ...


# ---- Data models ------------------------------------------------------------
@dataclass(frozen=True)
class Passage:
    pid: str
    text: str


@dataclass(frozen=True)
class EncodedPassage:
    pid: str
    text: str
    embedding: Vector


@dataclass(frozen=True)
class Query:
    text: str


@dataclass(frozen=True)
class RetrievedDoc:
    pid: str
    text: str
    embedding: Vector
    score: float


class SearchHit(TypedDict):
    pid: str
    score: float


class BaseIndex(Protocol):
    dim: int

    def add(self, items: Sequence[EncodedPassage]) -> None: ...
    def search(self, query_vec: Vector, top_k: int = 10) -> List[SearchHit]: ...
    def get(self, pid: str) -> EncodedPassage: ...

In [None]:
# ---- Implementations -------------------------------------------------------
class NumpyRandomEncoder:
    def __init__(self, dim: int = 4, rng: np.random.Generator | None = None) -> None:
        self.dim = dim
        self._rng = rng or np.random.default_rng()

    def encode(self, text: str) -> Vector:
        return self._rng.random(self.dim, dtype=np.float32)


class InMemoryIndex:
    def __init__(self, dim: int) -> None:
        self.dim = dim
        self._data: Dict[str, EncodedPassage] = {}

    def add(self, items: Sequence[EncodedPassage]) -> None:
        for it in items:
            self._data[it.pid] = it

    def search(self, query_vec: Vector, top_k: int = 10) -> List[SearchHit]:
        q = query_vec / (np.linalg.norm(query_vec) + 1e-12)
        hits: List[Tuple[str, float]] = []
        for pid, ep in self._data.items():
            v = ep.embedding / (np.linalg.norm(ep.embedding) + 1e-12)
            hits.append((pid, float(np.dot(q, v))))
        hits.sort(key=lambda x: x[1], reverse=True)
        return [{"pid": pid, "score": score} for pid, score in hits[:top_k]]

    def get(self, pid: str) -> EncodedPassage:
        return self._data[pid]


class SimpleIndexer:
    def __init__(self, dim: int) -> None:
        self.dim = dim

    def index(self, encoded: Sequence[EncodedPassage]) -> BaseIndex:
        idx = InMemoryIndex(self.dim)
        idx.add(encoded)
        return idx


class IdentityReranker:
    def rerank(
        self, query: Query, hits: Sequence[RetrievedDoc], top_k: int | None = None
    ) -> List[RetrievedDoc]:
        out = list(hits)
        if top_k is not None:
            out = out[:top_k]
        return out

In [None]:
# ---- Encode Pipeline ---------------------------------------------
@func(output="cleaned_text")
def clean_text(passage: Passage) -> str:
    return passage.text.strip().lower()


@func(output="embedding")
def encode_text(encoder: Encoder, cleaned_text: str, is_query: bool = False) -> Vector:
    return encoder.encode(cleaned_text, is_query)


@func(output="encoded_passage")
def pack_encoded(passage: Passage, embedding: Vector) -> EncodedPassage:
    return EncodedPassage(pid=passage.pid, text=passage.text, embedding=embedding)


single_encode = Pipeline(functions=[clean_text, encode_text, pack_encoded])
single_encode.visualize()

In [None]:
res = single_encode.run(
    inputs={"passage": Passage(pid="1", text="hello"), "encoder": Encoder()}
)  # should return a dict with keys as output names and corresponding results

In [None]:
# res == {
#     "cleaned_text": "hello",
#     "encoded_text": np.ndarray([...])
# }

In [None]:
single_encode.map(
    inputs={
        "passage": [Passage(pid="1", text="hello"), Passage(pid="2", text="world")],
        "encoder": Encoder(),
    },
    map_over="passage",
)  # this should return a dict with keys as output names and corresponding lists of results

In [None]:
# res == {
#     "cleaned_text": ["hello", "world"],
#     "encoded_text": [
#         np.ndarray([...]),  # embedding for "hello"
#         np.ndarray([...])   # embedding for "world"
#     ]
# }

In [None]:
## Index
encode_corpus = NestedPipeline(
    pipeline=single_encode,
    inputs={"corpus": "passage"},
    outputs={"encoded_passage": "encoded_corpus"},
    map_over="corpus",
)


@func(output="index")
def build_index(
    indexer: Indexer, encoded_corpus: Sequence[EncodedPassage]
) -> BaseIndex:
    return indexer.index(encoded_corpus)


# Take the mapped EncodedPassage list and build an index
encode_and_index = Pipeline(nodes=[encode_corpus, build_index])

In [None]:
# Toy data
corpus: List[Passage] = [
    Passage(pid="p1", text="Hello World"),
    Passage(pid="p2", text="The Quick Brown Fox"),
]

encoder = NumpyRandomEncoder(dim=4)
indexer = SimpleIndexer(dim=encoder.dim)
encode_and_index.visualize()

outputs = encode_and_index.run(
    inputs={
        "corpus": corpus,
        "encoder": encoder,
        "indexer": indexer,
    }
)
index: BaseIndex = outputs["index"]

In [None]:
# ---- Retrieval + Reranking --------------------------------------------------
encode_query = NestedPipeline(
    pipeline=single_encode,
    inputs={"query.text": "text"},
    outputs={"encoded_passage": "query_vec"},
)


@func(output="retrieved")
def retrieve(
    index: BaseIndex, query_vec: Vector, top_k: int = 10
) -> List[RetrievedDoc]:
    hits = index.search(query_vec, top_k=top_k)
    return [
        RetrievedDoc(
            pid=h["pid"],
            text=index.get(h["pid"]).text,
            embedding=index.get(h["pid"]).embedding,
            score=h["score"],
        )
        for h in hits
    ]


@func(output="reranked_hits")
def rerank_hits(
    reranker: Reranker,
    query: Query,
    retrieved: List[RetrievedDoc],
    final_top_k: int | None = None,
) -> List[RetrievedDoc]:
    return reranker.rerank(query, retrieved, top_k=final_top_k)


search_pipeline = Pipeline(nodes=[encode_query, retrieve, rerank_hits])

In [None]:
full_pipeline = Pipeline(nodes=[encode_and_index, search_pipeline])
full_pipeline.visualize()

In [None]:
corpus = [
    Passage(pid="p1", text="Hello World"),
    Passage(pid="p2", text="Quick Brown Fox"),
]
encoder = NumpyRandomEncoder(dim=4)
indexer = SimpleIndexer(dim=encoder.dim)

outputs = encode_and_index.run(
    inputs={"passage": corpus, "encoder": encoder, "indexer": indexer}
)
index: BaseIndex = outputs["index"]

reranker = IdentityReranker()

In [None]:
# Single query
search_out = search_pipeline.run(
    inputs={
        "query": Query(text="hello world"),
        "encoder": encoder,
        "index": index,
        "reranker": reranker,
        "top_k": 5,
        "final_top_k": 3,
    }
)

for doc in search_out["reranked_hits"]:
    print(doc.pid, doc.score, doc.text)

In [None]:
# ---- Multiple queries ---------------------------------------------------
queries = [Query(text="hello"), Query(text="quick fox"), Query(text="world")]
batch_out = search_pipeline.map(
    inputs={
        "query": queries,
        "encoder": encoder,
        "index": index,
        "reranker": reranker,
        "top_k": 5,
        "final_top_k": 3,
    },
    map_over="query",
)

for q, results in zip(queries, batch_out["reranked_hits"]):
    print(f"\nQuery: {q.text}")
    for r in results:
        print(f"  {r.pid} | {r.score:.3f} | {r.text}")