## Optimization

In [1]:
from autointent.nodes.optimization import NodeOptimizer

In [2]:
import logging

logger = logging.getLogger(__name__)

In [3]:
from autointent import Context
from autointent.pipeline.optimization.utils import get_run_name, load_data, get_db_dir


run_name = get_run_name("multiclass-cpu")
db_dir = get_db_dir("", run_name)

data = load_data("/home/voorhs/repos/AutoIntent/tests/minimal_optimization/data/clinc_subset.json", multilabel=False)
context = Context(
    multiclass_intent_records=data,
    multilabel_utterance_records=[],
    test_utterance_records=[],
    device="cpu",
    mode="multiclass_as_multilabel",
    multilabel_generation_config="",
    db_dir=db_dir,
    regex_sampling=0,
    seed=0,
    dump_dir="modules_dumps"
)

### Retrieval

In [4]:
retrieval_optimizer_config = {
    'metric': 'retrieval_hit_rate_intersecting',
    'node_type': 'retrieval',
    'search_space': [
        {
            'k': [10],
            'model_name': ['deepvk/USER-bge-m3'],
            'module_type': 'vector_db'
        }
    ]
}

In [5]:
retrieval_optimizer = NodeOptimizer.from_dict_config(retrieval_optimizer_config)

In [7]:
retrieval_optimizer.fit(context)

In [8]:
retrieval_dump = context.optimization_info.trials.retrieval[0].module_dump_dir
retrieval_dump

'modules_dumps/retrieval/vector_db/comb_0'

### Scoring

In [9]:
scoring_optimizer_config = {
    'metric': 'scoring_roc_auc',
    'node_type': 'scoring',
    'search_space': [
        {
            'k': [3],
            'module_type': 'knn',
            'weights': ['uniform', 'distance', 'closest']
        },
        {
            'module_type': 'linear'
        }
    ]
}

In [10]:
scoring_optimizer = NodeOptimizer.from_dict_config(scoring_optimizer_config)

In [11]:
scoring_optimizer.fit(context)

In [12]:
scoring_dump = context.optimization_info.trials.scoring[-1].module_dump_dir
scoring_dump

'modules_dumps/scoring/linear/comb_0'

### Prediction

In [None]:
prediction_optimizer_config = {
    'metric': 'prediction_accuracy',
    'node_type': 'prediction',
    'search_space': [
        {
            'module_type': 'threshold',
            'thresh': [0.5]
        }
    ]
}

In [10]:
prediction_optimizer = NodeOptimizer.from_dict_config(prediction_optimizer_config)

In [11]:
prediction_optimizer.fit(context)

### check result

In [12]:
context.optimization_info.dump_evaluation_results()

{'metrics': {'regexp': [],
  'retrieval': [1.0],
  'scoring': [1.0, 1.0, 1.0, 1.0],
  'prediction': [0.8333333333333334]},
 'configs': {'regexp': [],
  'retrieval': [{'module_type': 'vector_db',
    'module_params': {'k': 10, 'model_name': 'deepvk/USER-bge-m3'},
    'metric_name': 'retrieval_hit_rate_intersecting',
    'metric_value': 1.0}],
  'scoring': [{'module_type': 'knn',
    'module_params': {'k': 3, 'weights': 'uniform'},
    'metric_name': 'scoring_roc_auc',
    'metric_value': 1.0},
   {'module_type': 'knn',
    'module_params': {'k': 3, 'weights': 'distance'},
    'metric_name': 'scoring_roc_auc',
    'metric_value': 1.0},
   {'module_type': 'knn',
    'module_params': {'k': 3, 'weights': 'closest'},
    'metric_name': 'scoring_roc_auc',
    'metric_value': 1.0},
   {'module_type': 'linear',
    'module_params': {},
    'metric_name': 'scoring_roc_auc',
    'metric_value': 1.0}],
  'prediction': [{'module_type': 'threshold',
    'module_params': {'thresh': 0.5},
    'metric_

## Inference

In [13]:
from autointent.nodes import InferenceNode

### Retrieval

In [14]:
retrieval_config = dict(
    node_type="retrieval",
    module_type="vector_db",
    module_config={"k": 10, "model_name": 'deepvk/USER-bge-m3'},
    load_path=retrieval_dump
)

In [15]:
retrieval = InferenceNode(**retrieval_config)

In [16]:
retrieval.node_info

<autointent.nodes.nodes_info.retrieval.RetrievalNodeInfo at 0x7527d0e43b90>

In [21]:
labels, distances, texts = retrieval.module.predict(["hello", "world"])

In [22]:
labels[0], distances[0], texts[0]

([[0, 0, 1],
  [0, 0, 1],
  [0, 0, 1],
  [0, 0, 1],
  [0, 1, 0],
  [0, 1, 0],
  [1, 0, 0],
  [0, 1, 0],
  [0, 1, 0],
  [1, 0, 0]],
 [np.float32(0.477605),
  np.float32(0.49597514),
  np.float32(0.50022346),
  np.float32(0.5128485),
  np.float32(0.55573064),
  np.float32(0.5758226),
  np.float32(0.58933824),
  np.float32(0.6130637),
  np.float32(0.6174678),
  np.float32(0.63502705)],
 ['please set an alarm for mid day',
  'make sure my alarm is set for three thirty in the morning',
  'have an alarm set for three in the morning',
  'wake me up at noon tomorrow',
  'i think my account is blocked but i do not know the reason',
  'can you tell me why is my bank account frozen',
  'can i make a reservation for redrobin',
  'why is there a hold on my capital one checking account',
  'why is there a hold on my american saving bank account',
  'are reservations taken at redrobin'])

### Scoring

In [25]:
scoring_config = dict(
    node_type="scoring",
    module_type="linear",
    module_config={},
    load_path=scoring_dump
)

In [26]:
scoring = InferenceNode(**scoring_config)

In [27]:
scoring.module.predict(["hello", "world"])

array([[0.27486506, 0.31681463, 0.37459106],
       [0.2769358 , 0.31536099, 0.37366978]])

### Prediction

In [None]:
prediction_config = dict(
    node_type="prediction",
    module_type="threshold",
    module_config={"thresh": 0.5},
    load_path="."
)

In [None]:
prediction = InferenceNode(**prediction_config)

In [None]:
prediction.module.predict(context, ...)