# 04 - Forecasting with `ai_forecast()` in DBSQL 

We're looking at the following flow:

<img src="../docs/imgs/energy-sa-forecasting-ai.png" width="300">

This notebook creates forecast tables for UK energy data at two granularities—daily and 30-minute—using Databricks’ [AI_FORECAST](https://learn.microsoft.com/en-us/azure/databricks/sql/language-manual/functions/ai_forecast). `ai_forecast()` is a table-valued function designed to extrapolate time series data into the future. It will need to be enabled for your workspace by an administrator to successfully execute this notebook.

In [0]:
%run ./includes/common_functions_and_imports

Since we are working with Databricks SQL in a notebook, parameters work slightly differently. They must be exposed as notebook parameters for the SQL interpreter to be able to access them, so we are going to set up our widgets based on the config parameters.

In [0]:
dbutils.widgets.text("catalog_name", defaultValue=CONFIG.target_catalog, label="Catalog Name")
dbutils.widgets.text("schema_name", defaultValue=CONFIG.target_schema, label="Schema Name")
dbutils.widgets.text("input_table_name", defaultValue="unscaled_train_features", label="Input Table Name")
dbutils.widgets.text("target_table_name_finegrain", defaultValue="ai_forecast_uk_energy_30min", label="Output Table Name")
dbutils.widgets.text("target_table_name_daily", defaultValue="ai_forecast_uk_energy_daily", label="Daily Output Table Name")
# Test data
dbutils.widgets.text("test_dataset_name", defaultValue="unscaled_test_features", label="Test Table Name")
dbutils.widgets.text("target_table_name_daily_test", defaultValue="ai_test_uk_energy_daily", label="Daily Output Table Name")

# Quickly error check that the fully qualified input table exists
input_table_name = dbutils.widgets.get('input_table_name')
input_fqn = (
    f"{CONFIG.target_catalog}.{CONFIG.target_schema}.{input_table_name}"
)
if not spark.catalog.tableExists(input_fqn):
  dbutils.notebook.exit('Source table does not exist')

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :target_table_name_daily_test)) AS
  SELECT
    DATE(data_collection_log_timestamp) AS ts,   
    lv_feeder_unique_id,
    SUM(normalized_consumption_kwh) AS daily_normalized_consumption_kwh
  FROM
    IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :test_dataset_name))
  GROUP BY
    DATE(data_collection_log_timestamp), lv_feeder_unique_id

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :target_table_name_daily)) AS
WITH aggregated AS (
  SELECT
    DATE(data_collection_log_timestamp) AS ts,   
    lv_feeder_unique_id,
    SUM(normalized_consumption_kwh) AS daily_normalized_consumption_kwh
  FROM
    IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :input_table_name))
  GROUP BY
    DATE(data_collection_log_timestamp), lv_feeder_unique_id
)
SELECT * FROM AI_FORECAST(
  TABLE(
    SELECT
      ts,
      lv_feeder_unique_id,
      daily_normalized_consumption_kwh
    FROM aggregated
  ),
  horizon => '2025-03-29',
  time_col => 'ts',
  value_col => 'daily_normalized_consumption_kwh',
  group_col => 'lv_feeder_unique_id',
  seed => 23
)

In [0]:
%sql
CREATE OR REPLACE TABLE IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :target_table_name_finegrain)) (
SELECT * FROM AI_FORECAST(
  TABLE(
    SELECT
        data_collection_log_timestamp,
        lv_feeder_unique_id,
        normalized_consumption_kwh
    FROM IDENTIFIER(concat_ws('.', :catalog_name, :schema_name, :input_table_name))
  ),
  horizon => '2025-03-29',
  time_col => 'data_collection_log_timestamp',
  value_col => 'normalized_consumption_kwh',
  group_col => 'lv_feeder_unique_id',
  seed => 23
)
)