### Dask Kronecker Product Implementations

In [1]:
import warnings
import numpy as np
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph
from dask import core
from dask.array.core import operator
from dask.distributed import Client
import dask.array as da

In [2]:
client = Client(n_workers=1, threads_per_worker=1, processes=False, scheduler_port=8079)
client

0,1
Client  Scheduler: tcp://127.0.0.1:8079  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 1  Cores: 1  Memory: 134.78 GB


In [3]:
def validate_result(x, y, z):
    z = z.compute()
    assert z.shape == z.shape
    assert np.all(np.kron(x.compute(), y.compute()) == z)

Build inputs test test with:

In [4]:
x = da.arange(36).reshape(9, 4).rechunk((3, 2))
x

Unnamed: 0,Array,Chunk
Bytes,288 B,48 B
Shape,"(9, 4)","(3, 2)"
Count,14 Tasks,6 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 288 B 48 B Shape (9, 4) (3, 2) Count 14 Tasks 6 Chunks Type int64 numpy.ndarray",4  9,

Unnamed: 0,Array,Chunk
Bytes,288 B,48 B
Shape,"(9, 4)","(3, 2)"
Count,14 Tasks,6 Chunks
Type,int64,numpy.ndarray


In [5]:
y = da.ones(36, dtype=x.dtype).reshape(9, 4).rechunk((3, 2))
y

Unnamed: 0,Array,Chunk
Bytes,288 B,48 B
Shape,"(9, 4)","(3, 2)"
Count,14 Tasks,6 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 288 B 48 B Shape (9, 4) (3, 2) Count 14 Tasks 6 Chunks Type int64 numpy.ndarray",4  9,

Unnamed: 0,Array,Chunk
Bytes,288 B,48 B
Shape,"(9, 4)","(3, 2)"
Count,14 Tasks,6 Chunks
Type,int64,numpy.ndarray


### V1 - Using ```da.blockwise```

In [6]:
def kron_v1(x, y):
    # Rechunk left array to single data elements
    x = x.rechunk((1, 1))
    return da.blockwise(
        np.multiply, 'ij', x, 'ij', y, 'xy', concatenate=True, dtype='f8',
        adjust_chunks={'i': y.shape[0], 'j': y.shape[1]}
    )
z = kron_v1(x, y)
validate_result(x, y, z)
z.compute()

array([[ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       ...,
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35]])

### V2 - Using ```da.block```

In [7]:
def kron_v2(x, y):
    return da.block([
        [x[i, j] * y for j in range(x.shape[1])]
        for i in range(x.shape[0])
    ])
z = kron_v2(x, y)
validate_result(x, y, z)
z.compute()

array([[ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       ...,
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35]])

In [8]:
# Validate with np array argument on left as well
z = kron_v2(x.compute(), y)
validate_result(x, y, z)
z.compute()

array([[ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       ...,
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35]])

## V3 - Using custom graph

The intent with this approach is to assemble a new graph representing a kronecker product based on compositions of task nodes in the existing graphs for the input arrays.  

In [9]:
# Show all the tasks in the graph to create x
list(dict(x.__dask_graph__()))

[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 1),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 1),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 1),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 0),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 1),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 2),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 3),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 4),
 ('rechunk-split-3b7b3559a7b99771b58b4b9c451a03c5', 5),
 ('reshape-126131d495cd2afae9f2085c3fb52292', 0, 0),
 ('arange-e599d6c2e52a053c03cd96483fdb779d', 0)]

In [10]:
# Show only the tasks that represent results
list(dict(x.__dask_keys__())) 

[('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 0, 0),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 1, 0),
 ('rechunk-merge-3b7b3559a7b99771b58b4b9c451a03c5', 2, 0)]

In [11]:
# See https://docs.dask.org/en/latest/array-design.html#example-eye-function for an example Array graph construction

def kron_v3(x, y):
    """ Kronecker Product 
    
    Limitations:
        - Assumes equally size chunks within each array
        - Only works for 2D arrays
    """
    chunks = (y.chunks[0] * x.shape[0], y.chunks[1] * x.shape[1])
    
    name = 'kron-' + tokenize(x.name, y.name)
    
    def get_chunk_index(i, j):
        """ Determine index of array coordinates within a chunk """
        return (i % x.chunksize[0], j % x.chunksize[1])
    
    def get_chunk_coords(i, j):
        """ Determine chunk coordinates for array indices """
        return (i // x.chunksize[0], j // x.chunksize[1])
    
    def array_idx_op(xi, xj):
        """ Build task for selecting single element of a matrix provided scalar indices """
        chk_idx = get_chunk_index(xi, xj)
        chk_crd = get_chunk_coords(xi, xj)
        # Example selecting item in second column and row of first chunk:
        # (operator.getitem, (x.name, 0, 0), (1, 1))
        return (operator.getitem, (x.name, *chk_crd), chk_idx)
    
    def kron_key(xi, xj, yk):
        """ Build chunk key for resulting block in product """
        # Result has as many blocks in one dimension as there are chunks
        # in y times the length of x along that same dimension
        return (name, xi*len(y.chunks[0]) + yk[1], xj*len(y.chunks[1]) + yk[2])
    
    layer = {
        # Map kron product block to operation with copy of y times single element of x
        # NOTE: It is crucial here that all operations do not refer to x and y directly 
        # as this simply nests the graphs for them within this one -- the construction
        # here must instead only refer to keys within the graphs of x and y
        kron_key(xi, xj, yk): (operator.mul, array_idx_op(xi, xj), yk) 
        for xi in range(x.shape[0])
        for xj in range(x.shape[1])
        for yk in core.flatten(y.__dask_keys__())
    }
    # Many of the task keys reference above don't exist in the layer dict just created, so it 
    # is crucial that x and y are provided here as dependencies so that the graphs can be merged
    dsk = HighLevelGraph.from_collections(name, layer, dependencies=[x, y])
    return da.Array(dsk, name, chunks, dtype=x.dtype)

z = kron_v3(x, y)
validate_result(x, y, z)
z.compute()

array([[ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       [ 0,  0,  0, ...,  3,  3,  3],
       ...,
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35],
       [32, 32, 32, ..., 35, 35, 35]])

### Benchmarks

Create a little bigger arrays for benchmarking; note that two 50x20 arrays give a 2500x400 (1M elems) result so these should still remain small for local testing:

In [12]:
xl = da.arange(50*20).reshape(50, 20).rechunk((25, 4))
xl

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,800 B
Shape,"(50, 20)","(25, 4)"
Count,22 Tasks,10 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 8.00 kB 800 B Shape (50, 20) (25, 4) Count 22 Tasks 10 Chunks Type int64 numpy.ndarray",20  50,

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,800 B
Shape,"(50, 20)","(25, 4)"
Count,22 Tasks,10 Chunks
Type,int64,numpy.ndarray


In [13]:
yl = da.ones(xl.size, dtype=xl.dtype).reshape(xl.shape).rechunk(xl.chunksize)
yl

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,800 B
Shape,"(50, 20)","(25, 4)"
Count,22 Tasks,10 Chunks
Type,int64,numpy.ndarray
"Array Chunk Bytes 8.00 kB 800 B Shape (50, 20) (25, 4) Count 22 Tasks 10 Chunks Type int64 numpy.ndarray",20  50,

Unnamed: 0,Array,Chunk
Bytes,8.00 kB,800 B
Shape,"(50, 20)","(25, 4)"
Count,22 Tasks,10 Chunks
Type,int64,numpy.ndarray


In [13]:
%%timeit -n 3 -r 3
with warnings.catch_warnings():
    # Ignore "PerformanceWarning: Increasing number of chunks by factor of X"; 
    # we know the rechunking to (1,1) will do this
    warnings.simplefilter(action='ignore', category=da.PerformanceWarning)
    kron_v1(xl, yl).compute()

8.47 s ± 67.9 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [14]:
%%timeit -n 3 -r 3
kron_v2(xl, yl).compute()

29.1 s ± 266 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [15]:
%%timeit -n 3 -r 3
kron_v3(xl, yl).compute()

18.1 s ± 217 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [16]:
%%timeit -n 3 -r 3
np.kron(xl.compute(), yl.compute())

151 ms ± 8.67 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [17]:
%%timeit -n 3 -r 3
# Compare to elementwise multiplication of random matrices having shape 
# equal to kronecker product output shape (same number of multiplications)
xr1 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))
xr2 = da.random.normal(size=(xl.shape[0]**2, xl.shape[1]**2))
(xr1 * xr2).compute()

103 ms ± 6.34 ms per loop (mean ± std. dev. of 3 runs, 3 loops each)


In [21]:
# Also test the v2 version with a numpy input on the left --
# This is almost 2x slower somehow?
xla = xl.compute()
%timeit -n 3 -r 3 kron_v2(xla, yl).compute()

57.9 s ± 1.53 s per loop (mean ± std. dev. of 3 runs, 3 loops each)


**Conclusion**: 1x1 rechunking + ```da.blockwise``` performs better than the others, but all of the methods are extremely slow considering that an elementwise multiplication involving the same number of FLOPs takes 1% as much time (.1/8.47 ~= .01). 