In [1]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol, Sequence, Tuple, TypedDict

import numpy as np
import numpy.typing as npt

from hypernodes import Pipeline, node

# ---- Core vector type -------------------------------------------------------
Vector = npt.NDArray[np.float32]


# ---- Protocols -------------------------------------------------------------
class Encoder(Protocol):
    dim: int

    def encode(self, text: str, is_query: bool = False) -> Vector: ...


class Indexer(Protocol):
    def index(self, encoded: Sequence[EncodedPassage]) -> BaseIndex: ...


class Reranker(Protocol):
    def rerank(
        self, query: Query, hits: Sequence[RetrievedDoc], top_k: Optional[int] = None
    ) -> List[RetrievedDoc]: ...


# ---- Data models ------------------------------------------------------------
@dataclass(frozen=True)
class Passage:
    pid: str
    text: str


@dataclass(frozen=True)
class EncodedPassage:
    pid: str
    text: str
    embedding: Vector


@dataclass(frozen=True)
class Query:
    text: str


@dataclass(frozen=True)
class RetrievedDoc:
    pid: str
    text: str
    embedding: Vector
    score: float


class SearchHit(TypedDict):
    pid: str
    score: float


class BaseIndex(Protocol):
    dim: int

    def add(self, items: Sequence[EncodedPassage]) -> None: ...

    def search(self, query_vec: Vector, top_k: int = 10) -> List[SearchHit]: ...

    def get(self, pid: str) -> EncodedPassage: ...


# ---- Implementations -------------------------------------------------------
class NumpyRandomEncoder:
    def __init__(self, dim: int = 4, seed: int = 42):
        self.dim = dim
        self.seed = seed  # Public attribute, included in cache key

    def encode(self, text: str, is_query: bool = False) -> Vector:
        # Recreate RNG with seed for determinism
        rng = np.random.default_rng(self.seed)
        return rng.random(self.dim, dtype=np.float32)


class InMemoryIndex:
    def __init__(self, dim: int) -> None:
        self.dim = dim
        self._data: Dict[str, EncodedPassage] = {}

    def add(self, items: Sequence[EncodedPassage]) -> None:
        for it in items:
            self._data[it.pid] = it

    def search(self, query_vec: Vector, top_k: int = 10) -> List[SearchHit]:
        q = query_vec / (np.linalg.norm(query_vec) + 1e-12)
        hits: List[Tuple[str, float]] = []
        for pid, ep in self._data.items():
            v = ep.embedding / (np.linalg.norm(ep.embedding) + 1e-12)
            hits.append((pid, float(np.dot(q, v))))
        hits.sort(key=lambda x: x[1], reverse=True)
        return [{"pid": pid, "score": score} for pid, score in hits[:top_k]]

    def get(self, pid: str) -> EncodedPassage:
        return self._data[pid]


class SimpleIndexer:
    def __init__(self, dim: int) -> None:
        self.dim = dim

    def index(self, encoded: Sequence[EncodedPassage]) -> BaseIndex:
        idx = InMemoryIndex(self.dim)
        idx.add(encoded)
        return idx


class IdentityReranker:
    def rerank(
        self, query: Query, hits: Sequence[RetrievedDoc], top_k: Optional[int] = None
    ) -> List[RetrievedDoc]:
        out = list(hits)
        if top_k is not None:
            out = out[:top_k]
        return out

In [2]:
# ---- Core text encoding (reusable) ------------------------------------------
@node(output_name="cleaned_text")
def clean_text(text: str) -> str:
    return text.strip().lower()


@node(output_name="embedding")
def encode_text(encoder: Encoder, cleaned_text: str, is_query: bool = False) -> Vector:
    return encoder.encode(cleaned_text, is_query=is_query)


# Reusable text encoding pipeline
text_encode = Pipeline(nodes=[clean_text, encode_text], name="text_encode")
text_encode.visualize()

In [3]:
# ---- Passage encoding: extract -> encode -> pack ----------------------------
@node(output_name="text")
def extract_passage_text(passage: Passage) -> str:
    return passage.text


@node(output_name="encoded_passage")
def pack_passage(passage: Passage, embedding: Vector) -> EncodedPassage:
    return EncodedPassage(pid=passage.pid, text=passage.text, embedding=embedding)


# Single passage encoding pipeline
single_encode = Pipeline(
    nodes=[extract_passage_text, text_encode, pack_passage], name="single_encode"
)

# Visualize the DAG
single_encode.visualize(depth=1)

In [4]:
# Test with single passage
res = single_encode.run(
    inputs={
        "passage": Passage(pid="1", text="Hello"),
        "encoder": NumpyRandomEncoder(dim=4, seed=42),
        "is_query": False,
    }
)
# res contains:
# {
#     "text": "Hello",
#     "cleaned_text": "hello",
#     "embedding": np.ndarray([...]),
#     "encoded_passage": EncodedPassage(...)
# }

In [5]:
res

{'text': 'Hello',
 'cleaned_text': 'hello',
 'embedding': array([0.08925092, 0.773956  , 0.6545715 , 0.43887842], dtype=float32),
 'encoded_passage': EncodedPassage(pid='1', text='Hello', embedding=array([0.08925092, 0.773956  , 0.6545715 , 0.43887842], dtype=float32))}

In [6]:
# Test with map over multiple passages
results = single_encode.map(
    inputs={
        "passage": [Passage(pid="1", text="Hello"), Passage(pid="2", text="World")],
        "encoder": NumpyRandomEncoder(dim=4, seed=42),
        "is_query": False,
    },
    map_over="passage",
)
# results contains:
# {
#     "text": ["Hello", "World"],
#     "cleaned_text": ["hello", "world"],
#     "embedding": [np.ndarray([...]), np.ndarray([...])],
#     "encoded_passage": [EncodedPassage(...), EncodedPassage(...)]
# }

In [7]:
res

{'text': 'Hello',
 'cleaned_text': 'hello',
 'embedding': array([0.08925092, 0.773956  , 0.6545715 , 0.43887842], dtype=float32),
 'encoded_passage': EncodedPassage(pid='1', text='Hello', embedding=array([0.08925092, 0.773956  , 0.6545715 , 0.43887842], dtype=float32))}

# Index Building Pipeline

In [8]:
# Adapt single_encode to map over a corpus internally
encode_corpus = single_encode.as_node(
    input_mapping={"corpus": "passage"},
    output_mapping={"encoded_passage": "encoded_corpus"},
    map_over="corpus",
    name="encode_corpus",
)


@node(output_name="index")
def build_index(indexer: Indexer, encoded_corpus: List[EncodedPassage]) -> BaseIndex:
    return indexer.index(encoded_corpus)


# Pipeline: encode all passages, then build index
encode_and_index = Pipeline(nodes=[encode_corpus, build_index], name="encode_and_index")

In [9]:
# Visualize
encode_and_index.visualize(depth=2, show_legend=True)

In [10]:
# Run with corpus
corpus: List[Passage] = [
    Passage(pid="p1", text="Hello World"),
    Passage(pid="p2", text="The Quick Brown Fox"),
]

encoder = NumpyRandomEncoder(dim=4, seed=42)
indexer = SimpleIndexer(dim=encoder.dim)

outputs = encode_and_index.run(
    inputs={
        "corpus": corpus,
        "encoder": encoder,
        "indexer": indexer,
        "is_query": False,
    }
)
index: BaseIndex = outputs["index"]

In [11]:
# ---- Query encoding: extract -> encode --------------------------------------
@node(output_name="text")
def extract_query_text(query: Query) -> str:
    return query.text


# Query encoding pipeline
encode_query_pipeline = Pipeline(
    nodes=[extract_query_text, text_encode], name="encode_query"
)

# Use .as_node() to rename outputs
encode_query_step = encode_query_pipeline.as_node(
    output_mapping={"embedding": "query_vec"}, name="encode_query_step"
)

In [12]:
# ---- Retrieval + Reranking --------------------------------------------------
@node(output_name="retrieved")
def retrieve(
    index: BaseIndex, query_vec: Vector, top_k: int = 10
) -> List[RetrievedDoc]:
    hits = index.search(query_vec, top_k=top_k)
    return [
        RetrievedDoc(
            pid=h["pid"],
            text=index.get(h["pid"]).text,
            embedding=index.get(h["pid"]).embedding,
            score=h["score"],
        )
        for h in hits
    ]


@node(output_name="reranked_hits")
def rerank_hits(
    reranker: Reranker,
    query: Query,
    retrieved: List[RetrievedDoc],
    final_top_k: Optional[int] = None,
) -> List[RetrievedDoc]:
    return reranker.rerank(query, retrieved, top_k=final_top_k)


search_pipeline = Pipeline(
    nodes=[encode_query_step, retrieve, rerank_hits], name="search"
)

# Full pipeline
full_pipeline = Pipeline(nodes=[encode_and_index, search_pipeline], name="full_rag")
full_pipeline.visualize()

In [13]:
# ---- Usage Examples ---------------------------------------------------------
# Build index
corpus = [
    Passage(pid="p1", text="Hello World"),
    Passage(pid="p2", text="Quick Brown Fox"),
]
encoder = NumpyRandomEncoder(dim=4)
indexer = SimpleIndexer(dim=encoder.dim)

outputs = encode_and_index.run(
    inputs={
        "corpus": corpus,
        "encoder": encoder,
        "indexer": indexer,
        "is_query": False,
    }
)
index = outputs["index"]

In [14]:
reranker = IdentityReranker()

# Single query
search_out = search_pipeline.run(
    inputs={
        "query": Query(text="hello world"),
        "encoder": encoder,
        "index": index,
        "reranker": reranker,
        "top_k": 5,
        "final_top_k": 3,
        "is_query": True,
    }
)

for doc in search_out["reranked_hits"]:
    print(doc.pid, doc.score, doc.text)

p1 1.0 Hello World
p2 1.0 Quick Brown Fox


In [15]:
# Multiple queries
queries = [Query(text="hello"), Query(text="quick fox"), Query(text="world")]

batch_out = search_pipeline.map(
    inputs={
        "query": queries,
        "encoder": encoder,
        "index": index,
        "reranker": reranker,
        "top_k": 5,
        "final_top_k": 3,
        "is_query": True,
    },
    map_over="query",
)

for q, results in zip(queries, batch_out["reranked_hits"]):
    print(f"\nQuery: {q.text}")
    for r in results:
        print(f"  {r.pid} | {r.score:.3f} | {r.text}")


Query: hello
  p1 | 1.000 | Hello World
  p2 | 1.000 | Quick Brown Fox

Query: quick fox
  p1 | 1.000 | Hello World
  p2 | 1.000 | Quick Brown Fox

Query: world
  p1 | 1.000 | Hello World
  p2 | 1.000 | Quick Brown Fox
