From 7a120cef9d8e110a721b4c186a72786b5b570511 Mon Sep 17 00:00:00 2001 From: colethienes Date: Tue, 3 Dec 2019 20:37:46 -0500 Subject: [PATCH] Fix/Proxy: record search boost --- nboost/proxy.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/nboost/proxy.py b/nboost/proxy.py index 64956529..a389eabb 100644 --- a/nboost/proxy.py +++ b/nboost/proxy.py @@ -144,13 +144,11 @@ def record_topk_and_choices(self, topk: int = None, choices: list = None): @stats.vars_context def record_mrrs(self, upstream_mrr: float = None, model_mrr: float = None): - """Add the upstream mrr and model mrr to the stats""" + """Add the upstream mrr, model mrr, and search boost to the stats""" with suppress(ZeroDivisionError): - self.record_search_boost(search_boost=model_mrr / upstream_mrr) - - @stats.vars_context - def record_search_boost(self, search_boost: float = None): - """Record the quotient of model mrr and upstream mrr""" + var = self.stats.record['vars'] + var['search_boost'] = { + 'avg': var['model_mrr']['avg'] / var['upstream_mrr']['avg']} @stats.time_context def server_connect(self, server_socket: socket.socket) -> None: