In [None]:
import itertools
import pprint
import warnings

import ibis
from codetiming import Timer

import letsql as ls
import nasa_avionics_data_ml.settings as S
import nasa_avionics_data_ml.zip_data as ZD
from nasa_avionics_data_ml.letsql_udwf_inference import (
    asof_join_flight_data,
    do_manual_batch,
    make_evaluate_all,
    make_rate_to_parquet,
    read_model_and_scales,
    union_cached_asof_joined_flight_data,
)
from nasa_avionics_data_ml.lib import Config

In [None]:
(order_by, group_by) = ("time", "flight")
tail = "Tail_652_1"

return_type = "float64"

In [None]:
(config, *_) = Config.get_debug_configs()
(model, scaleX, scaleT) = read_model_and_scales()
for p in S.parquet_cache_path.iterdir():
    p.unlink()
print(tuple(S.parquet_cache_path.iterdir()))

## Demonstrate query of remote data into local engine tha

In [None]:
# get 8 flights from tail 652_1
(flight_data, *_) = flight_datas = tuple(itertools.islice(
    next(td for td in ZD.TailData.gen_from_data_dir() if td.tail == tail).gen_parquet_exists(),
    8,
))
single_expr = asof_join_flight_data(flight_data)

In [None]:
pprint.pprint(make_rate_to_parquet(flight_data))
print(ls.to_sql(single_expr))

## Create the deferred udwf expression to run inference

In [None]:
evaluate_all = make_evaluate_all(
    ibis.schema({name: float for name in config.x_names}),
    return_type, model, 8, scaleX, scaleT,
)
expr = union_cached_asof_joined_flight_data(*flight_datas)
window = ibis.window(
    preceding=config.seq_length-1,
    following=0,
    order_by=order_by,
    group_by=group_by,
)
with_prediction = (
    expr
    .mutate(predicted=evaluate_all.on_expr(expr).over(window))
)

## Run inference

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    with Timer("from_letsql"):
        from_letsql = ls.execute(with_prediction.order_by(group_by, order_by))
print(tuple(S.parquet_cache_path.iterdir()))

In [None]:
# clear the cache
for p in S.parquet_cache_path.iterdir(): p.unlink()
with Timer("from_manual"):
    from_manual = (
        do_manual_batch(expr, model, config.seq_length, scaleX, scaleT, return_type, config.xlist, group_by, order_by)
        .sort_values([group_by, order_by], ignore_index=True)
    )

In [None]:
assert from_manual.equals(from_letsql)

In [None]:
# run with warm cache
with Timer("from_letsql cached"):
    from_letsql = ls.execute(with_prediction.order_by(group_by, order_by))
print(tuple(S.parquet_cache_path.iterdir()))

## Inspect the data, evaulate inference

In [None]:
from_letsql

In [None]:
for (flight, df) in from_letsql.groupby("flight"):
    ax = df.set_index("time")[["ALT", "predicted"]].plot()
    ax.set_title(f"flight = {flight:}")