# DuckDB - DTW

Implementing a working example of Dynamic Time Warping in SQL with DuckDB

In [1]:
import duckdb
import numpy as np

In [2]:
db = duckdb.connect(':memory:')

## Mock Data

Initialize two arrays to server as our two comparison sequences:

In [3]:
s1 = np.array([4,5,2.5,1.5,6.4,5.5,7.8,9.0,7.4,2.0,3.0])
s2 = np.array([3.5,3.2,4,6.1,3.2,4.8,7.1,6.0])

## Distance Matrix

Start with a "distance matrix" where we calculate the distance function between every pair of points in the two sequences. In `SQL`, we can accomplish this with a cartesian join.

In this example, we will be using the squared Euclidean distance, as is the default in the `dtaidistance` package we are using for reference. For the initial distance matrix, we record the square of the difference, then the final result will have `sqrt` applied. Substitution of other distance functions will be possible, and dealt with later on.

In [5]:
%%time
db.sql("""
WITH seq1 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s1   
    ),
    seq2 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s2   
    )
SELECT 
        (seq1.value - seq2.value)^2 as dist
       ,seq1.s_index as index1
       ,seq2.s_index as index2
    FROM seq1
    FULL OUTER JOIN
        seq2 ON 1=1
""").show()

┌─────────────────────┬────────┬────────┐
│        dist         │ index1 │ index2 │
│       double        │ int64  │ int64  │
├─────────────────────┼────────┼────────┤
│                0.25 │      0 │      0 │
│                2.25 │      1 │      0 │
│                 1.0 │      2 │      0 │
│                 4.0 │      3 │      0 │
│   8.410000000000002 │      4 │      0 │
│                 4.0 │      5 │      0 │
│               18.49 │      6 │      0 │
│               30.25 │      7 │      0 │
│  15.210000000000003 │      8 │      0 │
│                2.25 │      9 │      0 │
│                  ·  │      · │      · │
│                  ·  │      · │      · │
│                  ·  │      · │      · │
│                 1.0 │      1 │      7 │
│               12.25 │      2 │      7 │
│               20.25 │      3 │      7 │
│ 0.16000000000000028 │      4 │      7 │
│                0.25 │      5 │      7 │
│  3.2399999999999993 │      6 │      7 │
│                 9.0 │      7 │  

## Cost matrix

The cost matrix is given as the minimum cost path to each point, coming from its preceding neighbors, accumulating cost on the way. 

We build upon the prior query with a recursive CTE that terminates at `0,0` and tracks back from each point to all possible preceding neighbors, recursively. At the end, by grouping on the inidices and taking the minimum cost, we are only keeping the minimum cost path to each point. 

The DuckDB `arg_min` function makes it easy to log which of the 3 possible preceding cells the path comes from at each step. This will be used to trace the warping path in the next step.

In [9]:
%%time
db.sql("""
WITH RECURSIVE
    seq1 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s1   
    ),
    seq2 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s2   
    ),
    dist_matrix as (
        SELECT 
            (seq1.value - seq2.value)^2 as dist
            ,seq1.s_index as index1
            ,seq2.s_index as index2
        FROM seq1
        FULL OUTER JOIN
            seq2 ON 1=1
    ),
    cost_matrix_all_steps as (
        -- start at 0,0
        SELECT 
            dist
            ,index1
            ,index2
            ,dist as cost
            ,0 as step
        FROM dist_matrix
        WHERE index1 = 0 and index2 = 0
        UNION ALL 
        -- recurse through 1 of 3 paths
        SELECT 
            d.dist
            ,d.index1
            ,d.index2
            ,d.dist + sub.cost as cost
            , sub.step
        FROM dist_matrix d
        -- LEFT JOIN to a subquery that preselects the lowest cost of the three connected cells
        LEFT JOIN (
            -- match case
             SELECT * EXCLUDE(step), 0 as step FROM cost_matrix_all_steps c1 WHERE d.index1 = c1.index1 + 1 AND d.index2 = c1.index2 + 1
             UNION ALL 
            -- insertion case
             SELECT * EXCLUDE(step), 1 as step FROM cost_matrix_all_steps c2 WHERE d.index1 = c2.index1 + 1 AND d.index2 = c2.index2
             UNION ALL 
            -- deletion case
             SELECT * EXCLUDE(step), 2 as step FROM cost_matrix_all_steps c3 WHERE d.index1 = c3.index1 AND d.index2 = c3.index2 + 1
         ) sub on 1=1
       WHERE (d.index1 > 0 OR d.index2 > 0)
         AND 
       cost is not null
    ),
cost_matrix as (
    SELECT  
    sqrt(min(cost)) as cost
    ,index1
    ,index2
    ,arg_min(step,cost) as step
    FROM cost_matrix_all_steps
    GROUP BY index1, index2
    )
SELECT * FROM cost_matrix
       """).show()

┌────────────────────┬────────┬────────┬───────┐
│        cost        │ index1 │ index2 │ step  │
│       double       │ int64  │ int64  │ int32 │
├────────────────────┼────────┼────────┼───────┤
│                0.5 │      0 │      0 │  NULL │
│ 1.8681541692269403 │      1 │      1 │     0 │
│ 1.7291616465790582 │      2 │      1 │     0 │
│ 0.9433981132056602 │      0 │      2 │     2 │
│ 2.4248711305964283 │      3 │      1 │     1 │
│ 1.4491376746189435 │      1 │      3 │     0 │
│ 3.8535697735995385 │      2 │      3 │     0 │
│  5.029910535983716 │      3 │      3 │     0 │
│ 3.4117444218463966 │      4 │      2 │     0 │
│  2.310844001658268 │      1 │      4 │     2 │
│          ·         │      · │      · │     · │
│          ·         │      · │      · │     · │
│          ·         │      · │      · │     · │
│  3.566510900025402 │      8 │      6 │     1 │
│ 3.8196858509568563 │      8 │      7 │     0 │
│  9.235799911215056 │      9 │      1 │     0 │
│  6.284106937345991

## Warping Path

By starting a recursive CTE at the terminal point (index `N-1,M-1`) and stepping back along the recorded step directions, we can extract the warping path

In [19]:
%%time
db.sql("""
WITH RECURSIVE
    seq1 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s1   
    ),
    seq2 as (
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s2   
    ),
    dist_matrix as (
        SELECT 
            (seq1.value - seq2.value)^2 as dist
            ,seq1.s_index as index1
            ,seq2.s_index as index2
        FROM seq1
        FULL OUTER JOIN
            seq2 ON 1=1
    ),
    cost_matrix_all_steps as (
        -- start at 0,0
        SELECT 
            dist
            ,index1
            ,index2
            ,dist as cost
            ,0 as step
        FROM dist_matrix
        WHERE index1 = 0 and index2 = 0
        UNION ALL 
        -- recurse through 1 of 3 paths
        SELECT 
            d.dist
            ,d.index1
            ,d.index2
            ,d.dist + sub.cost as cost
            , sub.step
        FROM dist_matrix d
        -- LEFT JOIN to a subquery that preselects the lowest cost of the three connected cells
        LEFT JOIN (
            -- match case
             SELECT * EXCLUDE(step), 0 as step FROM cost_matrix_all_steps c1 WHERE d.index1 = c1.index1 + 1 AND d.index2 = c1.index2 + 1
             UNION ALL 
            -- insertion case
             SELECT * EXCLUDE(step), 1 as step FROM cost_matrix_all_steps c2 WHERE d.index1 = c2.index1 + 1 AND d.index2 = c2.index2
             UNION ALL 
            -- deletion case
             SELECT * EXCLUDE(step), 2 as step FROM cost_matrix_all_steps c3 WHERE d.index1 = c3.index1 AND d.index2 = c3.index2 + 1
         ) sub on 1=1
       WHERE (d.index1 > 0 OR d.index2 > 0)
         AND 
       cost is not null
    ),
cost_matrix as (
    SELECT  
    sqrt(min(cost)) as cost
    ,index1
    ,index2
    ,arg_min(step,cost) as step
    FROM cost_matrix_all_steps
    GROUP BY index1, index2
    ),
    warping_path as (
       -- recursive function to start at top-right corner and trace back to origin
       SELECT 
          * 
       ,CASE WHEN step = 0 or step = 1 THEN index1 - 1
                ELSE index1 END
        as next_index1 
       ,CASE WHEN step = 0 or step = 2 THEN index2 - 1
                ELSE index2 END
        as next_index2
       FROM cost_matrix
          WHERE index1 = (SELECT max(s_index) FROM seq1) AND index2 = (SELECT max(s_index) FROM seq2)
       
       UNION ALL
       
       SELECT 
          c.* 
       ,CASE WHEN c.step = 0 or c.step = 1 THEN c.index1 - 1
                ELSE c.index1 END
        as next_index1 
       ,CASE WHEN c.step = 0 or c.step = 2 THEN c.index2 - 1
                ELSE c.index2 END
        as next_index2
       FROM warping_path wp
       INNER JOIN cost_matrix c ON c.index1 = wp.next_index1 AND c.index2 = wp.next_index2
       WHERE c.index1 >= 0 AND c.index2 >= 0
       )
    SELECT index1, index2, cost FROM warping_path
""").show()

┌────────┬────────┬────────────────────┐
│ index1 │ index2 │        cost        │
│ int64  │ int64  │       double       │
├────────┼────────┼────────────────────┤
│     10 │      7 │  6.141661013113635 │
│      9 │      7 │  5.359104402789705 │
│      8 │      6 │  3.566510900025402 │
│      7 │      6 │ 3.5538711287833724 │
│      6 │      6 │ 3.0033314835362415 │
│      5 │      5 │  2.920616373302047 │
│      4 │      5 │ 2.8354893757515653 │
│      3 │      4 │  2.340939982143925 │
│      2 │      4 │ 1.6093476939431077 │
│      1 │      3 │ 1.4491376746189435 │
│      0 │      2 │ 0.9433981132056602 │
│      0 │      1 │ 0.9433981132056602 │
│      0 │      0 │                0.5 │
├────────┴────────┴────────────────────┤
│ 13 rows                    3 columns │
└──────────────────────────────────────┘

CPU times: user 26.1 s, sys: 783 ms, total: 26.9 s
Wall time: 14.6 s


## Validation

Let's compare to the solution from `dtaidistance` to check that the result is correct. We will use the `_fast` variant of each method as the target for performance as well.

In [13]:
from dtaidistance.dtw import distance_fast, warping_path_fast

In [17]:
%%time
warping_path_fast(s1,s2)

CPU times: user 35 μs, sys: 1e+03 ns, total: 36 μs
Wall time: 39.1 μs


[(0, 0),
 (0, 1),
 (0, 2),
 (1, 3),
 (2, 4),
 (3, 4),
 (4, 5),
 (5, 5),
 (6, 6),
 (7, 6),
 (8, 6),
 (9, 7),
 (10, 7)]

In [18]:
%%time
distance_fast(s1, s2)

CPU times: user 53 μs, sys: 0 ns, total: 53 μs
Wall time: 56.3 μs


6.141661013113635

The calculated square Euclidean DTW distance of `6.14166` matches between the DuckDB method and the `dtaidistance` calculation. The warping paths are also identical. (read the table output by the DuckDB query from bottom to top, and see that the indices match the path given by `warping_path_fast`)

However, the performance difference is substantial. This initial attempt with DuckDB took about 153,000x as long to execute! (Running on a min-spec GitHub Codespaces instance with 2 cores)

## Performance Improvement

There are a few clear opportunities for performance improvement:

* Breaking one monolithic query (with multiple recursive CTEs) into distinct materialized tables
* Reducing unnecessary recursion: The initial approach above traverses all possible paths through the distance matrix, then filters down to the lowest cost paths outside of the recursion. Moving this into the recursive CTE should significantly reduce required computation.

### Incremental Materialization

In [20]:
%%time
db.sql("""
CREATE OR REPLACE TEMP TABLE seq1 as 
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s1   
    ;

CREATE OR REPLACE TEMP TABLE seq2 as 
        SELECT * as value, row_number() OVER () - 1 as s_index
        FROM s2   
    ;

CREATE OR REPLACE TEMP TABLE dist_matrix as 
        SELECT 
            (seq1.value - seq2.value)^2 as dist
            ,seq1.s_index as index1
            ,seq2.s_index as index2
        FROM seq1
        FULL OUTER JOIN
            seq2 ON 1=1
    ;

CREATE OR REPLACE TEMP TABLE cost_matrix as
    with recursive cost_matrix_all_steps as (
        -- start at 0,0
        SELECT 
            dist
            ,index1
            ,index2
            ,dist as cost
            ,0 as step
        FROM dist_matrix
        WHERE index1 = 0 and index2 = 0
        UNION ALL 
        -- recurse through 1 of 3 paths
        SELECT 
            d.dist
            ,d.index1
            ,d.index2
            ,d.dist + sub.cost as cost
            , sub.step
        FROM dist_matrix d
        -- LEFT JOIN to a subquery that preselects the lowest cost of the three connected cells
        LEFT JOIN (
            -- match case
             SELECT * EXCLUDE(step), 0 as step FROM cost_matrix_all_steps c1 WHERE d.index1 = c1.index1 + 1 AND d.index2 = c1.index2 + 1
             UNION ALL 
            -- insertion case
             SELECT * EXCLUDE(step), 1 as step FROM cost_matrix_all_steps c2 WHERE d.index1 = c2.index1 + 1 AND d.index2 = c2.index2
             UNION ALL 
            -- deletion case
             SELECT * EXCLUDE(step), 2 as step FROM cost_matrix_all_steps c3 WHERE d.index1 = c3.index1 AND d.index2 = c3.index2 + 1
         ) sub on 1=1
       WHERE (d.index1 > 0 OR d.index2 > 0)
         AND 
       cost is not null
    )
    SELECT  
    sqrt(min(cost)) as cost
    ,index1
    ,index2
    ,arg_min(step,cost) as step
    FROM cost_matrix_all_steps
    GROUP BY index1, index2
    ;

    with recursive warping_path as (
       -- recursive function to start at top-right corner and trace back to origin
       SELECT 
          * 
       ,CASE WHEN step = 0 or step = 1 THEN index1 - 1
                ELSE index1 END
        as next_index1 
       ,CASE WHEN step = 0 or step = 2 THEN index2 - 1
                ELSE index2 END
        as next_index2
       FROM cost_matrix
          WHERE index1 = (SELECT max(s_index) FROM seq1) AND index2 = (SELECT max(s_index) FROM seq2)
       
       UNION ALL
       
       SELECT 
          c.* 
       ,CASE WHEN c.step = 0 or c.step = 1 THEN c.index1 - 1
                ELSE c.index1 END
        as next_index1 
       ,CASE WHEN c.step = 0 or c.step = 2 THEN c.index2 - 1
                ELSE c.index2 END
        as next_index2
       FROM warping_path wp
       INNER JOIN cost_matrix c ON c.index1 = wp.next_index1 AND c.index2 = wp.next_index2
       WHERE c.index1 >= 0 AND c.index2 >= 0
       )
    SELECT index1, index2, cost FROM warping_path
""").show()

┌────────┬────────┬────────────────────┐
│ index1 │ index2 │        cost        │
│ int64  │ int64  │       double       │
├────────┼────────┼────────────────────┤
│     10 │      7 │  6.141661013113635 │
│      9 │      7 │  5.359104402789705 │
│      8 │      6 │  3.566510900025402 │
│      7 │      6 │ 3.5538711287833724 │
│      6 │      6 │ 3.0033314835362415 │
│      5 │      5 │  2.920616373302047 │
│      4 │      5 │ 2.8354893757515653 │
│      3 │      4 │  2.340939982143925 │
│      2 │      4 │ 1.6093476939431077 │
│      1 │      3 │ 1.4491376746189435 │
│      0 │      2 │ 0.9433981132056602 │
│      0 │      1 │ 0.9433981132056602 │
│      0 │      0 │                0.5 │
├────────┴────────┴────────────────────┤
│ 13 rows                    3 columns │
└──────────────────────────────────────┘

CPU times: user 1.4 s, sys: 71.1 ms, total: 1.47 s
Wall time: 883 ms


That is over **16x** faster than the initial attemp with DuckDB, but still 9250x slower than `dtaidistance.distance_fast`