# Notebook Overview

## Purpose
This notebook is designed to perform drift detection via a simple heuristic. The heuristic involves computing the ratio of unselected bounding boxes, i.e. false positives produced by the model during inference that were identified by end-users, to the total number of detected bounding boxes detected by the model within a timeframe prodivded to this notebook via the parameter `drift_timeframe_days`.

## Inputs
- `drift_timeframe_days` parameter which specifies how many days back to retrieve records from the gold table
- Validated inference records from the gold table from the past `drift_timeframe_days` days

## Processes
1. Read `drift_timeframe_days` parameter from workflow, `drift_threshold` and `gold_table_name` parameter from config file 
2. Read in records from the gold table whose `reviewed_time` field is within the specified timeframe
3. Add a column `num_deselected_bboxes` which holds the number of bounding boxes that were deselected for that record
4. Add a column `num_bboxes` which holds the _total_ number of bounding boxes that were detected by the model for that record
5. Compute the ratio unselected boxes to total detected boxes and check if the ratio meets or exceeds the threshold given

## Outputs
- `True` if drift detection threshold is meet which triggers the retraining notebook to run, `False` otherwise

In [0]:
%run ./nb_config_retrieval

In [0]:
from tsdb.ml.drift import get_struct_counts, compute_drift_score

In [0]:
# Check if the global view exists
if spark.catalog._jcatalog.tableExists("global_temp.global_temp_towerscout_configs"):
    # Query the global temporary view and collect the first row
    result = spark.sql("SELECT * FROM global_temp.global_temp_towerscout_configs").collect()[0]
    
    # Extract values from the result row
    catalog = result['catalog_name']
    schema = result['schema_name']
    gold_table_name = result['gold_table_name']
    drift_threshold = float(result['drift_threshold'])
    debug_mode = result['debug_mode'] == 'true'
    unit_test_mode = result['unit_test_mode'] == 'true'
    drift_threshold = float(result['drift_threshold'])

else:
    # Exit the notebook with an error message if the global view does not exist
    dbutils.notebook.exit("Global view 'global_temp_towerscout_configs' does not exist, make sure to run the utils notebook")

In [0]:
dbutils.widgets.text("drift_timeframe_days", defaultValue="90")  

In [0]:
drift_timeframe_days = int(dbutils.widgets.get("drift_timeframe_days"))

In [0]:
gold_records = spark.read.format("delta").table(f"{catalog}.{schema}.{gold_table_name}")

filtered_records = get_struct_counts(gold_records, "reviewed_time", "bboxes", "x.class = 1 and x.class_name = 'not-ct'", drift_timeframe_days)

In [0]:
unselected_bboxes_ratio =  compute_drift_score(filtered_records, "num_filtered_structs", "num_structs")

drift_detected = unselected_bboxes_ratio >= drift_threshold

dbutils.jobs.taskValues.set(key="drift_detected", value=drift_detected)