In [2]:
from typing import Any, List, Optional

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle

from catboost import CatBoostRanker, CatBoostClassifier, Pool
import pandas as pd
import numpy as np
import joblib

In [32]:
data = pd.DataFrame({
    'query': ['hello', 'how', 'are', 'you'],
    'feature_1': [1., 2., 4., 6.],
    'feature_2': [34, 45, 56, 67],
})

target = np.array([1, 0, 1, 0])
data = data.values

In [33]:
model = CatBoostClassifier(random_seed=42, iterations=5)
train_pool = Pool(data, label=target, cat_features=[0])
model.fit(train_pool)

Learning rate set to 0.12562
0:	learn: 0.6883390	total: 8.98ms	remaining: 35.9ms
1:	learn: 0.6788150	total: 22.7ms	remaining: 34.1ms
2:	learn: 0.6745137	total: 31.1ms	remaining: 20.7ms
3:	learn: 0.6703554	total: 39ms	remaining: 9.75ms
4:	learn: 0.6658914	total: 49.4ms	remaining: 0us


<catboost.core.CatBoostClassifier at 0x1f2e9174cd0>

In [34]:
joblib.dump(model, "catboost_model.pkl")

['catboost_model.pkl']

In [133]:
import joblib
from typing import Any, List, Optional
from pydantic import BaseModel, Field
from llama_index.core.data_structs import Node

class CatBoost:
    def __init__(self, path_to_cb_model: str):
        self.model = joblib.load(path_to_cb_model)
    
    def predict(self, X):
        return self.model.predict_proba(X)[:, 1]

class CatboostReranker(BaseNodePostprocessor, BaseModel):
    top_n: int = Field(default=2, description="Top N nodes to return.")
    keep_retrieval_score: bool = Field(
        default=False,
        description="Whether to keep the retrieval score in metadata.",
    )
    app_settings: dict = Field(default_factory=dict, description="Settings for search")
    path_to_model: str = Field(description="Path to CatBoost model")
    _predictor: Any = PrivateAttr()

    def __init__(
        self,
        top_n: int = 2,
        keep_retrieval_score: Optional[bool] = False,
        path_to_model: str = None,
        app_settings=app_settings,
        **data
    ):
        super().__init__(top_n=top_n, keep_retrieval_score=keep_retrieval_score, path_to_model=path_to_model, **data)
        self._predictor = CatBoost(path_to_model)

    def _postprocess_nodes(
        self,
        nodes: List[NodeWithScore],
        query_bundle: Optional[QueryBundle] = None,
    ) -> List[NodeWithScore]:

        if not nodes:
            return []

        query = [node.get_content() for node in nodes]
        nodes_metadata = [node.metadata for node in nodes]

        combined_data = []
        for q, meta in zip(query, nodes_metadata):
            combined_data.append([q, meta['feature_1'], meta['feature_2']])

        data_array = np.array(combined_data, dtype=object)

        test_pool = Pool(data_array, cat_features=[0])
        
        predicted_scores = self._predictor.predict(test_pool)
        
        for node, score in zip(nodes, predicted_scores):
            node.score = score
            if not self.keep_retrieval_score:
                node.score = score 
        
        sorted_nodes = sorted(nodes, key=lambda x: x.score, reverse=True)
        
        return sorted_nodes[:self.top_n]

In [134]:
node_1 = Node(text='girl')
node_1.metadata = {'feature_1': 36., 'feature_2': 9}
node_2 = Node(text='what is your problem, boy')
node_2.metadata = {'feature_1': 987., 'feature_2': 98}
node_3 = Node(text='hello how i can do it')
node_3.metadata = {'feature_1': 80, 'feature_2': 4}
node_4 = Node(text='woooorld')
node_4.metadata = {'feature_1': 8765., 'feature_2': 1}

nodes = [
    NodeWithScore(node=node_1, score=0.8),
    NodeWithScore(node=node_2, score=0.3),
    NodeWithScore(node=node_3, score=0.5),
    NodeWithScore(node=node_4, score=0.98),
]

In [135]:
path_to_cb_model = "catboost_model.pkl"

ranker = CatboostReranker(
    top_n=2,
    keep_retrieval_score=True,
    path_to_model=path_to_cb_model
)

In [136]:
ranked_nodes = ranker._postprocess_nodes(nodes)
ranked_nodes

[NodeWithScore(node=TextNode(id_='87e3efbf-704a-4211-a566-ad4abed2b2b5', embedding=None, metadata={'feature_1': 36.0, 'feature_2': 9}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='girl', mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=0.49535310163657537),
 NodeWithScore(node=TextNode(id_='a4f23c15-854a-4e69-ac58-7b63138d4fd7', embedding=None, metadata={'feature_1': 80, 'feature_2': 4}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='hello how i can do it', mimetype='text/plain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=0.49535310163657537)]