# Base Pipeline

![](../assets/classification_pipeline.png)

## Modules

### data preparation

first one need to split data to train and test and then launch client with vector db

In [1]:
import json

banking77 = json.load(open('../data/records/banking77.json'))
banking77[0]

{'intent_id': 0,
 'intent_name': 'activate_my_card',
 'sample_utterances': ["Please help me with my card.  It won't activate.",
  'I tired but an unable to activate my card.',
  'I want to start using my card.',
  'How do I verify my new card?',
  "I tried activating my plug-in and it didn't piece of work"],
 'regexp_for_sampling': [],
 'regexp_as_rules': []}

In [2]:
import itertools as it
from sklearn.model_selection import train_test_split


def get_sample_utterances(dataset: list[dict]):
    """get plain list of all sample utterances and their intent labels"""
    utterances = [intent['sample_utterances'] for intent in dataset]
    labels = [[intent['intent_id']] * len(uts) for intent, uts in zip(dataset, utterances)]

    utterances = list(it.chain.from_iterable(utterances))
    labels = list(it.chain.from_iterable(labels))

    return utterances, labels


def split_sample_utterances(dataset: list[dict]):
    """
    Return: utterances_train, utterances_test, labels_train, labels_test
    
    TODO: ensure stratified train test splitting (test set must contain all classes)
    """

    utterances, labels = get_sample_utterances(dataset)

    return train_test_split(
        utterances,
        labels,
        test_size=0.25,
        random_state=0,
        stratify=labels,
        shuffle=True
    )

In [3]:
utterances_train, utterances_test, labels_train, labels_test = split_sample_utterances(banking77)
len(utterances_train), len(utterances_test)

(288, 97)

In [4]:
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction as EmbFunc
from chromadb import ClientAPI


def create_collection(
    utterances: list[str],
    labels: list[int],
    client: ClientAPI,
    name="example_collection",
    embedder_name="Alibaba-NLP/gte-base-en-v1.5",
):
    labels_set = set(labels)
    n_classes = len(labels_set)
    assert set(range(n_classes)) == labels_set, "labels must be from [0,n_classes-1]"

    collection = client.get_or_create_collection(
        name=name,
        embedding_function=EmbFunc(model_name=embedder_name, trust_remote_code=True),
        metadata={'n_classes': n_classes}
    )

    collection.add(
        documents=utterances,
        ids=[str(i) for i in range(len(utterances))],
        metadatas=[{'intent_id': lab} for lab in labels]
    )
    
    return collection

In [5]:
from chromadb import PersistentClient


client = PersistentClient(path='../data/chroma')
client.delete_collection("example_collection")

collection = create_collection(
    utterances_train,
    labels_train,
    client
)
collection.count()

  from tqdm.autonotebook import tqdm, trange


288

In [None]:
import os
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction as EmbFunc


class DataHandler:
    def __init__(self, intent_records: os.PathLike, db_path: os.PathLike = '../data/chroma'):
        self.utterances_train, self.utterances_test, self.labels_train, self.labels_test = split_sample_utterances(intent_records)

        self.client = PersistentClient(path=db_path)
    
    def create_collection(self, model_name: str, db_name: str = "example_collection"):
        collection = client.get_or_create_collection(
            name=db_name,
            embedding_function=EmbFunc(model_name=model_name, trust_remote_code=True),
            metadata={'n_classes': len(set(self.labels_train))}
        )
        collection.add(
            documents=self.utterances_train,
            ids=[str(i) for i in range(len(self.utterances_train))],
            metadatas=[{'intent_id': lab} for lab in self.labels_train]
        )
        self.collection = collection
        return collection
        

In [None]:
from typing import Callable


class Module:
    def fit(self, data_handler: DataHandler):
        pass

    def score(self, data_handler: DataHandler, metric_fn: Callable) -> float:
        pass

    def fit_score(self, data_handler: DataHandler, metric_fn: Callable) -> float:
        self.fit(data_handler)
        return self.score(data_handler, metric_fn)

### RegExp

In [6]:
import json

dream = json.load(open('../data/records/dream.json'))
dream[0]

{'intent_id': 0,
 'intent_name': 'what_are_you_talking_about',
 'sample_utterances': [],
 'regexp_for_sampling': ['(alexa ){0,1}what are ((you)|(we)) ((talking about)|(discussing))',
  '(alexa ){0,1}what ((you)|(we)) are (even ){0,1}((talking about)|(discussing))',
  '(alexa ){0,1}what does it mean',
  '(alexa ){0,1}pass that by me again',
  "(alexa ){0,1}i ((don't)|(didn't)|(do not)|(did not)) get it",
  '(alexa ){0,1}what it is about',
  '(alexa ){0,1}what is it about',
  'i lost common ground',
  '(alexa ){0,1}what (even ){0,1}is that',
  "(i ((did not get)|(don't understand)|(don't get)) ){0,1}what do you mean( alexa){0,1}",
  "(sorry, ){0,1}i ((don't)|(do not)|(didn't)|(did not)) ((understand)|(get))( ((what you mean)|(what are you talking about)))( alexa){0,1}",
  '((what you mean)|(what are you talking about))( alexa){0,1}',
  "i don't know what you just said"],
 'regexp_as_rules': ['(alexa ){0,1}are we having a communication problem',
  "(alexa ){0,1}i don't think you understan

In [7]:
import re

def regexp(utterance: str, intents_patterns: list[dict]):
    detected = set()
    for intent in intents_patterns:
        for pattern in intent['regexp_for_sampling'] + intent['regexp_as_rules']:
            if re.match(pattern, utterance) is None:
                continue
            detected.add(intent['intent_id'])
    return detected

In [8]:
regexp(
    utterance='what are you talking about',
    intents_patterns=dream
)

{0, 5}

In [9]:
regexp(
    utterance='tell me something else',
    intents_patterns=dream
)

{1, 6}

In [10]:
regexp(
    utterance='kind of',
    intents_patterns=dream
)

{6}

### Retrieval

In [11]:
from chromadb import Collection


def retrieval(utterance: str, collection: Collection, k: int):
    query_res = collection.query(
        query_texts=[utterance],
        n_results=k,
        include=["metadatas", "documents"]  # one can add "embeddings", "distances"
    )
    return query_res

In [12]:
utterance = 'i want a new card'
query_res = retrieval(utterance, collection, k=3)

In [13]:
query_res

{'ids': [['40', '23', '240']],
 'distances': None,
 'metadatas': [[{'intent_id': 39}, {'intent_id': 39}, {'intent_id': 43}]],
 'embeddings': None,
 'documents': [['I want some extra physical cards.',
   "I'd like to order an additional card",
   'Can I request a card?']],
 'uris': None,
 'data': None,
 'included': ['metadatas', 'documents']}

In [None]:
from typing import Callable


class VectorDBModule(Module):
    def __init__(self, model_name: str, k: int):
        self.model_name = model_name
        self.k = k

    def fit(self, data_handler: DataHandler):
        data_handler.create_collection(self.model_name)

    def score(self, data_handler: DataHandler, metric_fn: Callable):
        query_res = data_handler.collection.query(
            query_texts=data_handler.utterances_test,
            n_results=self.k,
            include=["metadatas", "documents"]  # one can add "embeddings", "distances"
        )
        labels_pred = [[cand['intent_id'] for cand in candidates] for candidates in query_res['metadatas']]
        return metric_fn(labels_test, labels_pred)

### Scoring

modules:
- knn
- linear
- dnnc

In [None]:
# from typing import Literal


# class ScorerOptimizer:
#     available_modules = {
#         'knn': KNNScorer,
#         'linear': LinearScorer,
#         'dnnc': DNNCScorer,
#     }
#     def __init__(self, module_type: Literal['knn', 'linear', 'dnnc'], **hyperparams):
#         self.module = self.available_modules[module_type](**hyperparams)
    
#     def fit()

In [14]:
import numpy as np


class ScoringModule(Module):
    def score(self, data_handler: DataHandler, metric_fn: Callable):
        probas = self.predict(data_handler.utterances_test)
        return metric_fn(data_handler.labels_test, probas)

    def predict(self, utterances: list[str]):
        raise NotImplementedError()

    def predict_topk(self, utterance: str, k=3):
        scores = self.predict(utterance)
        top_indices = np.argpartition(scores, kth=-k)[-k:]
        top_scores = scores[top_indices]
        return top_indices[np.argsort(top_scores)][::-1]

#### knn

In [15]:
import numpy as np

class KNNScorer(ScoringModule):
    """
    TODO:
    - add weighted knn?
    """
    def __init__(self, k):
        self.k = k

    def fit(self, data_handler: DataHandler):
        self._collection = data_handler.collection
        self._n_classes = data_handler.collection.metadata['n_classes']

    def predict(self, utterances: list[str]):
        """
        TODO: test this code
        """
        query_res = self._collection.query(
            query_texts=utterances,
            n_results=self.k,
            include=["metadatas", "documents"]  # one can add "embeddings", "distances"
        )
        labels_pred = [[cand['intent_id'] for cand in candidates] for candidates in query_res['metadatas']]
        y = np.array(labels_pred)

        n_queries = len(utterances)
        n_classes = self._collection.metadata['n_classes']
        y += n_classes * np.arange(n_queries)[:, None]
        counts = np.bincount(y.ravel(), minlength=n_classes*n_queries).reshape(n_queries, n_classes)
        
        return counts / counts.sum(axis=1, keepdims=True)

In [16]:
knn_scorer = KNNScorer(k=10)
knn_scorer.fit(collection)
knn_scorer.predict_proba('i want a new card')

array([0.1, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0. , 0. , 0.1,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0.2, 0.2, 0.1, 0. , 0.1, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,
       0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ])

In [17]:
knn_scorer.predict_topk('i want a new card')

array([40, 39, 12])

#### linear

In [18]:
from sklearn.linear_model import LogisticRegressionCV


class LinearScorer(ScoringModule):
    """
    TODO:
    - implement different modes (incremental learning with SGD and simple learning with LogisticRegression)
    - control n_jobs
    - adjust cv
    - ensure that embeddings of train set are not recalculated
    """

    def fit(self, data_handler: DataHandler):
        dataset = data_handler.collection.get(include=['embeddings', 'metadatas'])
        features = dataset['embeddings']
        labels = [dct['intent_id'] for dct in dataset['metadatas']]
        clf = LogisticRegressionCV(cv=3, n_jobs=8, multi_class='multinomial')
        clf.fit(features, labels)

        self._clf = clf
        self._emb_func = data_handler.collection._embedding_function
    
    def predict(self, utterances: list[str]):
        features = self._emb_func(utterances)
        return self._clf.predict_proba(features)

In [19]:
linear_scorer = LinearScorer()
linear_scorer.fit(collection)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [20]:
linear_scorer.predict_proba('i want a new card')

array([5.36315606e-02, 2.06736865e-03, 1.07819996e-03, 2.84117638e-03,
       1.06543725e-03, 4.14270486e-04, 4.56166199e-04, 9.31523609e-04,
       8.16927696e-04, 1.04347418e-01, 4.14846962e-03, 1.07265188e-02,
       2.87261476e-02, 2.22940571e-02, 4.22492414e-02, 2.29241231e-03,
       3.62684145e-03, 1.13575060e-03, 3.57048793e-03, 5.06323550e-04,
       1.13312843e-03, 9.95915608e-03, 7.87386713e-04, 1.55948122e-02,
       1.01617718e-02, 1.36223739e-02, 7.43484887e-04, 1.21027345e-03,
       1.65229683e-03, 1.14075433e-02, 2.32802378e-02, 1.24439988e-03,
       8.63093249e-04, 3.74815527e-03, 1.91142796e-03, 5.09926445e-04,
       8.36122168e-04, 4.88069047e-03, 2.72668840e-03, 2.16004925e-01,
       5.72615173e-02, 1.98442906e-01, 7.95156098e-04, 5.20176928e-02,
       1.22214771e-03, 1.21927601e-03, 6.29880435e-04, 6.33465044e-03,
       3.60845091e-04, 1.24068092e-03, 8.97284111e-04, 1.52737685e-03,
       3.98661705e-03, 2.60228011e-03, 3.17184426e-03, 1.78225807e-03,
      

In [21]:
linear_scorer.predict_topk('i want a new card')

array([39, 41,  9])

#### DNNC

In [22]:
from sentence_transformers import CrossEncoder

model_name = "BAAI/bge-reranker-base"
cross_encoder = CrossEncoder(model_name, trust_remote_code=True)
cross_encoder.predict([['i want a new card', 'new card please'], ['i want a new card', 'new card please']])

array([0.99926525, 0.99926525], dtype=float32)

In [23]:
from chromadb import Collection


class DNNCScorer(ScoringModule):
    """
    TODO:
    - think about other cross-encoder settings
    - implement training of cross-encoder with sentence_encoders utils
    - control device of model
    - inspect batch size of model.predict?
    """
    def __init__(self, model_name: str, k: int):
        self.model = CrossEncoder(model_name, trust_remote_code=True)
        self.k = k
    
    def fit(self, data_handler: DataHandler):
        self._collection = data_handler.collection
    
    def predict(self, utterances: list[str]):
        """
        returns just a smooth indicator of the chosen class
        
        TODO: test this code
        """
        query_res = self._collection.query(
            query_texts=utterances,
            n_results=self.k,
            include=["metadatas", "documents"]  # one can add "embeddings", "distances"
        )

        text_pairs = [[[query, cand] for cand in q_res] for query, q_res in zip(utterances, query_res['documents'])]
        flattened_text_pairs = list(it.chain.from_iterable(text_pairs))
        flattened_cross_encoder_scores = self.model.predict(flattened_text_pairs)
        cross_encoder_scores = [flattened_cross_encoder_scores[i:i+self.k] for i in range(0,len(flattened_cross_encoder_scores,self.k))]

        labels_pred = [[cand['intent_id'] for cand in candidates] for candidates in query_res['metadatas']]
        res = np.zeros((len(utterances), self._collection.metadata['n_classes']))
        i_best = np.argmax(cross_encoder_scores, axis=1)
        res[labels_pred[i_best]] = cross_encoder_scores[i_best]
        
        return res

In [24]:
dnnc_scorer = DNNCScorer(model_name="BAAI/bge-reranker-base", k=10)
dnnc_scorer.fit(collection)

In [25]:
dnnc_scorer.predict_proba('i want a card')

array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.99974364, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.     

In [26]:
dnnc_scorer.predict_topk('i want a card', k=1)

array([43])

### Prediction

In [27]:
def make_decision(scores: list[float], thresholds: list[float] | float):
    """
    Arguments: classwise `scores` and classwise `thresholds` (or a single one)
    """
    i_best = np.argmax(scores)
    score_best = scores[i_best]

    thresh = thresholds[i_best] if isinstance(thresholds, list) else thresholds

    if score_best > thresh:
        return i_best

    return None # i.e. out of scope

In [None]:
class PredictionModule(Module):
    def __init__(self, single_thresh: bool):
        self.signle_thresh = single_thresh

    def fit(self, data_handler: DataHandler):
        n_classes = data_handler.collection.metadata['n_classes']
        self.thresh = 0.5 if self.signle_thresh else np.ones(n_classes) / 2
        
        # TODO: optimization

    def score(self, data_handler: DataHandler, metric_fn: Callable):
        predictions = self.predict(data_handler.scores)
        return metric_fn(data_handler.labels_test, predictions)

    def predict(self, scores: list[list[float]]):
        pred_classes = np.argmax(scores, axis=1)
        thresh = self.thresh if self.signle_thresh else self.thresh[pred_classes]
        best_scores = scores[pred_classes]
        pred_classes[best_scores < thresh] = None
        return

## Metrics

### Retrieval

In [None]:
# def score(query_labels: list[int], candidates_labels: list[list[int]]):
#     """
#     Arguments
#     ---
#     - `query_labels`: for each query, this list contains its class labels
#     - `candidates_labels`: for each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
#     - `k`: the number of top items to consider for each query

#     Return
#     ---
#     retrieval metric, averaged over all queries
    

#     TODO:
#     - implement multilabel case, where query_labels: list[list[int]], i.e. each query has multiple intents
#     """
#     raise NotImplementedError()

In [56]:
import numpy as np


def average_precision(query_label: int, candidate_labels: list[int], k: int = None) -> float:
    num_relevant = 0
    sum_precision = 0.0
    for i, label in enumerate(candidate_labels[:k]):
        if label == query_label:
            num_relevant += 1
            sum_precision += num_relevant / (i + 1)
    return sum_precision / num_relevant if num_relevant > 0 else 0.0

def retrieval_map(query_labels: list[int], candidates_labels: list[list[int]], k: int = None):
    ap_list = [average_precision(q, c, k) for q, c in zip(query_labels, candidates_labels)]
    return sum(ap_list) / len(ap_list)

def retrieval_map_numpy(query_labels: list[int], candidates_labels: list[list[int]], k: int) -> float:
    query_labels = np.array(query_labels)
    candidates_labels = np.array(candidates_labels)
    candidates_labels = candidates_labels[:, :k]
    relevance_mask = (candidates_labels == query_labels[:, None])
    cumulative_relevant = np.cumsum(relevance_mask, axis=1)
    precision_at_k = cumulative_relevant * relevance_mask / np.arange(1, k + 1)
    sum_precision = np.sum(precision_at_k, axis=1)
    num_relevant = np.sum(relevance_mask, axis=1)
    average_precision = np.divide(sum_precision, num_relevant, out=np.zeros_like(sum_precision), where=num_relevant != 0)
    return np.mean(average_precision)

y_true = 1
y_pred = [2,1,1]
retrieval_map([y_true], [y_pred], k=3), retrieval_map_numpy([y_true], [y_pred], k=3)

(0.5833333333333333, 0.5833333333333333)

In [59]:
def retrieval_hit_rate(query_labels: list[int], candidates_labels: list[list[int]], k: int) -> float:
    num_queries = len(query_labels)
    hit_count = 0

    for i in range(num_queries):
        query_label = query_labels[i]
        candidate_labels = candidates_labels[i][:k]

        if query_label in candidate_labels:
            hit_count += 1

    return hit_count / num_queries

def retrieval_hit_rate_numpy(query_labels: list[int], candidates_labels: list[list[int]], k: int) -> float:
    query_labels = np.array(query_labels)
    candidates_labels = np.array(candidates_labels)
    truncated_candidates = candidates_labels[:, :k]
    hit_mask = np.isin(query_labels[:, None], truncated_candidates).any(axis=1)
    hit_rate = hit_mask.mean()
    return hit_rate

query_labels = [1]
candidates_labels = [[1, 4, 5, 2]]
k = 2

retrieval_hit_rate(query_labels, candidates_labels, k), retrieval_hit_rate_numpy(query_labels, candidates_labels, k)

(1.0, 1.0)

In [63]:
def retrieval_precision(query_labels: list[int], candidates_labels: list[list[int]], k: int) -> float:
    total_precision = 0.0
    num_queries = len(query_labels)

    for i in range(num_queries):
        query_label = query_labels[i]
        candidate_labels = candidates_labels[i][:k]

        relevant_items = [label for label in candidate_labels if label == query_label]
        precision_at_k = len(relevant_items) / k

        total_precision += precision_at_k

    return total_precision / num_queries


def retrieval_precision_numpy(query_labels: list[int], candidates_labels: list[list[int]], k: int) -> float:
    """
    Arguments
    ---
    - `query_labels`: for each query, this list contains its class labels
    - `candidates_labels`: for each query, these lists contain class labels of items ranked by a retrieval model (from most to least relevant)
    - `k`: the number of top items to consider for each query

    Return
    ---
    retrieval metric, averaged over all queries
    """
    query_labels = np.array(query_labels)
    candidates_labels = np.array(candidates_labels)
    top_k_candidates = candidates_labels[:, :k]
    matches = (top_k_candidates == query_labels[:, None]).astype(int)
    relevant_counts = np.sum(matches, axis=1)
    precision_at_k = relevant_counts / k
    return np.mean(precision_at_k)


query_labels = [1]
candidates_labels = [[1, 1, 3, 4, 5]]
k = 3

retrieval_precision(query_labels, candidates_labels, k), retrieval_precision_numpy(query_labels, candidates_labels, k)

(0.6666666666666666, 0.6666666666666666)

In [91]:
def dcg(relevance_scores, k):
    """
    Calculate the Discounted Cumulative Gain (DCG) at position k.

    Arguments
    ---
    - `relevance_scores`: numpy array of relevance scores for items
    - `k`: the number of top items to consider

    Return
    ---
    DCG value at position k
    """
    relevance_scores = relevance_scores[:k]
    discounts = np.log2(np.arange(2, k + 2))
    dcg = np.sum((2 ** relevance_scores - 1) / discounts)
    return dcg

def idcg(relevance_scores, k):
    """
    Calculate the Ideal Discounted Cumulative Gain (IDCG) at position k.

    Arguments
    ---
    - `relevance_scores`: numpy array of relevance scores for items
    - `k`: the number of top items to consider

    Return
    ---
    IDCG value at position k
    """
    ideal_scores = np.sort(relevance_scores)[::-1]
    return dcg(ideal_scores, k)


def retrieval_ndcg(query_labels, candidates_labels, k):
    ndcg_scores = []
    relevance_scores = np.array(query_labels)[:, None] == np.array(candidates_labels)

    for rel_scores in relevance_scores:
        cur_dcg = dcg(rel_scores, k)
        cur_idcg = idcg(rel_scores, k)
        ndcg_scores.append(0.0 if cur_idcg == 0 else cur_dcg / cur_idcg)

    return sum(ndcg_scores) / len(ndcg_scores)

query_labels = [1]
candidates_labels = [[1, 2, 1, 2, 5]]
k = 3

retrieval_ndcg(query_labels, candidates_labels, k=k)

0.9197207891481876

In [None]:
def retrieval_mrr(query_labels: list[int], candidates_labels: list[list[int]]) -> float:
    mrr_sum = 0.0
    num_queries = len(query_labels)

    for i in range(num_queries):
        query_label = query_labels[i]
        candidate_labels = candidates_labels[i]

        for rank, label in enumerate(candidate_labels):
            if label == query_label:
                mrr_sum += 1.0 / (rank + 1)
                break

    mrr = mrr_sum / num_queries
    return mrr

## Optimization

### Node Abstraction

In [None]:
class RetrievalNode:
    metrics_available = {
        'retrieval_map': retrieval_map,
        'retrieval_ndcg': retrieval_ndcg,
        'retrieval_hit_rate': retrieval_hit_rate,
        'retrieval_precision': retrieval_precision,
        'retrieval_mrr': retrieval_mrr
    }

    def __init__(
        self,
        embedding_model_names: list[str],
        k: int,
        metric: Literal['retrieval_map', 'retrieval_ndcg', 'retrieval_hit_rate', 'retrieval_precision', 'retrieval_mrr']
    ):
        self.embedding_model_names = embedding_model_names
        self.client = PersistentClient(path='../data/chroma')
        self.k = k
        self.metric_name = metric
        self.metric_fn = self.metrics_available[metric]

    def fit(self, dataset: list[dict]):
        """
        `dataset`: intent records

        TODO: add splits statistics to optimization results (train size, test size, how many instances of each class in each split)
        """
        splits = split_sample_utterances(dataset)
        metric_scores = [self._score_embedder(emb_name, *splits) for emb_name in self.embedding_model_names]
        self.optimization_results = {
            'metric_name': self.metric_name,
            'i_best': np.argmax(metric_scores),
            'scores': [{'model': model, 'score': score} for model, score in zip(self.embedding_model_names, metric_scores)]
        }

    def _score_embedder(self, embedder_name, utterances_train, utterances_test, labels_train, labels_test):
        collection = create_collection(
            utterances=utterances_train,
            labels=labels_train,
            client=self.client,
            embedder_name=embedder_name
        )
        query_res = retrieval(utterances_test, collection, self.k)
        labels_pred = [[cand['intent_id'] for cand in candidates] for candidates in query_res['metadatas']]
        return self.metric_fn(labels_test, labels_pred)

In [None]:
class Node:
    metrics_available = {}
    modules_available = {}

    def __init__(
        self,
        modules_search_spaces: list[dict],
        metric: str
    ):
        """
        `modules_search_spaces`: list of records, where each record is a mapping: hyperparam_name -> list of values (search space) with extra field "module_type" with values from ["knn", "linear", "dnnc"]
        """
        self.modules_search_spaces = modules_search_spaces
        self.metric_name = metric

    def fit(self, data_handler: DataHandler):
        metric_scores = []
        modules_configs = []
        for search_space in self.modules_search_spaces:
            module_type = search_space.pop('module_type')
            for module_config in it.product(*search_space.values()):
                modules_configs.append(module_config)
                module = self.modules_available[module_type](**module_config)
                metric = module.fit_score(data_handler, self.metrics_available[self.metric_name])
                metric_scores.append(metric)

        self.optimization_results = {
            'metric_name': self.metric_name,
            'i_best': np.argmax(metric_scores),
            'scores': metric_scores,
            'configs': modules_configs
        }

In [None]:
class ScorerNode(Node):
    metrics_available = {
        'neg_cross_entropy': None,
        'roc_auc_ovr': None,
        'roc_auc_ovo': None,
    }

    modules_available = {
        'knn': KNNScorer,
        'linear': LinearScorer,
        'dnnc': DNNCScorer
    }

In [None]:
class RetrievalNode(Node):
    metrics_available = {
        'retrieval_map': retrieval_map,
        'retrieval_ndcg': retrieval_ndcg,
        'retrieval_hit_rate': retrieval_hit_rate,
        'retrieval_precision': retrieval_precision,
        'retrieval_mrr': retrieval_mrr
    }

    modules_available = {
        'vector_db': VectorDBModule
    }