# AI_FORECAST DBSQL Function

AI_FORECAST is a table-valued function designed to extrapolate time series data into the future. In its most general form, AI_FORECAST accepts __grouped, multivariate, mixed-granularity data,__ and forecasts that data up to some horizon in the future.


AI_FORECAST is an all-in-one function for doing out-of-sample predictions on a large number of time series simultaneously. AI_FORECAST is useful for

- On-the-fly applications where training and persisting models is not required (e.g. dashboards, investigations)
- Scenarios where model persistence is complicated or cumbersome (e.g. generating forecasts for multiple grouping set rollups over the same dataset, or if some dimensions have a few months of data & others have years of data)
- Forecasting “at scale” in the sense that many independent models are trained and evaluated simultaneously.

### Prerequisites 
Standard compute clusters or SQL warehouses running DBR 15.1+. The function is not yet available on serverless.

### Setup
AI_FORECAST can be enabled in standard compute environments (i.e. not SQL warehouses) at a session level via the Spark configuration.

`SET spark.databricks.sql.functions.aiForecast.enabled = TRUE`

### API

```
SELECT ... FROM AI_FORECAST(
  observed TABLE,
  horizon DATE | TIMESTAMP | STRING,
  time_col STRING,
  value_col STRING | ARRAY<STRING>,
  group_col STRING | ARRAY<STRING> | NULL DEFAULT NULL,
  prediction_interval_width DOUBLE DEFAULT 0.95,
  frequency STRING DEFAULT 'auto',
  seed INTEGER | NULL DEFAULT NULL,
  parameters STRING DEFAULT '{}' -- * New in DBR 15.3 *
)

```


# Examples

In [0]:
SET spark.databricks.sql.functions.aiForecast.enabled = TRUE

key,value
spark.databricks.sql.functions.aiForecast.enabled,True


### Forecast until a Specified Date.

In [0]:
SELECT * FROM samples.nyctaxi.trips order by tpep_pickup_datetime desc

tpep_pickup_datetime,tpep_dropoff_datetime,trip_distance,fare_amount,pickup_zip,dropoff_zip
2016-02-29T23:51:20Z,2016-02-29T23:57:41Z,0.91,6.5,10153,10065
2016-02-29T23:51:06Z,2016-02-29T23:59:38Z,1.4,7.5,10154,10019
2016-02-29T23:50:27Z,2016-02-29T23:54:25Z,1.35,6.0,11238,11205
2016-02-29T23:39:15Z,2016-02-29T23:55:24Z,2.8,12.5,10003,10038
2016-02-29T23:33:22Z,2016-02-29T23:38:56Z,0.8,5.5,10001,10011
2016-02-29T23:32:16Z,2016-02-29T23:49:57Z,8.72,25.5,11371,10021
2016-02-29T23:29:04Z,2016-02-29T23:34:08Z,1.0,6.0,10003,10001
2016-02-29T23:23:23Z,2016-02-29T23:42:33Z,3.5,15.5,10001,10162
2016-02-29T23:23:16Z,2016-02-29T23:26:03Z,0.51,4.0,11231,11201
2016-02-29T23:21:38Z,2016-02-29T23:34:44Z,2.33,11.0,11211,11222


##### Let's say we wanted to forecast revenue (sum of fair_amount) per day.

In [0]:
WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    SUM(fare_amount) AS revenue
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1
)
SELECT * FROM AI_FORECAST(
  TABLE(aggregated),
  horizon => '2016-03-31',
  time_col => 'ds',
  value_col => 'revenue'
)


ds,revenue_forecast,revenue_upper,revenue_lower
2016-03-01,4913.364710590872,5092.238161490353,900.9210677797578
2016-03-02,5206.637086933598,7264.220917778914,1194.1934441224837
2016-03-03,5804.4132960390025,5806.025692797726,1791.969653227888
2016-03-04,5662.675824170121,7720.619871007284,1650.2321813590063
2016-03-05,5146.320389718501,7204.275020292225,1133.8767469073869
2016-03-06,4839.87083907125,6897.825780367234,827.4271962601356
2016-03-07,4822.592401290079,6880.549366068541,810.1487584789643
2016-03-08,5130.882708186111,7152.853560793545,1118.4390653749963
2016-03-09,5424.155084528837,7482.111417161024,1411.7114417177222
2016-03-10,6021.931293634252,6673.050186998387,2009.4876508231375


Databricks visualization. Run in Databricks to view.

### A slightly more complex example.

It is very common for tables to not materialize 0s or empty entries. If the values of the missing entries can be inferred (e.g. 0, 100%, etc.) then these values should be coalesced prior to calling the forecast function. If the values are truly missing or unknown, then they can be left empty.

For very sparse data (e.g. >50% missing entries), it is best practice to provide a frequency value explicitly. Two entries 35 days apart will be inferred as a time series with granularity 35D, rather than a daily series with 34 missing entries.


Here's an example of missing dates in the `nyctaxi.trips` table.

In [0]:
SELECT
    DATE(tpep_pickup_datetime) AS ds,
    dropoff_zip,
    SUM(fare_amount) AS revenue,
    COUNT(*) AS n_trips
  FROM
    samples.nyctaxi.trips
  WHERE dropoff_zip = 7114
  GROUP BY
    1, 2
  ORDER BY ds 


ds,dropoff_zip,revenue,n_trips
2016-01-01,7114,66.0,1
2016-01-03,7114,64.5,1
2016-01-04,7114,105.0,1
2016-01-06,7114,61.5,1
2016-01-08,7114,122.5,2
2016-01-09,7114,53.5,1
2016-01-10,7114,67.0,1
2016-01-11,7114,78.5,1
2016-01-12,7114,71.0,1
2016-01-21,7114,56.5,1


##### Let's say we wanted to forecast revenue AND number of trips for each dropoff zip code.

In [0]:
-- Generate the aggregated table from the nyctaxi.trips
WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    dropoff_zip,
    SUM(fare_amount) AS revenue,
    COUNT(*) AS n_trips
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1, 2
),
-- Generate the full series of missing dates for each zip code
spine AS (
  SELECT all_dates.ds, all_zipcodes.dropoff_zip
  FROM (SELECT DISTINCT ds FROM aggregated) all_dates
  CROSS JOIN (SELECT DISTINCT dropoff_zip FROM aggregated) all_zipcodes
)
-- Perform forecast on the spine and aggregated table
SELECT * FROM AI_FORECAST(
-- Input table fills in zero for dates that were originally empty
  TABLE(
    SELECT
      spine.*,
      COALESCE(aggregated.revenue, 0) AS revenue,
      COALESCE(aggregated.n_trips, 0) AS n_trips
    FROM spine LEFT JOIN aggregated USING (ds, dropoff_zip)
  ),
  horizon => '2016-03-31',
  time_col => 'ds',
  value_col => ARRAY('revenue', 'n_trips'),
  group_col => 'dropoff_zip',
  prediction_interval_width => 0.9,
  parameters => '{"global_floor": 0}'
)
order by dropoff_zip,ds 


ds,dropoff_zip,revenue_forecast,revenue_upper,revenue_lower,n_trips_forecast,n_trips_upper,n_trips_lower
2016-03-01,7002,5.830807389685071e-09,5.830807575898092e-09,0.0,3.751101133954946e-09,3.7511012008475565e-09,0.0
2016-03-02,7002,7.724019338004879e-09,7.724019480299643e-09,0.0,4.958435723692672e-09,4.984307258761572e-09,0.0
2016-03-03,7002,1.0223942223693747e-08,1.0223942363623074e-08,0.0,6.547250447116308e-09,2.8155256000339157e-06,0.0
2016-03-04,7002,4.896544814955026e-06,4.951466908084134e-06,9.05318283151564e-11,1.8945461005526096e-06,1.8998412087594535e-06,8.318459876382042e-11
2016-03-05,7002,2.4600731222833477e-08,2.4600731279584572e-08,0.0,1.4633115426471153e-08,1.4633305637667994e-08,0.0
2016-03-06,7002,3.3748390477429e-08,3.3748390517157313e-08,0.0,1.96827484028089e-08,1.9683526214991382e-08,0.0
2016-03-07,7002,4.6280723092666626e-08,4.62807231200813e-08,0.0,2.6417820856655762e-08,2.6419012253789127e-08,0.0
2016-03-08,7002,9.273902674990349e-08,9.2739026750261e-08,0.0,4.574526577869653e-08,4.5745265778786754e-08,0.0
2016-03-09,7002,1.22374781743227e-07,1.2237478174350028e-07,0.0,6.011792783691042e-08,6.011796262661401e-08,0.0
2016-03-10,7002,1.6150780232670372e-07,1.6150780232697244e-07,0.0,7.9031903537586e-08,7.911004250928379e-08,0.0


###### To help better visualize what the `spine` table looks like:

In [0]:
WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    dropoff_zip,
    SUM(fare_amount) AS revenue,
    COUNT(*) AS n_trips
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1, 2
),
-- Generate the aggregated table from the nyctaxi.trips
spine AS (
  SELECT all_dates.ds, all_zipcodes.dropoff_zip
  FROM (SELECT DISTINCT ds FROM aggregated) all_dates
  CROSS JOIN (SELECT DISTINCT dropoff_zip FROM aggregated) all_zipcodes
)SELECT * FROM spine where dropoff_zip = 7114
order by ds asc

ds,dropoff_zip
2016-01-01,7114
2016-01-02,7114
2016-01-03,7114
2016-01-04,7114
2016-01-05,7114
2016-01-06,7114
2016-01-07,7114
2016-01-08,7114
2016-01-09,7114
2016-01-10,7114


###### To help better visualize what the input table looks like:

In [0]:
-- Generate the aggregated table from the nyctaxi.trips
WITH
aggregated AS (
  SELECT
    DATE(tpep_pickup_datetime) AS ds,
    dropoff_zip,
    SUM(fare_amount) AS revenue,
    COUNT(*) AS n_trips
  FROM
    samples.nyctaxi.trips
  GROUP BY
    1, 2
),
-- Generate the full series of missing dates for each zip code
spine AS (
  SELECT all_dates.ds, all_zipcodes.dropoff_zip
  FROM (SELECT DISTINCT ds FROM aggregated) all_dates
  CROSS JOIN (SELECT DISTINCT dropoff_zip FROM aggregated) all_zipcodes
)
-- Perform forecast on the spine and aggregated table
SELECT
  spine.*,
  COALESCE(aggregated.revenue, 0) AS revenue,
  COALESCE(aggregated.n_trips, 0) AS n_trips
FROM spine LEFT JOIN aggregated USING (ds, dropoff_zip)
where dropoff_zip = 7114
order by ds asc


ds,dropoff_zip,revenue,n_trips
2016-01-01,7114,66.0,1
2016-01-02,7114,0.0,0
2016-01-03,7114,64.5,1
2016-01-04,7114,105.0,1
2016-01-05,7114,0.0,0
2016-01-06,7114,61.5,1
2016-01-07,7114,0.0,0
2016-01-08,7114,122.5,2
2016-01-09,7114,53.5,1
2016-01-10,7114,67.0,1


### Daily + Hourly Forecasting
AI_FORECAST can be used to generate forecasts at multiple granularities spanning the same window of time.


In [0]:
SELECT * FROM AI_FORECAST(
-- Daily aggragtions of revenue
  TABLE(
    SELECT
      DATE_TRUNC('DAY', tpep_pickup_datetime) AS ts,
      ANY_VALUE('DAY') AS granularity,
      SUM(fare_amount) AS revenue
    FROM
      samples.nyctaxi.trips
    GROUP BY
      1
    
    UNION ALL
-- Hourly aggragtions of revenue
    SELECT
      DATE_TRUNC('HOUR', tpep_pickup_datetime) AS ts,
      ANY_VALUE('HOUR') AS granularity,
      SUM(fare_amount) AS revenue
    FROM
      samples.nyctaxi.trips
    GROUP BY
      1
  ),
  horizon => '2016-03-31',
  time_col => 'ts',
  value_col => 'revenue',
  group_col => 'granularity'
)


ts,granularity,revenue_forecast,revenue_upper,revenue_lower
2016-03-01T00:00:00Z,DAY,4913.364710590872,5092.238161490353,900.9210677797578
2016-03-02T00:00:00Z,DAY,5206.637086933598,7264.220917778914,1194.1934441224837
2016-03-03T00:00:00Z,DAY,5804.4132960390025,5806.025692797726,1791.969653227888
2016-03-04T00:00:00Z,DAY,5662.675824170121,7720.619871007284,1650.2321813590063
2016-03-05T00:00:00Z,DAY,5146.320389718501,7204.275020292225,1133.8767469073869
2016-03-06T00:00:00Z,DAY,4839.87083907125,6897.825780367234,827.4271962601356
2016-03-07T00:00:00Z,DAY,4822.592401290079,6880.549366068541,810.1487584789643
2016-03-08T00:00:00Z,DAY,5130.882708186111,7152.853560793545,1118.4390653749963
2016-03-09T00:00:00Z,DAY,5424.155084528837,7482.111417161024,1411.7114417177222
2016-03-10T00:00:00Z,DAY,6021.931293634252,6673.050186998387,2009.4876508231375


Databricks visualization. Run in Databricks to view.

### Investigations
AI_FORECAST can be used to perform drill-down investigations. Join forecasting results with the original table to compute residuals. Pair this functionality with grouping set rollups to quickly isolate unexpected changes.

In the sample data below we have introduce an anomaly for all CA/Rural UIDs on 2023-01-31.


In [0]:
CREATE OR REPLACE TEMPORARY VIEW
hierarchical_data_with_an_anomalous_date
AS
WITH
-- Create the dimensions for the dataset
dimensions AS (
  SELECT
    country, population, uid, 10 * RAND() AS intercept, RAND() AS slope
  LATERAL VIEW
    EXPLODE(ARRAY('US', 'CA', 'UK', 'IN')) t1 AS country
  LATERAL VIEW
    EXPLODE(ARRAY('Urban', 'Rural', 'Suburban')) t2 AS population
  LATERAL VIEW
    EXPLODE(SEQUENCE(0, 10)) t3 AS uid
),
-- Create the timestamps for the dataset
dim_times AS (
  SELECT dimensions.*, ts, DATEDIFF(HOUR, '2023-01-01', ts) AS x
  FROM dimensions
  LATERAL VIEW
    EXPLODE(SEQUENCE(
      TIMESTAMP('2023-01-01'),
      TIMESTAMP('2023-02-01'),
      INTERVAL 1 HOUR
    )) t AS ts
)
-- Create the value column
SELECT
  dim_times.*,
  (intercept + (slope * x) + RANDN())
  * IF(
      -- Introduce an anomaly for all CA/Rural UIDs on 2023-01-31
      DATE(ts) = '2023-01-31'
      AND country = 'CA'
      AND population = 'Rural',
      0.75,
      1.0
  ) AS y
FROM
  dim_times;
SELECT country,population,uid, ts, y FROM hierarchical_data_with_an_anomalous_date;


country,population,uid,ts,y
US,Urban,0,2023-01-01T00:00:00Z,7.713485518088246
US,Urban,0,2023-01-01T01:00:00Z,8.271630002988095
US,Urban,0,2023-01-01T02:00:00Z,9.06897942577231
US,Urban,0,2023-01-01T03:00:00Z,10.752951540626558
US,Urban,0,2023-01-01T04:00:00Z,8.446783877678373
US,Urban,0,2023-01-01T05:00:00Z,10.420320505129029
US,Urban,0,2023-01-01T06:00:00Z,10.566414334928684
US,Urban,0,2023-01-01T07:00:00Z,12.338550404938388
US,Urban,0,2023-01-01T08:00:00Z,12.256708269662688
US,Urban,0,2023-01-01T09:00:00Z,13.569205837415158


##### Let's see if we can detect anomalies in the data on `2023-01-31` using AI_FORECAST().

In [0]:
WITH
-- Calcuate rollup values for each dimension
rollups AS (
  SELECT country, population, uid, ts, SUM(y) AS y
  FROM hierarchical_data_with_an_anomalous_date
  GROUP BY GROUPING SETS(
    (country, population, uid, ts),
    (country, population, ts),
    (country, ts)
  )
),
-- Get observations for the target investigation date: 2023-01-31
obs AS (SELECT * FROM rollups WHERE DATE(ts) = '2023-01-31'),
-- Calculate forcast for 2023-01-31 using historical data.
fcst AS (
  SELECT * FROM AI_FORECAST(
    TABLE(SELECT * FROM rollups WHERE ts < '2023-01-31'),
    horizon => '2023-02-01',
    time_col => 'ts',
    value_col => 'y',
    group_col => ARRAY('country', 'population', 'uid')
)
)
-- Calculate which groupings have the highest std deviation from the predicted value.
SELECT
  obs.country,
  IF(obs.population IS NULL, '[All]', obs.population) AS population,
  IF(obs.uid IS NULL, '[All]', CAST(obs.uid AS STRING)) AS uid,
  AVG(ABS(obs.y - fcst.y_forecast)) AS mean_abs_deviation
FROM
  obs
JOIN
  fcst
ON
  fcst.ts = obs.ts
  AND fcst.country <=> obs.country
  AND fcst.population <=> obs.population
  AND fcst.uid <=> obs.uid
GROUP BY
  1, 2, 3
ORDER BY
  4 DESC
LIMIT
  15


country,population,uid,mean_abs_deviation
CA,Rural,[All],992.1594437690696
CA,[All],[All],968.2475200109594
CA,Rural,6,176.25363546820472
CA,Rural,8,145.4787203466831
CA,Rural,3,119.33758635533046
CA,Rural,0,118.546828737652
CA,Rural,5,115.44342678249286
CA,Rural,7,84.76986982108967
CA,Rural,4,80.08509231715722
CA,Rural,2,73.04466056147272
