In [8]:
!pip install --upgrade -q sentence-transformers ranx redisvl "redis-retrieval-optimizer>=0.2.0" datasets pandas openai scikit-learn h5py

# Setup

## Run a Redis instance

#### Mac

Remember to point at 6380, not 6379
```bash
docker run -d --name redis-stack \
  -p 6380:6379 -p 8001:8001 \
  -v redis-data:/data \
  redis/redis-stack:latest

```
#### For Colab
Use the shell script below to download, extract, and install [Redis Stack](https://redis.io/docs/getting-started/install-stack/) directly from the Redis package archive.

In [None]:
# NBVAL_SKIP
%%sh
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update  > /dev/null 2>&1
sudo apt-get install redis-stack-server  > /dev/null 2>&1
redis-stack-server --daemonize yes

### Define the Redis Connection URL

By default this notebook connects to the local instance of Redis Stack. **If you have your own Redis Enterprise instance** - replace REDIS_PASSWORD, REDIS_HOST and REDIS_PORT values with your own.

In [1]:
import os, sys, io
import warnings

warnings.filterwarnings("ignore")

# Replace values below with your own if using Redis Cloud instance
REDIS_HOST = os.getenv("REDIS_HOST", "localhost") # ex: "redis-18374.c253.us-central1-1.gce.cloud.redislabs.com"
REDIS_PORT = os.getenv("REDIS_PORT", "6380")      # ex: 18374
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", "")  
REDIS_DB = os.getenv("REDIS_DB", 0)

# If SSL is enabled on the endpoint, use rediss:// as the URL prefix
REDIS_URL = f"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}"

# Set Credentials for OpenAI

In [2]:
import os
from dotenv import load_dotenv
load_dotenv()

True

# Helpers

## Metrics Collection

In [43]:
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    classification_report,
    confusion_matrix
)
import os
import time
from datetime import datetime

def format_timestamp(ts=None):
    if ts is None:
        ts = datetime.now()
    elif isinstance(ts, (int, float)):
        ts = datetime.fromtimestamp(ts)
    return ts.strftime("%Y_%m_%d_%H%M%S")

class Timer:
    def __init__(self):
        self.timings = {}

    def timeit(self, name):
        class _T:
            def __enter__(_self):
                _self.start = time.perf_counter()
            def __exit__(_self, *a):
                end=round(time.perf_counter() - _self.start, 3)
                self.timings[name] = end
        return _T()

def classification_metrics(preds: list[str], labels: list[str]) -> dict:
    metrics = {
        "accuracy": accuracy_score(labels, preds),
        "precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
        "recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
        "f1_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "confusion_matrix": confusion_matrix(labels, preds).tolist(),
        "classification_report": classification_report(labels, preds, zero_division=0, output_dict=True)
    }
    return metrics

def print_experiment_summary(eval_metrics, 
                             router_config, timings, redis_schema=None, redis_index_info=None,
                             experiment_notes="NA", experiment_name="experiment", 
                             print_output=True, save_output=True):
    lines = []
    lines.append("=== Experiment Summary: Semantic Routing Classification ===\n")

    lines.append(f"Timestamp         : {format_timestamp()}")
    lines.append("")

    lines.append(f"Experiment Notes  : {experiment_notes}")
    lines.append("")

    # Data Config
    lines.append("DATA CONFIGURATION")
    lines.append(f"Router Name       : {router_config.get('name')}")
    lines.append(f"Batch Size        : {router_config.get('batch_size')}")
    lines.append(f"Dataset Repo ID    : {router_config.get('dataset_repo_id')}")
    lines.append(f"Dataset Domain    : {router_config.get('dataset_domain')}")
    lines.append(f"Dataset Name      : {router_config.get('dataset_name')}")
    lines.append(f"Train Limit       : {router_config.get('train_limit')}")
    lines.append(f"Test Limit        : {router_config.get('test_limit')}")
    vec_cfg = router_config.get('vectorizer_config', {})
    lines.append(f"Vector Dimensions : {vec_cfg.get('dimensions')}")
    lines.append(f"Embedding Model   : {vec_cfg.get('OPENAI_MODEL')}\n")

    # Redis Schema
    if redis_schema:
        idx = redis_schema.get('index', {})
        lines.append("REDIS SCHEMA")
        lines.append(f"Index Name        : {idx.get('name')}")
        lines.append(f"Prefix            : {idx.get('prefix')}")
        lines.append(f"Storage Type      : {idx.get('storage_type')}")
        lines.append("Fields:")
        for f in redis_schema.get('fields', []):
            if f['type'] == 'vector':
                attrs = f.get('attrs', {})
                lines.append(f"  - {f['name']} ({f['type']})  dims={attrs.get('dims')}, metric={attrs.get('distance_metric')}, algo={attrs.get('algorithm')}")
            else:
                lines.append(f"  - {f['name']} ({f['type']})")
    lines.append("")

    # Redis Index Info
    if redis_index_info:
        lines.append("REDIS INDEX INFO")
        lines.append(f"Documents         : {redis_index_info['num_docs']}")
        lines.append(f"Max Doc ID        : {redis_index_info['max_doc_id']}")
        lines.append(f"Num Terms         : {redis_index_info['num_terms']}")
        lines.append(f"Num Records       : {redis_index_info['num_records']}")
        lines.append(f"Inverted Size (MB): {float(redis_index_info['inverted_sz_mb']):.6f}")
        lines.append(f"Vector Index (MB) : {float(redis_index_info['vector_index_sz_mb']):.6f}\n")
    else:
        lines.append("No index info available")

    # Evaluation Metrics
    lines.append("EVALUATION METRICS")
    lines.append(f"Accuracy          : {eval_metrics['accuracy']:.4f}")
    lines.append(f"Precision (macro) : {eval_metrics['precision_macro']:.4f}")
    lines.append(f"Recall (macro)    : {eval_metrics['recall_macro']:.4f}")
    lines.append(f"F1 (macro)        : {eval_metrics['f1_macro']:.4f}\n")

    cm = eval_metrics['confusion_matrix']
    lines.append("CONFUSION MATRIX [ [TP_0, FP_0], [FN_1, TP_1] ]")
    lines.append(f"  {cm[0]}")
    lines.append(f"  {cm[1]}\n")

    lines.append("CLASSIFICATION REPORT")
    report = eval_metrics['classification_report']
    for label, metrics in report.items():
        if label in ['accuracy', 'macro avg', 'weighted avg']:
            continue
        lines.append(f"Class {label}: "
                     f"Precision={metrics['precision']:.3f}, "
                     f"Recall={metrics['recall']:.3f}, "
                     f"F1={metrics['f1-score']:.3f}, "
                     f"Support={int(metrics['support'])}")
    lines.append("")

    macro = report.get('macro avg', {})
    lines.append("Macro Avg:")
    lines.append(f"  Precision={macro.get('precision', 0):.3f}, "
                 f"Recall={macro.get('recall', 0):.3f}, "
                 f"F1={macro.get('f1-score', 0):.3f}\n")

    # Timings
    lines.append("TIMINGS (seconds)")
    total=sum(timings.values())
    for k, v in timings.items():
        lines.append(f"{k:<45}: {v:.3f}")
    lines.append(f"Total: {total:.3f}")
    lines.append("\n=============================================================\n")

    output_str = "\n".join(lines)

    if print_output:
        print(output_str)

    if save_output:
        os.makedirs("experiments", exist_ok=True)
        path = os.path.join("experiments", f"{format_timestamp()}_{experiment_name}.txt")
        with open(path, "w", encoding="utf-8") as f:
            f.write(output_str)

    return output_str

## OpenAI Embedding and Data Management

In [4]:

import os
from tqdm.auto import tqdm
from openai import OpenAI
from typing import List, Optional, Dict, Any, Tuple
import h5py, numpy as np

def save_embeddings_h5(
    path: str,
    labels: List[Any],
    embeds: List[List[float]],
    meta: Optional[Dict] = None
) -> None:
    E = np.asarray(embeds, dtype="float32")
    n, d = E.shape
    L = np.asarray([str(l) for l in labels], dtype=h5py.string_dtype("utf-8"))
    if len(L) != n:
        raise ValueError("labels length must match number of embeddings")

    with h5py.File(path, "a") as f:
        if "embeddings" not in f:
            f.create_dataset("embeddings", data=E, maxshape=(None, d),
                             chunks=True, compression="gzip", compression_opts=4)
            f.create_dataset("labels", data=L, maxshape=(None,),
                             chunks=True, compression="gzip", compression_opts=4)
            if meta:
                for k, v in meta.items():
                    f.attrs[k] = v
        else:
            de, dl = f["embeddings"], f["labels"]
            if de.shape[1] != d:
                raise ValueError(f"dim mismatch: file {de.shape[1]} vs new {d}")
            start = de.shape[0]
            de.resize(start + n, axis=0)
            dl.resize(start + n, axis=0)
            de[start:start+n] = E
            dl[start:start+n] = L

def load_embeddings_h5(path: str, n: Optional[int] = None) -> tuple[list[list[float]], list[str]]:
    with h5py.File(path, "r") as f:
        for name in ("embeddings", "labels"):
            if name not in f:
                raise ValueError(f"File missing '{name}' dataset")
        de, dl = f["embeddings"], f["labels"]
        total = de.shape[0]
        end = total if n is None else min(total, n)
        vectors = de[:end]
        labels = dl.asstr()[:end]
    return vectors.tolist(), labels.tolist()

def get_batch_embeddings(
    text_list: list[str], 
    batch_size: int = 512,
    dimensions: int = 128,
    model: str = "text-embedding-3-small",
    show_bar: bool = True
) -> list[list[float]]:
    client = OpenAI()
    embeddings: List[List[float]] = []
    texts = text_list
    iterator = range(0, len(texts), batch_size)
    pbar = tqdm(total=len(texts), unit="emb", desc="Embedding", disable=not show_bar)
    try:
        for i in iterator:
            batch = texts[i:i + batch_size]
            resp = client.embeddings.create(input=batch, model=model, dimensions=dimensions)
            embeddings.extend([d.embedding for d in resp.data])
            pbar.update(len(batch))
    finally:
        pbar.close()
    return embeddings


def get_batch_embeddings_to_h5(
    out_path: str,
    text_list: list[str],
    labels_list: list[str],
    batch_size: int = 512,
    dimensions: int = 128,
    model: str = "text-embedding-3-small",
    show_bar: bool = True,
    meta: Optional[Dict] = None,
    return_data: bool = False,
    append_existing: bool = False,
) -> Optional[Tuple[list[list[float]], list[str]]]:
    if len(text_list) != len(labels_list):
        raise ValueError("text_list and labels_list must be same length")

    if os.path.exists(out_path) and not append_existing:
        print(f"File '{out_path}' already exists. Skipping embedding generation.")
        return None

    client = OpenAI()
    ret_E: List[List[float]] = []
    ret_L: List[str] = []

    iterator = range(0, len(text_list), batch_size)
    pbar = tqdm(total=len(text_list), unit="emb", desc="Embedding", disable=not show_bar)

    try:
        for i in iterator:
            batch_t = text_list[i:i + batch_size]
            batch_l = labels_list[i:i + batch_size]
            resp = client.embeddings.create(input=batch_t, model=model, dimensions=dimensions)
            batch_e = [d.embedding for d in resp.data]

            save_embeddings_h5(out_path, batch_l, batch_e, meta=meta)

            if return_data:
                ret_E.extend(batch_e)
                ret_L.extend([str(x) for x in batch_l])

            pbar.update(len(batch_t))
    finally:
        pbar.close()

    return (ret_E, ret_L) if return_data else None

# Experiments

## Option 1: Custom Semantic Router - Direct Redis access - Allows for trying out techniques with Numpy Arrs

### Select and Download Datasets

In [None]:
import os 
from datasets import load_dataset

# See dataset page @ https://huggingface.co/datasets/ + repo_id
eval_datasets={
                "Topic": {
                    "DBLP":{
                        "repo_id": "waashk/dblp",
                        "train_lbl": "train",
                        "test_lbl": "test",
                        "text_col": "text",
                        "label_col": "label"
                      }
                    },
                "Sentiment": {
                    "MPQA":{
                        "repo_id": "jxm/mpqa",
                        "train_lbl": "train",
                        "test_lbl": "test",
                        "text_col": "sentence",
                        "label_col": "label"
                    }
                }
              }

router_config = {
    "name": "test",
    "batch_size": 1024,
    "vectorizer_config":{
        "dimensions":128,
        "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
        "OPENAI_MODEL": "text-embedding-3-small"
    },
    "dataset_domain": "Sentiment",
    "dataset_name": "MPQA",
}

redis_schema = {
    "index": {
        "name": "MPQA",
        "prefix": "test1",
        "storage_type": "hash",
    },
    "fields": [
        {"name": "label", "type": "tag"},
        {
            "name": "vector",
            "type": "vector",
            "attrs": {
                "dims": router_config["vectorizer_config"]["dimensions"],
                "distance_metric": "cosine",
                "algorithm": "hnsw",
                "datatype": "float32"
            }

        }
    ],
}

experiment_notes = ""

timer=Timer()

with timer.timeit("load_dataset"):
    dataset_repo_id = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["repo_id"]
    dataset_train_lbl = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["train_lbl"]
    dataset_test_lbl = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["test_lbl"]
    dataset_text_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["text_col"]
    dataset_label_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["label_col"]

    dataset = load_dataset(dataset_repo_id)
    train_data = dataset[dataset_train_lbl]
    test_data = dataset[dataset_test_lbl]

    router_config.update({"dataset_repo_id":dataset_repo_id})
    router_config.update({"train_limit":train_data.num_rows})
    router_config.update({"test_limit":test_data.num_rows})

# Define Experiment Names
safe_repo = dataset_repo_id.replace("/", "-")
experiment_name = f'{router_config["name"]}_{safe_repo}_{router_config["vectorizer_config"]["OPENAI_MODEL"]}_{router_config["vectorizer_config"]["dimensions"]}'
train_filepath = f'embeddings/{experiment_name}_{dataset_train_lbl}_{str(train_data.num_rows)}.h5'
test_filepath  = f'embeddings/{experiment_name}_{dataset_test_lbl}_{str(test_data.num_rows)}.h5'

### Pre-Generate Embeddings for Train and Test Data, Save to Embeddings Folder

In [None]:
os.makedirs("embeddings", exist_ok=True)

with timer.timeit("Embed and Save Train Data"):
    train_embeddings = get_batch_embeddings_to_h5(out_path=train_filepath,
                                                text_list=train_data[dataset_text_col],
                                                labels_list=train_data[dataset_label_col],
                                                batch_size=router_config["batch_size"],
                                                dimensions=router_config["vectorizer_config"]["dimensions"],
                                                model=router_config["vectorizer_config"]["OPENAI_MODEL"])

with timer.timeit("Embed and Save Test Data"):
    test_embeddings = get_batch_embeddings_to_h5(out_path=test_filepath,
                                                text_list=test_data[dataset_text_col],
                                                labels_list=test_data[dataset_label_col],
                                                batch_size=router_config["batch_size"],
                                                dimensions=router_config["vectorizer_config"]["dimensions"],
                                                model=router_config["vectorizer_config"]["OPENAI_MODEL"])

File 'embeddings/test_jxm-mpqa_text-embedding-3-small_128_train_8603.h5' already exists. Skipping embedding generation.
File 'embeddings/test_jxm-mpqa_text-embedding-3-small_128_test_2000.h5' already exists. Skipping embedding generation.


### Load and Query Functions

In [45]:
import numpy as np
import redis
from redisvl.query import VectorQuery
from redisvl.index import SearchIndex

def _kmeans(x, k, iters=20, seed=0):
    rng = np.random.default_rng(seed)
    k = min(k, len(x))
    idx = rng.choice(len(x), size=k, replace=False)
    c = x[idx].copy()
    for _ in range(iters):
        d = ((x[:, None, :] - c[None, :, :])**2).sum(-1)
        a = d.argmin(1)
        new_c = []
        for j in range(k):
            sj = (a == j)
            if sj.any():
                new_c.append(x[sj].mean(0))
            else:
                new_c.append(c[rng.integers(0, k)])
        new_c = np.stack(new_c, 0)
        if np.allclose(new_c, c):
            break
        c = new_c
    return c, a

def build_multi_prototypes(vectors, labels, mask=None, k_per_class=1, 
                           max_points_per_proto=None, normalize=True, seed=0):
    """
    If max_points_per_proto is set, k for class c becomes ceil(n_c / max_points_per_proto).
    Otherwise use fixed k_per_class.
    """
    labels = np.asarray(labels)
    V = np.asarray(vectors, dtype=np.float32)
    if mask is None:
        mask = np.linalg.norm(V, axis=1) > 0
    V, L = V[mask], labels[mask]

    out = []
    for cls in np.unique(L):
        sel = (L == cls)
        X = V[sel]
        if len(X) == 0:
            continue

        if max_points_per_proto is not None and max_points_per_proto > 0:
            k = int(np.ceil(len(X) / max_points_per_proto))
        else:
            k = int(k_per_class)
        k = max(1, min(k, len(X)))

        centers, assign = _kmeans(X, k=k, seed=seed)
        for j in range(len(centers)):
            sj = (assign == j)
            if not sj.any():
                continue
            mu = centers[j]
            if normalize:
                n = np.linalg.norm(mu)
                if n > 0:
                    mu = mu / (n + 1e-12)
            weight = float(sj.mean())  # fraction of class points in this prototype
            out.append({"vector": mu.astype(np.float32).tobytes(), "label": cls, "weight": weight})
    return out

def route_label(vec, labels, index, k=5):
    means = {}
    for lab in labels:
        q = VectorQuery(
            vector=list(map(float, vec)),
            vector_field_name="vector",
            num_results=k,
            return_fields=["label", "vector_distance"],
            filter_expression=f"@label:{{{lab}}}",
        )
        hits = index.query(q)
        if hits:
            dists = [float(h["vector_distance"]) for h in hits]
            means[lab] = sum(dists) / len(dists)
    if not means:
        return None, {}
    best = min(means, key=means.get)
    return best, means

def route_batch(vectors, labels, index, k=5):
    out = []
    for v in tqdm(vectors):
        best, _ = route_label(v, labels, index, k)
        out.append(best)
    return out

### Pre-Generated Experiment

In [64]:


with timer.timeit("Load and process train data, create Redis Index"):
    r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB)
    try:
        r.flushdb()
        r.ft(redis_schema["index"]["name"]).dropindex() 
    except:
        pass

    index = SearchIndex.from_dict(redis_schema, redis_url=REDIS_URL)
    index.create(overwrite=True)
    vectors,labels=load_embeddings_h5(train_filepath)
    unique_labels=list(set(labels))
    vectors = np.array(vectors)
    data = build_multi_prototypes(vectors, labels, mask=None, k_per_class=5, max_points_per_proto=None, normalize=True, seed=0)

    keys = index.load(data)
index.info()['num_docs']

10

In [47]:
with timer.timeit("Load and process test data"):
    test_vectors,test_labels=load_embeddings_h5(test_filepath, n=test_data.num_rows)
    test_vectors = np.array(test_vectors)
    test_vectors /= np.linalg.norm(test_vectors, axis=1, keepdims=True) + 1e-12
with timer.timeit("Route test data"):
    preds=route_batch(test_vectors, unique_labels, index, k=5)

eval_metrics = classification_metrics(
    preds=preds,
    labels=test_labels
)

100%|██████████| 2000/2000 [00:00<00:00, 2039.42it/s]


In [None]:
exp_summary = print_experiment_summary(eval_metrics, redis_schema, router_config, timer.timings, index.info(), experiment_name=experiment_name)

=== Experiment Summary: Semantic Routing Classification ===

Timestamp         : 2025_10_15_164430

DATA CONFIGURATION
Router Name       : test
Batch Size        : 1024
Dataset Domain    : Sentiment
Dataset Name      : MPQA
Train Limit       : 8603
Test Limit        : 2000
Vector Dimensions : 128
Embedding Model   : text-embedding-3-small

REDIS SCHEMA
Index Name        : MPQA
Prefix            : test1
Storage Type      : hash
Fields:
  - label (tag)
  - vector (vector)  dims=128, metric=cosine, algo=hnsw

REDIS INDEX INFO
Documents         : 10
Max Doc ID        : 10
Num Terms         : 0
Num Records       : 20
Inverted Size (MB): 0.000164
Vector Index (MB) : 0.724480

EVALUATION METRICS
Accuracy          : 0.8360
Precision (macro) : 0.8361
Recall (macro)    : 0.8360
F1 (macro)        : 0.8360

CONFUSION MATRIX [ [TP_0, FP_0], [FN_1, TP_1] ]
  [827, 173]
  [155, 845]

CLASSIFICATION REPORT
Class 0: Precision=0.842, Recall=0.827, F1=0.835, Support=1000
Class 1: Precision=0.830, Recall=



# Option 2: OOTB Semantic Routing - The Offering from RedisVL, slight changes to handle more data bulk

### Custom Semantic Router - Allows messing with dimensions + batch sizes for bulky preds

In [78]:
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type
from tqdm.auto import tqdm
import os
from redisvl.extensions.router import SemanticRouter
from redisvl.extensions.router import Route
from redisvl.utils.vectorize import OpenAITextVectorizer
from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache
from datasets import load_dataset, Dataset
from redisvl.extensions.router.schema import (
    DistanceAggregationMethod,
    Route,
    RouteMatch
)
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
from redisvl.redis.utils import hashify
from redisvl.query import FilterQuery, VectorRangeQuery
from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME
from redis.exceptions import ResponseError



def build_route(
    references: list[str],
    label: str,
    priority: int = 1,
    distance_threshold: float = 0.5
    ) -> Route:
  route = Route(
      name=label,
      references=references,
      metadata={"category": label, "priority": priority},
      distance_threshold=distance_threshold
  )
  return route

def build_routes_from_hf_data(
    dataset: Dataset,
    text_col: str = "text",
    label_col: str = "label",
    distance_threshold: float = 0.5
    ) -> list[Route]:
    '''
    Expects dataset to be of form:
    Dataset({
        features: ['text', 'label'],
        num_rows: 343152
    })
    '''
    df = dataset.to_pandas()[[text_col, label_col]]
    texts_by_label = df.groupby(label_col)[text_col].apply(list).to_dict()
    return [build_route(references=refs, label=str(lbl), distance_threshold=distance_threshold) for lbl, refs in texts_by_label.items()]


class HighVisOpenAITextVectorizer(OpenAITextVectorizer):
    @retry(
    wait=wait_random_exponential(min=1, max=60),
    stop=stop_after_attempt(6),
    retry=retry_if_not_exception_type(TypeError),
    )
    def _embed_many(
        self, texts: List[str], batch_size: int = 1024, dimensions: int = 1536, **kwargs
    ) -> List[List[float]]:
        """Exactly the same but with progress bar
        """
        if not isinstance(texts, list):
            raise TypeError("Must pass in a list of str values to embed.")
        if texts and not isinstance(texts[0], str):
            raise TypeError("Must pass in a list of str values to embed.")

        embeddings: List = []

        pbar = tqdm(total=len(texts), unit="emb", desc="Embedding", disable=False)
        for batch in tqdm(self.batchify(texts, batch_size)):
            try:
                response = self._client.embeddings.create(
                    input=batch, model=self.model, dimensions=1536, **kwargs
                )
                embeddings += [r.embedding for r in response.data]
                pbar.update(len(batch))
            except Exception as e:
                raise ValueError(f"Embedding texts failed: {e}")
        pbar.close()
        return embeddings

class BigBatchSemanticRouter(SemanticRouter):
    def _add_routes(self, routes: List[Route]):
        """Add routes to the router and index.

        Args:
            routes (List[Route]): List of routes to be added.
        """
        route_references: List[Dict[str, Any]] = []
        keys: List[str] = []

        for route in routes:
            # embed route references as a single batch
            reference_vectors = self.vectorizer.embed_many(
                [reference for reference in route.references], as_buffer=True, batch_size=1024, dimensions=1536
            )
            # set route references
            for i, reference in enumerate(route.references):
                reference_hash = hashify(reference)
                route_references.append(
                    {
                        "reference_id": reference_hash,
                        "route_name": route.name,
                        "reference": reference,
                        "vector": reference_vectors[i],
                    }
                )
                keys.append(
                    self._route_ref_key(self._index, route.name, reference_hash)
                )
            # set route if does not yet exist client side
            if not self.get(route.name):
                self.routes.append(route)
        self._index.load(route_references, keys=keys)

    def _get_route_matches(
        self,
        vector: List[float],
        aggregation_method: DistanceAggregationMethod,
        max_k: int = 1,
    ) -> List[RouteMatch]:
        """Get route response from vector db"""

        # what's interesting about this is that we only provide one distance_threshold for a range query not multiple
        # therefore you might take the max_threshold and further refine from there.
        distance_threshold = max(route.distance_threshold for route in self.routes)

        vector_range_query = VectorRangeQuery(
            vector=vector,
            vector_field_name=ROUTE_VECTOR_FIELD_NAME,
            distance_threshold=float(distance_threshold),
            return_fields=["route_name"],
        )

        aggregate_request = self._build_aggregate_request(
            vector_range_query, aggregation_method, max_k=max_k
        )

        try:
            aggregation_result: AggregateResult = self._index.aggregate(
                aggregate_request, vector_range_query.params
            )
        except ResponseError as e:
            if "VSS is not yet supported on FT.AGGREGATE" in str(e):
                raise RuntimeError(
                    "Semantic routing is only available on Redis version 7.x.x or greater"
                )
            raise e

        # process aggregation results into route matches
        return [
            self._process_route(route_match) for route_match in aggregation_result.rows
        ]
    
    def _classify_multi_route(
        self,
        vector: List[float],
        max_k: int,
        aggregation_method: DistanceAggregationMethod,
    ) -> List[RouteMatch]:
        """Classify to multiple routes, up to max_k (int), using a vector."""

        route_matches = self._get_route_matches(vector, aggregation_method, max_k=max_k)

        # process route matches
        top_route_matches: List[RouteMatch] = []
        if route_matches:
            for route_match in route_matches:
                if route_match.name is not None:
                    top_route_matches.append(route_match)
                else:
                    raise ValueError(
                        f"{route_match.name} not a supported route for the {self.name} semantic router."
                    )

        return top_route_matches

    def bulk_route(
        self,
        statements: Optional[List[str]] = None,
        vectors: Optional[List[List[float]]] = None,
        max_k: Optional[int] = None,
        distance_threshold: Optional[float] = None,
        aggregation_method: Optional[DistanceAggregationMethod] = None,
    ) -> List[List[RouteMatch]]:
        """For mass routing.
        """
        if not vectors:
            if not statements:
                raise ValueError("Must provide a list of vectors or statements to the router")
            vectors = self.vectorizer.embed_many(statements, batch_size=1024, dimensions = 1536)  # type: ignore

        max_k = max_k or self.routing_config.max_k
        aggregation_method = (
            aggregation_method or self.routing_config.aggregation_method
        )

        pbar = tqdm(total=len(vectors), unit="pred", desc="Assessing", disable=False)
        results: List[List[RouteMatch]] = []
        for v in vectors:
            matches = self._classify_multi_route(v, max_k, aggregation_method)  # type: ignore
            if distance_threshold is not None:
                try:
                    matches = [m for m in matches if m.distance <= distance_threshold]  # type: ignore
                except AttributeError:
                    pass
            results.append(matches)
            pbar.update(1)
        pbar.close()
        return results

def nuke_redis():
    import redis
    try:
        r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB)
        r.flushdb()
    except:
        pass



### Custom Router Optimizer

In [68]:
import random
from typing import Any, Callable, Dict, List

import numpy as np
from ranx import Qrels, Run, evaluate
from redisvl.extensions.router.semantic import SemanticRouter

from redis_retrieval_optimizer.threshold_optimization.base import (
    BaseThresholdOptimizer,
    EvalMetric,
)
from redis_retrieval_optimizer.threshold_optimization.schema import LabeledData
from redis_retrieval_optimizer.threshold_optimization.utils import (
    NULL_RESPONSE_KEY,
    _format_qrels,
)


def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> "Run":
    """Format router results into format for ranx Run"""
    if Run is None:
        raise ImportError("ranx is required for threshold optimization")
    if np is None:
        raise ImportError("numpy is required for threshold optimization")

    run_dict: Dict[Any, Any] = {}

    bulk_route = router.bulk_route(statements=[td.query for td in test_data], max_k=1)
    print(len(bulk_route))
    for td, match in zip(test_data, bulk_route):
        run_dict[td.id] = {}
        if match:
            run_dict[td.id][match[0].name] = np.int64(1)
        else:
            run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)

    # for td in test_data:
    #     run_dict[td.id] = {}
    #     route_match = router(td.query)
    #     if route_match and route_match.name == td.query_match:
    #         run_dict[td.id][td.query_match] = np.int64(1)
    #     else:
    #         run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1)

    return Run(run_dict)


def _eval_router(
    router: SemanticRouter,
    test_data: List[LabeledData],
    qrels: "Qrels",
    eval_metric: str,
) -> float:
    """Evaluate acceptable metric given run and qrels data"""
    if evaluate is None:
        raise ImportError("ranx is required for threshold optimization")

    run = _generate_run_router(test_data, router)
    return evaluate(qrels, run, eval_metric, make_comparable=True)


def _router_random_search(
    route_names: List[str], route_thresholds: dict, search_step=0.10
):
    """Performs random search for many thresholds to many routes"""
    if np is None:
        raise ImportError("numpy is required for threshold optimization")

    score_threshold_values = []
    for route in route_names:
        score_threshold_values.append(
            np.linspace(
                start=max(route_thresholds[route] - search_step, 0),
                stop=route_thresholds[route] + search_step,
                num=100,
            )
        )

    return {
        route: float(random.choice(score_threshold_values[i]))
        for i, route in enumerate(route_names)
    }


def _random_search_opt_router(
    router: SemanticRouter,
    test_data: List[LabeledData],
    qrels: "Qrels",
    eval_metric: EvalMetric,
    **kwargs: Any,
):
    """Performs complete optimization for router cases provide acceptable metric"""
    print("Starting Optimization")

    start_score = _eval_router(router, test_data, qrels, eval_metric.value)
    best_score = start_score
    best_thresholds = router.route_thresholds

    max_iterations = kwargs.get("max_iterations", 20)
    search_step = kwargs.get("search_step", 0.10)

    pbar = tqdm(total=max_iterations, desc="Optimizing Routes")
    for _ in range(max_iterations):
        route_names = router.route_names
        route_thresholds = router.route_thresholds
        thresholds = _router_random_search(
            route_names=route_names,
            route_thresholds=route_thresholds,
            search_step=search_step,
        )
        router.update_route_thresholds(thresholds)
        print("Eval starting")
        score = _eval_router(router, test_data, qrels, eval_metric.value)
        if score > best_score:
            best_score = score
            best_thresholds = thresholds
        pbar.update(1)

    print(
        f"Eval metric {eval_metric.value.upper()}: start {round(start_score, 3)}, end {round(best_score, 3)} \nEnding thresholds: {router.route_thresholds}"
    )
    router.update_route_thresholds(best_thresholds)
    pbar.close()


class BulkRouterThresholdOptimizer(BaseThresholdOptimizer):
    def __init__(
        self,
        router: SemanticRouter,
        test_dict: List[Dict[str, Any]],
        opt_fn: Callable = _random_search_opt_router,
        eval_metric: str = "f1",
    ):
        """Initialize the router optimizer.

        Args:
            router (SemanticRouter): The RedisVL SemanticRouter instance to optimize.
            test_dict (List[Dict[str, Any]]): List of test cases.
            opt_fn (Callable): Function to perform optimization. Defaults to
                grid search.
            eval_metric (str): Evaluation metric for threshold optimization.
                Defaults to "f1" score.
        Raises:
            ValueError: If the test_dict not in LabeledData format.
        """
        super().__init__(router, test_dict, opt_fn, eval_metric)

    def optimize(self, **kwargs: Any):
        """Optimize kicks off the optimization process for router"""
        qrels = _format_qrels(self.test_data)
        self.opt_fn(self.optimizable, self.test_data, qrels, self.eval_metric, **kwargs)

### Experiment

In [7]:
nuke_redis()

In [75]:
from datasets import load_dataset, concatenate_datasets

eval_datasets={
                "Topic": {
                    "DBLP":{
                        "repo_id": "waashk/dblp",
                        "train_lbl": "train",
                        "test_lbl": "test",
                        "text_col": "text",
                        "label_col": "label"
                      }
                    },
                "Sentiment": {
                    "SST-2":{
                        "repo_id": "stanfordnlp/sst2",
                        "train_lbl": "train",
                        "test_lbl": "test",
                        "text_col": "sentence",
                        "label_col": "label",
                        "alternate_scoring_source": {
                            "repo_id": "SetFit/sst2",
                            "train_lbl": "train",
                            "test_lbl": "test",
                            "text_col": "text",
                            "label_col": "label"
                        }
                    },
                    "MPQA":{
                        "repo_id": "jxm/mpqa",
                        "train_lbl": "train",
                        "test_lbl": "test",
                        "text_col": "sentence",
                        "label_col": "label"
                    }
                }
              }

router_config = {
    "name": "OOTB",
    "batch_size": 1024,
    "vectorizer_config":{
        "dimensions":1536,
        "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY"),
        "OPENAI_MODEL": "text-embedding-3-small"
    },
    "dataset_domain": "Sentiment",
    "dataset_name": "SST-2",
}

# experiment_notes = ""
# dataset_repo_id = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["repo_id"]
# dataset = load_dataset(dataset_repo_id)

timer=Timer()

with timer.timeit("load_dataset"):
    dataset_repo_id = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["repo_id"]
    dataset_train_lbl = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["train_lbl"]
    dataset_test_lbl = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["test_lbl"]
    dataset_text_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["text_col"]
    dataset_label_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["label_col"]

    dataset = load_dataset(dataset_repo_id)

    # Merge dev with train
    all_splits = list(dataset.keys())
    extra_splits = [s for s in all_splits if s not in (dataset_train_lbl, dataset_test_lbl)]
    if extra_splits:
        train_data = concatenate_datasets([dataset[dataset_train_lbl]] + [dataset[s] for s in extra_splits])
    else:
        train_data = dataset[dataset_train_lbl][dataset_text_col]

    if "alternate_scoring_source" in eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]:
        altdataset_repo_id = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["alternate_scoring_source"]["repo_id"]
        altdataset_test_lbl = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["alternate_scoring_source"]["test_lbl"]
        altdataset = load_dataset(altdataset_repo_id)
        test_text_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["alternate_scoring_source"]["text_col"]
        test_label_col = eval_datasets[router_config["dataset_domain"]][router_config["dataset_name"]]["alternate_scoring_source"]["label_col"]
        test_data = altdataset[altdataset_test_lbl]
    else:
        test_text_col = dataset_text_col
        test_label_col = dataset_label_col
        test_data = dataset[dataset_test_lbl]

    router_config.update({"dataset_repo_id":dataset_repo_id})
    router_config.update({"train_limit":train_data.num_rows})
    router_config.update({"test_limit":test_data.num_rows})

# Define Experiment Names
safe_repo = dataset_repo_id.replace("/", "-")
experiment_name = f'{router_config["name"]}_{safe_repo}_{router_config["vectorizer_config"]["OPENAI_MODEL"]}_{router_config["vectorizer_config"]["dimensions"]}'

Repo card metadata block was not found. Setting CardData to empty.




In [61]:
# altdataset


In [64]:
print(train_data)
print(test_data)
print(test_data['label'])

Dataset({
    features: ['idx', 'sentence', 'label'],
    num_rows: 68221
})
Dataset({
    features: ['text', 'label', 'label_text'],
    num_rows: 1821
})
Column([0, 0, 0, 0, 1])


In [None]:
timer=Timer()

with timer.timeit("Build Routes"):
    routes = build_routes_from_hf_data(dataset=train_data, text_col=dataset_text_col, label_col=dataset_label_col, distance_threshold=0.9)

vectorizer=HighVisOpenAITextVectorizer(
    model=router_config["vectorizer_config"]["OPENAI_MODEL"],
    api_config={"api_key": router_config["vectorizer_config"]["OPENAI_API_KEY"]},
    cache=EmbeddingsCache(redis_url=REDIS_URL)
)

with timer.timeit("Build Router"):
    router = BigBatchSemanticRouter(
        name=router_config["name"],
        routes=routes,
        vectorizer=vectorizer,
        redis_url=REDIS_URL,
        overwrite=True
    )

20:28:29 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
20:28:29 redisvl.index.index INFO   Index already exists, overwriting.


Embedding:   0%|          | 0/30208 [00:00<?, ?emb/s]

20:28:31 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:   3%|▎         | 1024/30208 [00:03<01:42, 284.31emb/s]

20:28:35 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:   7%|▋         | 2048/30208 [00:06<01:28, 319.27emb/s]

20:28:38 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  10%|█         | 3072/30208 [00:09<01:28, 308.24emb/s]

20:28:41 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  14%|█▎        | 4096/30208 [00:13<01:23, 314.25emb/s]

20:28:44 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  17%|█▋        | 5120/30208 [00:16<01:19, 316.58emb/s]

20:28:47 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  20%|██        | 6144/30208 [00:19<01:16, 316.10emb/s]

20:28:51 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  24%|██▎       | 7168/30208 [00:23<01:15, 305.26emb/s]

20:28:54 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  27%|██▋       | 8192/30208 [00:25<01:06, 331.22emb/s]

20:28:57 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  31%|███       | 9216/30208 [00:28<01:04, 326.92emb/s]

20:29:00 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  34%|███▍      | 10240/30208 [00:31<00:57, 345.26emb/s]

20:29:03 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  37%|███▋      | 11264/30208 [00:35<00:59, 320.87emb/s]

20:29:06 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  41%|████      | 12288/30208 [00:36<00:47, 375.79emb/s]

20:29:08 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  44%|████▍     | 13312/30208 [00:39<00:45, 372.72emb/s]

20:29:11 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  47%|████▋     | 14336/30208 [00:42<00:41, 383.09emb/s]

20:29:13 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  51%|█████     | 15360/30208 [00:44<00:38, 389.90emb/s]

20:29:16 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  54%|█████▍    | 16384/30208 [00:47<00:36, 380.35emb/s]

20:29:18 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  58%|█████▊    | 17408/30208 [00:49<00:31, 409.79emb/s]

20:29:20 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  61%|██████    | 18432/30208 [00:51<00:26, 446.18emb/s]

20:29:23 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  64%|██████▍   | 19456/30208 [00:54<00:27, 396.11emb/s]

20:29:26 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  68%|██████▊   | 20480/30208 [00:57<00:26, 371.67emb/s]

20:29:29 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  71%|███████   | 21504/30208 [01:00<00:24, 358.38emb/s]

20:29:32 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  75%|███████▍  | 22528/30208 [01:03<00:21, 356.91emb/s]

20:29:34 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  78%|███████▊  | 23552/30208 [01:06<00:19, 349.08emb/s]

20:29:38 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  81%|████████▏ | 24576/30208 [01:10<00:16, 343.48emb/s]

20:29:41 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  85%|████████▍ | 25600/30208 [01:13<00:13, 339.61emb/s]

20:29:44 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  88%|████████▊ | 26624/30208 [01:15<00:09, 368.92emb/s]

20:29:46 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  92%|█████████▏| 27648/30208 [01:18<00:07, 357.47emb/s]

20:29:50 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  95%|█████████▍| 28672/30208 [01:22<00:04, 326.39emb/s]

20:29:53 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  98%|█████████▊| 29696/30208 [01:25<00:01, 310.35emb/s]

20:29:57 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


30it [01:28,  2.95s/it]███| 30208/30208 [01:28<00:00, 280.00emb/s]
Embedding: 100%|██████████| 30208/30208 [01:28<00:00, 341.32emb/s]
Embedding:   0%|          | 0/38013 [00:00<?, ?emb/s]

20:30:01 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:   3%|▎         | 1024/38013 [00:04<02:25, 253.71emb/s]

20:30:04 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:   5%|▌         | 2048/38013 [00:06<01:43, 348.13emb/s]

20:30:08 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:   8%|▊         | 3072/38013 [00:09<01:50, 315.69emb/s]

20:30:11 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  11%|█         | 4096/38013 [00:12<01:36, 353.24emb/s]

20:30:13 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  13%|█▎        | 5120/38013 [00:14<01:27, 375.32emb/s]

20:30:16 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  16%|█▌        | 6144/38013 [00:17<01:23, 383.76emb/s]

20:30:18 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  19%|█▉        | 7168/38013 [00:19<01:14, 411.88emb/s]

20:30:21 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  22%|██▏       | 8192/38013 [00:23<01:29, 333.59emb/s]

20:30:24 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  24%|██▍       | 9216/38013 [00:26<01:25, 337.70emb/s]

20:30:27 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  27%|██▋       | 10240/38013 [00:28<01:12, 382.60emb/s]

20:30:29 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  30%|██▉       | 11264/38013 [00:30<01:03, 421.42emb/s]

20:30:31 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  32%|███▏      | 12288/38013 [00:33<01:06, 385.95emb/s]

20:30:34 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  35%|███▌      | 13312/38013 [00:36<01:04, 380.05emb/s]

20:30:37 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  38%|███▊      | 14336/38013 [00:39<01:06, 353.86emb/s]

20:30:40 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  40%|████      | 15360/38013 [00:42<01:05, 347.41emb/s]

20:30:43 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  43%|████▎     | 16384/38013 [00:44<00:54, 395.47emb/s]

20:30:46 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  46%|████▌     | 17408/38013 [00:47<00:52, 393.15emb/s]

20:30:48 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  48%|████▊     | 18432/38013 [00:50<00:54, 359.03emb/s]

20:30:51 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  51%|█████     | 19456/38013 [00:53<00:52, 351.66emb/s]

20:30:54 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  54%|█████▍    | 20480/38013 [00:56<00:49, 352.85emb/s]

20:30:57 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  57%|█████▋    | 21504/38013 [00:59<00:46, 352.68emb/s]

20:31:00 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  59%|█████▉    | 22528/38013 [01:01<00:38, 400.79emb/s]

20:31:02 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  62%|██████▏   | 23552/38013 [01:04<00:40, 358.35emb/s]

20:31:05 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  65%|██████▍   | 24576/38013 [01:06<00:33, 401.31emb/s]

20:31:07 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  67%|██████▋   | 25600/38013 [01:09<00:31, 392.68emb/s]

20:31:10 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  70%|███████   | 26624/38013 [01:12<00:29, 382.58emb/s]

20:31:13 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  73%|███████▎  | 27648/38013 [01:14<00:26, 393.89emb/s]

20:31:15 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  75%|███████▌  | 28672/38013 [01:16<00:22, 420.67emb/s]

20:31:18 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  78%|███████▊  | 29696/38013 [01:20<00:22, 369.80emb/s]

20:31:21 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  81%|████████  | 30720/38013 [01:22<00:20, 363.90emb/s]

20:31:24 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  84%|████████▎ | 31744/38013 [01:25<00:15, 396.15emb/s]

20:31:26 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  86%|████████▌ | 32768/38013 [01:27<00:13, 392.20emb/s]

20:31:30 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  89%|████████▉ | 33792/38013 [01:32<00:13, 313.13emb/s]

20:31:33 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  92%|█████████▏| 34816/38013 [01:35<00:10, 311.86emb/s]

20:31:36 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  94%|█████████▍| 35840/38013 [01:38<00:06, 317.12emb/s]

20:31:39 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  97%|█████████▋| 36864/38013 [01:41<00:03, 334.35emb/s]

20:31:43 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding: 100%|█████████▉| 37888/38013 [01:45<00:00, 309.04emb/s]

20:31:46 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


38it [01:46,  2.81s/it]███| 38013/38013 [01:46<00:00, 284.88emb/s]
Embedding: 100%|██████████| 38013/38013 [01:46<00:00, 356.56emb/s]


In [None]:
''' 
Optional for now
'''

# train_data_arr=[{"query": train_data[dataset_text_col][i], "query_match":str(train_data[dataset_label_col][i])} for i in range(len(train_data))]
# optimizer = BulkRouterThresholdOptimizer(router, train_data_arr)
# optimizer.optimize()

In [77]:
with timer.timeit("Predict using test data"):
    test_statements=[str(i) for i in test_data[test_text_col]]
    preds = router.bulk_route(statements=test_statements, max_k=1)

Embedding:   0%|          | 0/1821 [00:00<?, ?emb/s]

20:51:46 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


Embedding:  56%|█████▌    | 1024/1821 [00:03<00:03, 264.95emb/s]

20:51:49 httpx INFO   HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


2it [00:05,  2.83s/it]████| 1821/1821 [00:05<00:00, 335.98emb/s]
Embedding: 100%|██████████| 1821/1821 [00:05<00:00, 321.36emb/s]


In [80]:
extracted_preds = [i[0].name for i in preds]
labels = [str(i) for i in test_data[test_label_col]]

eval_metrics = classification_metrics(
    preds=extracted_preds,
    labels=labels
)

exp_summary = print_experiment_summary(eval_metrics, router_config, timer.timings, experiment_name=experiment_name)

=== Experiment Summary: Semantic Routing Classification ===

Timestamp         : 2025_10_15_205329

Experiment Notes  : NA

DATA CONFIGURATION
Router Name       : OOTB
Batch Size        : 1024
Dataset Repo ID    : stanfordnlp/sst2
Dataset Domain    : Sentiment
Dataset Name      : SST-2
Train Limit       : 68221
Test Limit        : 1821
Vector Dimensions : 1536
Embedding Model   : text-embedding-3-small


No index info available
EVALUATION METRICS
Accuracy          : 0.9121
Precision (macro) : 0.9138
Recall (macro)    : 0.9122
F1 (macro)        : 0.9121

CONFUSION MATRIX [ [TP_0, FP_0], [FN_1, TP_1] ]
  [803, 109]
  [51, 858]

CLASSIFICATION REPORT
Class 0: Precision=0.940, Recall=0.880, F1=0.909, Support=912
Class 1: Precision=0.887, Recall=0.944, F1=0.915, Support=909

Macro Avg:
  Precision=0.914, Recall=0.912, F1=0.912

TIMINGS (seconds)
load_dataset                                 : 10.004
Predict using test data                      : 95.192
Total: 105.196


