In [1]:
import dask
import dask.array as da
import numpy as np
import operator

### inputs

In [2]:
m = 1000
mc = 100

In [3]:
A = da.random.normal(0, 1, (m,m), chunks=mc)
b = da.random.normal(0, 1./m, m, chunks=mc)
x0 = None
M = None
name = None

### before iteration

In [4]:
def block_cg_initialize(A, b, M, x0, name=None):
    token = name or dask.base.tokenize(A)
    itertoken = 'cg-iter-0-' + token
    nblks_1d = (len(b.chunks[0]),)
    nblks_2d = len(A.chunks[0]), len(A.chunks[1])
  
    # TODO: rechunk x0?

    _r = 'r-' + itertoken
    _x = 'x-' + itertoken
    _Ax = 'Ax-' + itertoken
    _p = 'p-' + itertoken
    _Mr = 'Mr-' + itertoken
    _resnrm2 = 'resnrm2-' + itertoken
    
    if x0 is None:
        dsk_r = da.core.top(lambda bi: bi, _r, 'i', b.name, 'i', 
                            numblocks={b.name: nblks_1d})
        dsk_x = da.core.top(lambda ri: 0 * ri, _x, 'i', _r, 'i', 
                            numblocks={_r: nblks_1d})
        dsk_Ax = None
    else:
        dsk_x = da.core.top(lambda x0i: x0i, _x, 'i', x0.name, 'i', 
                            numblocks={x0.name: nblks_1d})
        dsk_Ax = da.core.top(da.core.dotmany, _Ax, 'i', A.name, 'ij', _x, 'j', 
                             numblocks={A.name: nblks_2d, _x: nblks_1d})
        dsk_r = da.core.top(operator.sub, _r, 'i', b.name, 'i', _Ax, 'i',
                            numblocks={b.name: nblks_1d, _Ax: nblks_1d})
        
    if M is None:
        dsk_Mr = {(_Mr, key[1]): dsk_r[_r, key[1]] for key in dsk_r}
    else:
        raise NotImplementedError
    
    dsk_p = {(_p, key[1]): dsk_Mr[_Mr, key[1]] for key in dsk_Mr}
    dsk_resnrm2 = da.core.top(da.core.dotmany, _resnrm2, '', _r, 'i', _Mr, 'i',
                              numblocks={_r: nblks_1d, _Mr: nblks_1d})


    dsk = dask.sharedict.merge(A.dask, b.dask)
    dsk.update_with_key(dsk_x, _x)
    if dsk_Ax is not None:
        dsk.update_with_key(dsk_Ax, _Ax)
    dsk.update_with_key(dsk_r, _r)
    dsk.update_with_key(dsk_Mr, _Mr)
    dsk.update_with_key(dsk_p, _p)
    dsk.update_with_key(dsk_resnrm2, _resnrm2)

    x = da.Array(dsk, _x, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    r = da.Array(dsk, _r, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    p = da.Array(dsk, _p, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    resnrm2 = da.Array(dsk, _resnrm2, shape=(), chunks=(), dtype=b.dtype)
    x, r, p, resnrm2 = dask.persist(x, r, p, resnrm2)
    return dsk, x, resnrm2

### iteration

In [25]:
def block_cg_iterate(dsk, A, M, iteration, name=None):
    m, _ = A.shape
    chunks_1d = (A.chunks[1],)
    nblks_2d = vblocks, hblocks = len(A.chunks[0]), len(A.chunks[1])
    nblks_1d = (vblocks,)

    token = name or dask.base.tokenize(A)
    itertoken = 'cg-iter-' + str(iteration) + '-' + token
    oitertoken = 'cg-iter-' + str(iteration - 1) + '-' + token

    _Ap = 'Ap-' + itertoken
    _alpha = 'alpha-' + itertoken
    _beta = 'beta-' + itertoken
    _gamma = 'gamma-' + itertoken
    _x = 'x-' + itertoken
    _ox = 'x-' + oitertoken
    _r = 'r-' + itertoken
    _or = 'r-' + oitertoken
    _p = 'p-' + itertoken
    _op = 'p-' + oitertoken
    _Mr = 'Mr-' + itertoken
    _resnrm2 = 'resnrm2-' + itertoken
    _oresnrm2 = 'resnrm2-' + oitertoken

    # alpha = oresnrm2 / p.dot(Ap)
    dsk_Ap = da.core.top(da.core.dotmany, _Ap, 'i', A.name, 'ij', _op, 'j',
                         numblocks={A.name: nblks_2d, _op: nblks_1d})
    dsk_gamma = da.core.top(da.core.dotmany, _gamma, '', _op, 'i', _Ap, 'i', 
                            numblocks={_op: nblks_1d, _Ap: nblks_1d})
    dsk_alpha = da.core.top(operator.div, _alpha, '', _oresnrm2, '', _gamma, '',
                            numblocks={_oresnrm2: (), _gamma: ()})

    # x = ox + alpha * p
    def update_x(xi, pi, alpha): return xi + alpha * pi
    dsk_x = da.core.top(update_x, _x, 'i', _ox, 'i', _op, 'i', _alpha, '',
                        numblocks={_ox: nblks_1d, _op: nblks_1d, _alpha: ()})

    # r = or - alpha * Ap
    def update_r(ri, Api, alpha): return ri - alpha * Api
    dsk_r = da.core.top(update_r, _r, 'i', _or, 'i', _op, 'i', _alpha, '',
                        numblocks={_or: nblks_1d, _op: nblks_1d, _alpha: ()})

    # resnrm2 = r'Mr
    if M is None:
        dsk_Mr = {(_Mr, rkey[1]): dsk_r[_r, rkey[1]] for rkey in dsk_r}
    else:
        raise NotImplementedError

    dsk_resnrm2 = da.core.top(da.core.dotmany, _resnrm2, '', _r, 'i', _Mr, 'i',
                              numblocks={_r: nblks_1d, _Mr: nblks_1d})

    # p = Mr + (resnrm2 / oresnrm2) * op
    dsk_beta = da.core.top(operator.div, _beta, '', _resnrm2, '', _oresnrm2, '',
                           numblocks={ _resnrm2: (), _oresnrm2: ()})
    def update_p(Mri, pi, beta): return Mri + beta * pi
    dsk_p = da.core.top(update_p, _p, 'i', _Mr, 'i', _op, 'i', _beta, '',
                          numblocks={_Mr: nblks_1d, _op: nblks_1d, _beta: ()})

    dsk = dask.sharedict.merge(dsk, A.dask)
    dsk.update_with_key(dsk_Ap, key=_Ap)
    dsk.update_with_key(dsk_gamma, key=_gamma)
    dsk.update_with_key(dsk_alpha, key=_alpha)
    dsk.update_with_key(dsk_x, key=_x)
    dsk.update_with_key(dsk_r, key=_r)    
    dsk.update_with_key(dsk_Mr, key=_Mr)
    dsk.update_with_key(dsk_resnrm2, key=_resnrm2)
    dsk.update_with_key(dsk_beta, key=_beta)
    dsk.update_with_key(dsk_p, key=_p)

    x = da.Array(dsk, _x, shape=(m,), chunks=chunks_1d, dtype=A.dtype)
    r = da.Array(dsk, _r, shape=(m,), chunks=chunks_1d, dtype=A.dtype)
    p = da.Array(dsk, _p, shape=(m,), chunks=chunks_1d, dtype=A.dtype)
    resnrm2 = da.Array(dsk, _resnrm2, shape=(), chunks=(), dtype=A.dtype)
    x, r, p, resnrm2 = dask.persist(x, r, p, resnrm2)
    dsk = dask.sharedict.merge(x.dask, r.dask, p.dask, resnrm2.dask) # prune all but state vars from dictionary
    return dsk, x, resnrm2

In [26]:
dsk, x, resnrm2 = block_cg_initialize(A, b, M, x0)

In [27]:
dsk, x, resnrm2 = block_cg_iterate(dsk, A, M, 1)

In [28]:
dask.compute(resnrm2)

(0.00012719458495403515,)

In [29]:
tol = 1e-8

In [30]:
dsk, x, resnrm2 = block_cg_initialize(A, b, M, x0, name=name)
(resnrm2,) = dask.compute(resnrm2)
if resnrm2**0.5 < tol:
    print "OK AT INIT"
#     return x, 0, resnrm2**0.5

In [31]:
print_iter = 1
maxiter = 100

In [32]:
for k in range(1, maxiter + 1):
    dsk, x, resnrm2 = block_cg_iterate(dsk, A, M, k, name=name)
    (resnrm2,) = dask.compute(resnrm2)
    if resnrm2**0.5 < tol:
        break
    elif k % print_iter == 0:
        print('ITER: {:5}\t||Ax - b||_2: {}'.format(k, resnrm2**0.5))

ITER:     1	||Ax - b||_2: 0.0112780576765
ITER:     2	||Ax - b||_2: 0.00578599834704
ITER:     3	||Ax - b||_2: 0.00354360763911
ITER:     4	||Ax - b||_2: 0.00240309899525
ITER:     5	||Ax - b||_2: 0.00174178087361
ITER:     6	||Ax - b||_2: 0.00132314012372
ITER:     7	||Ax - b||_2: 0.00104085954468
ITER:     8	||Ax - b||_2: 0.000841234881693
ITER:     9	||Ax - b||_2: 0.00069469980997
ITER:    10	||Ax - b||_2: 0.000583860029814
ITER:    11	||Ax - b||_2: 0.00049792703754
ITER:    12	||Ax - b||_2: 0.000429914665095
ITER:    13	||Ax - b||_2: 0.000375133704427
ITER:    14	||Ax - b||_2: 0.000330339722127
ITER:    15	||Ax - b||_2: 0.000293228759886
ITER:    16	||Ax - b||_2: 0.000262127401979
ITER:    17	||Ax - b||_2: 0.000235795976023
ITER:    18	||Ax - b||_2: 0.000213299980297
ITER:    19	||Ax - b||_2: 0.00019392395189
ITER:    20	||Ax - b||_2: 0.000177112467965
ITER:    21	||Ax - b||_2: 0.000162428921045
ITER:    22	||Ax - b||_2: 0.000149526194031
ITER:    23	||Ax - b||_2: 0.000138125460052