In [1]:
import dask
import dask.array as da

In [2]:
def update(a):
    aold = a
    a = 2 * aold
    (a,) = dask.persist(a)
    return a

In [3]:
b = da.random.random(10, chunks=10)

In [4]:
print b.dask.keys()

[('da.random.random_sample-a8690c4e6e64ba86cd9d6011f8397875', 0)]


In [5]:
c = update(b)

In [6]:
print c.dask.keys()

[('mul-76923e61f627666b57781f497c3ba165', 0)]


In [7]:
d = update(c)

In [8]:
print d.dask.keys()

[('mul-59d25d3329b2b800e03a0e6ea21ae6c0', 0)]


In [9]:
def binary_update(a, b):
    aold, bold = a, b
    a = 2 * aold + 0.5 * bold
    b = 0.5 * aold + 2 * bold
    a, b = dask.persist(a, b)
    return a, b

In [10]:
A = da.random.random(10, chunks=10)
B = da.random.random(10, chunks=10)

In [11]:
A2, B2 = binary_update(A, B)

In [12]:
for i in range(10):
    A2, B2 = binary_update(A, B)
    assert len(A.dask) == len(A2.dask), 'A dasks equal len'
    assert len(B.dask) == len(B2.dask), 'B dasks equal len'
    A, B = A2, B2

print 'done'

done


In [13]:
def quaternary_update(a, b, c, d):
    aold, bold, cold, dold = a, b, c, d
    a = 2 * aold + 0.5 * bold
    b = 0.5 * aold + 2 * bold
    c = cold - 2 * b
    d = b.dot(b)
    a, b, c, d = dask.persist(a, b, c, d)
    return a, b, c, d

In [14]:
A = da.random.random(10, chunks=10)
B = da.random.random(10, chunks=10)
C = da.random.random(10, chunks=10)
D = B.dot(B)
(D,) = dask.persist(D)

In [15]:
A2, B2, C2, D2 = quaternary_update(A, B, C, D)

In [16]:
print len(A.dask), len(A2.dask)
print len(B.dask), len(B2.dask)
print len(C.dask), len(C2.dask)
print len(D.dask), len(D2.dask)

1 1
1 1
1 1
1 1


In [17]:
def quaternary_update_with_data(a, b, c, d, data):
    aold, bold, cold, dold = a, b, c, d
    a = 2 * data.dot(aold) + 0.5 * bold
    b = 0.5 * data.dot(aold) + 2 * bold
    c = cold - 2 * b
    d = b.dot(b)
    a, b, c, d = dask.persist(a, b, c, d)
    return a, b, c, d

In [18]:
DATA = da.random.random((10, 10), chunks=10)

In [19]:
A2, B2, C2, D2 = quaternary_update_with_data(A, B, C, D, DATA)

In [20]:
print len(A.dask), len(A2.dask)
print len(B.dask), len(B2.dask)
print len(C.dask), len(C2.dask)
print len(D.dask), len(D2.dask)

1 1
1 1
1 1
1 1


In [21]:
m = 20
mc = 10

In [22]:
A = da.random.random(m, chunks=mc)
B = da.random.random(m, chunks=mc)
C = da.random.random(m, chunks=mc)
D = B.dot(B)
(D,) = dask.persist(D)
DATA = da.random.random((m, m), chunks=mc)

In [23]:
A2, B2, C2, D2 = quaternary_update_with_data(A, B, C, D, DATA)
print len(A.dask), len(A2.dask)
print len(B.dask), len(B2.dask)
print len(C.dask), len(C2.dask)
print len(D.dask), len(D2.dask)

2 2
2 2
2 2
1 1


In [24]:
def ternary_update_with_data(x, r, p, A):
    ox, or_, op = x, r, p
    Ap = A.dot(op)
    alpha = or_.dot(or_) / op.dot(Ap)
    x = ox + alpha * op
    r = or_ - alpha * Ap
    p = r + op * r.dot(r) / or_.dot(or_)
    x, r, p = dask.persist(x, r, p)
    return x, r, p

In [25]:
A2, B2, C2 = ternary_update_with_data(A, B, C, DATA)
print len(A.dask), len(A2.dask)
print len(B.dask), len(B2.dask)
print len(C.dask), len(C2.dask)

2 2
2 2
2 2


In [26]:
import time 

In [27]:
for i in range(100):
    start = time.time()
    A2, B2, C2 = ternary_update_with_data(A, B, C, DATA)
    assert len(A2.dask) == len(A.dask)
    assert len(B2.dask) == len(B.dask)
    assert len(C2.dask) == len(C.dask)
    A, B, C = A2, B2, C2
    print i, time.time() - start

0 0.0564198493958
1 0.0598909854889
2 0.0627830028534
3 0.061320066452
4 0.0502400398254
5 0.048807144165
6 0.0493011474609
7 0.0481998920441
8 0.049870967865
9 0.0474500656128
10 0.0494909286499
11 0.0478911399841
12 0.0462191104889
13 0.047847032547
14 0.0516381263733
15 0.0589091777802
16 0.0600678920746
17 0.0731048583984
18 0.0747520923615
19 0.0685570240021
20 0.0519578456879
21 0.0485479831696
22 0.052463054657
23 0.064857006073
24 0.0730340480804
25 0.0662229061127
26 0.0899589061737
27 0.0755879878998
28 0.0500431060791
29 0.0460929870605
30 0.0457689762115
31 0.0507090091705
32 0.0475850105286
33 0.0504679679871
34 0.0624830722809
35 0.0630528926849
36 0.0579299926758
37 0.0616838932037
38 0.0743989944458
39 0.0756669044495
40 0.0866801738739
41 0.0603241920471
42 0.0468690395355
43 0.0528130531311
44 0.0751800537109
45 0.0656337738037
46 0.0571029186249
47 0.0489459037781
48 0.0496890544891
49 0.0527808666229
50 0.0457608699799
51 0.0492880344391
52 0.0907151699066
53 0.0786

In [28]:
cg_iterate = ternary_update_with_data

In [29]:
times = []
for i in range(100):
    start = time.time()
    A2, B2, C2 = cg_iterate(A, B, C, DATA)
    assert len(A2.dask) == len(A.dask)
    assert len(B2.dask) == len(B.dask)
    assert len(C2.dask) == len(C.dask)
    A, B, C = A2, B2, C2
    times.append(time.time() - start)

In [30]:
import numpy as np

In [31]:
np.mean(times), np.max(times), np.argmax(times), np.min(times), np.argmin(times)

(0.048107810020446777, 0.07526087760925293, 98, 0.038157939910888672, 37)

In [32]:
def multi_update_stats(m, mc, iters):
    A = da.random.random(m, chunks=mc)
    B = da.random.random(m, chunks=mc)
    C = da.random.random(m, chunks=mc)
    D = B.dot(B)
    (D,) = dask.persist(D)
    DATA = da.random.random((m, m), chunks=mc)
    times = []
    for i in range(100):
        start = time.time()
        A2, B2, C2 = cg_iterate(A, B, C, DATA)
        assert len(A2.dask) == len(A.dask)
        assert len(B2.dask) == len(B.dask)
        assert len(C2.dask) == len(C.dask)
        A, B, C = A2, B2, C2
        times.append(time.time() - start)
    return np.mean(times), np.max(times), np.argmax(times), np.min(times), np.argmin(times), np.sum(times)

In [33]:
if False:
    for i in range(4):
        scaling = 10**i
        print 20 * scaling, multi_update_stats(20 * scaling, 10 * scaling, 100)

### results from above scaling experiment
m      | mean (s)             | max (s)              | argmax | min (s)              | argmin | total (s)
-------|----------------------|----------------------|--------|----------------------|--------|--------------------
20     | 0.042010838985443118 | 0.04625391960144043  | 28     | 0.038927078247070312 | 17     | 4.2010838985443115)
200    | 0.050868065357208253 | 0.053812980651855469 | 93     | 0.048615932464599609 | 15     | 5.0868065357208252)
2000   | 0.084596984386444085 | 0.1150360107421875   | 73     | 0.078493118286132812 | 35     | 8.4596984386444092)
20000  | 2.4261571550369263   | 5.257922887802124    | 0      | 2.284060001373291    | 9      | 242.61571550369263)

In [34]:
def cg_initialize(A, b):
    x = 0 * b
    r = A.dot(x) - b
    p = 1 * r
    x, r, p = dask.persist(x, r, p)
    return x, r, p

In [35]:
AA = da.random.random((200, 200), chunks=50)
bb = da.random.random(200, chunks=50)
AA, bb = dask.persist(AA, bb)

In [36]:
x, r, p = cg_initialize(AA, bb)
assert len(x.dask) == len(bb.dask)
assert len(r.dask) == len(bb.dask)
assert len(p.dask) == len(bb.dask)

In [37]:
start = time.time()
(res,) = dask.persist(da.linalg.norm(r))
print time.time() - start
(cond,) = dask.compute(res < 0.001)
print cond

0.0173809528351
False


In [38]:
def cg_initialize(A, b):
    x = 0 * b
    r = A.dot(x) - b
    p = 1 * r
    x, r, p = dask.persist(x, r, p)
    return x, r, p

def cg_iterate(A, state, persist=True):
    ox, or_, op = state
    Ap = A.dot(op)
    alpha = or_.dot(or_) / op.dot(Ap)
    x = ox + alpha * op
    r = or_ - alpha * Ap
    p = r + op * r.dot(r) / or_.dot(or_)
    if persist:
        x, r, p = dask.persist(x, r, p)
    return x, r, p

def cg_residual(state, compute=True):
    _, r, _ = state
    res = da.linalg.norm(r)
    if compute:
        (res,) = dask.compute(res)
    return res
    
def cg(A, b, tol=1e-5, maxiter=200, verbose=0, graph_iters=1):
    graph_iters = max(1, int(graph_iters))
    state = cg_initialize(A, b)
    start = time.time()
    for i in range(1, maxiter + 1):
        calculate = bool(i % graph_iters == 0)
        state = cg_iterate(A, state, persist=calculate)
        res = cg_residual(state, compute=calculate)
        if i % 10 == 0:
            print i, time.time() - start
            start = time.time()
        if calculate:
            if i % 10 == 0:
                print '\t', i, res
            if res < tol:
                break
    x, _, _ = state
    res = cg_residual(state, compute=True)
    (x,) = dask.persist(x)
    return x, res, i

In [39]:
AA = da.random.random((200, 200), chunks=50)
AA = AA.T.dot(AA)
bb = da.random.random(200, chunks=50)
AA, bb = dask.persist(AA, bb)

In [40]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500)
    print res, i, time.time() - t_start

In [41]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500, graph_iters=5)
    print res, i, time.time() - t_start

In [42]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500, graph_iters=10)
    print res, i, time.time() - t_start

In [43]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500, graph_iters=20)
    print res, i, time.time() - t_start

In [44]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500, graph_iters=50)
    print res, i, time.time() - t_start

In [45]:
if False:
    t_start = time.time()
    x, res, i = cg(AA, bb, maxiter=500, graph_iters=100)
    print res, i, time.time() - t_start

### speed per iteration as dependent on graph size
graph iters | residual          | iters | total time (s) | iteration time (ms)
------------|-------------------|-------|----------------|--------------------
1 | 7.10698798222e-06 | 406 | 37.6809880733 | 92.8
5 | 3.037102932e-07 | 410 | 33.7013099194 | 82.2
10 | 3.037102932e-07 | 410 | 32.1891298294 | 76.1
20 | 1.21618439317e-09 | 420 | 31.7939510345 | 75.7
50 | 3.22951398091e-12 | 450 | 32.4151659012 | 72.0
100         | 7.03413449862e-12 | 500   | 37.3863518238  | 74.8


In [46]:
state = cg_initialize(AA, bb)

In [47]:
state1 = cg_iterate(AA, state, persist=True)

In [48]:
x, r, p = cg_iterate(AA, state1, persist=False)

In [50]:
def cg_initialize_atop(A, b, persist=True, optimize=True):
#     x = da.atop(lambda bi: 0 * bi, 'i', b, 'i', dtype=b.dtype, name='x')
#     r = da.atop(lambda Ai, xi, bi: Ai.dot(xi) - bi, 'i', A, 'ij', x, 'j', b, 'i', concatenate=True, dtype=b.dtype, name='r')
#     p = da.atop(lambda ri: 1 * ri, 'i', r, 'i', dtype=b.dtype, name='p')
    def init_x(bi): return 0 * bi
    def init_r(Ai, xi, bi): return Ai.dot(xi) - bi
    def init_p(ri): return 1 * ri
    x = da.atop(init_x, 'i', b, 'i', dtype=b.dtype)
    r = da.atop(init_r, 'i', A, 'ij', x, 'j', b, 'i', concatenate=True, dtype=b.dtype)
    p = da.atop(init_p, 'i', r, 'i', dtype=b.dtype)
    if optimize:
        x, r, p = dask.optimize(x, r, p)
    if persist:
        x, r, p = dask.persist(x, r, p, optimize=(not optimize))
    return x, r, p

In [51]:
def cg_iterate_atop(A, state, persist=True, optimize=True):
    def update_x(x, alpha, p): return x + alpha * p
    def update_r(r, alpha, Ap): return r - alpha * Ap
    def update_p(p, gamma, gamma_next, r): return r + (gamma / gamma_next) * p
    x, r, p = state
    Ap = A.dot(p)
    gamma = r.dot(r)
    alpha = gamma / p.dot(Ap)
    x_next = da.atop(update_x, 'i', x, 'i', alpha, '', p, 'i', dtype=A.dtype)
    r_next = da.atop(update_r, 'i', r, 'i', alpha, '', Ap, 'i', dtype=A.dtype)
#     x_next = da.atop(lambda xi, alpha, pi: xi + alpha * pi, 'i', x, 'i', alpha, '', p, 'i', dtype=A.dtype, name='xnext')
#     r_next = da.atop(lambda xi, alpha, Api: ri - alpha * Api, 'i', r, 'i', alpha, '', Ap, 'i', dtype=A.dtype, name='rnext')
    gamma_next = r_next.dot(r_next)
    p_next = da.atop(update_p, 'i', r_next, 'i', gamma, '', gamma_next, '', p, 'i', dtype=A.dtype)
#     p_next = da.atop(lambda ri, beta, pi: ri + beta * pi, 'i', r_next, 'i', gamma_next / gamma, '', p, 'i', dtype=A.dtype, name='pnext')
    if optimize:
        x_next, r_next, p_next = dask.optimize(x_next, r_next, p_next)
    if persist:
        x_next, r_next, p_next = dask.persist(x_next, r_next, p_next, optimize=(not optimize))
    return x_next, r_next, p_next

In [52]:
state = cg_initialize_atop(AA, bb, persist=True)
x0, r0, p0 = state

In [53]:
start = time.time()
state = x0, r0, p0
for i in range(10):
    state = cg_iterate_atop(AA, state, persist=False, optimize=(i == 9))
x_out, r_out, p_out = state
print time.time() - start

0.194261789322


In [54]:
start = time.time()
state = x1, r1, p1 = dask.persist(*state)
print time.time() - start

0.43355679512


In [55]:
len(x0.dask), len(x_out.dask)

(4, 707)

In [59]:
import operator

In [60]:
def cg_init_dsk(A, b, state0):
    x0, r0, p0 = map(lambda nm: nm + '-' + state0, ('x', 'r', 'p'))
    def init_x(bi): return 0 * bi
#     def init_r(Ai, xi, bi): return Ai.dot(xi) - bi
    def init_p(ri): return 1 * ri
    dsk = dict()
    vblocks, hblocks = A.numblocks
    for i in range(vblocks):
        dsk[(x0, i)] = (init_x, (b.name, i))
        dsk[(r0, i)] = (operator.sub, 
                (da.core.dotmany, [(A.name, i, j) for j in range(hblocks)], [(x0, j) for j in range(hblocks)]),
                (b.name, i))
        dsk[(p0, i)] = (init_p, (r0, i))
#     if optimize:
#         x, r, p = dask.optimize(x, r, p)
#     if persist:
#         x, r, p = dask.persist(x, r, p, optimize=(not optimize))
    return dsk

In [61]:
start = time.time()
key1 = 'iter1'
dsk = dask.sharedict.merge(AA.dask, bb.dask, cg_init_dsk(AA, bb, key1))
x = da.Array(dsk, 'x-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
r = da.Array(dsk, 'r-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
p = da.Array(dsk, 'p-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
x, r, p = dask.persist(x, r, p)
print time.time() - start

0.0135219097137


In [62]:
def cg_iterate_dsk(A, state0, state1):
    Ap, pAp = 'Ap-' + state0, 'pAp-' + state0
    x0, r0, p0, gamma0 = map(lambda nm: nm + '-' + state0, ('x', 'r', 'p', 'gamma'))
    x1, r1, p1, gamma1 = map(lambda nm: nm + '-' + state1, ('x', 'r', 'p', 'gamma'))
    def update_x(x, gamma, pAp, p): return x + (gamma / pAp) * p
    def update_r(r, gamma, pAp, Ap): return r - (gamma / pAp) * Ap
    def update_p(p, gamma, gamma_next, r): return r + (gamma_next / gamma) * p
    dsk = dict()
    vblocks, hblocks = A.numblocks
    for i in range(vblocks):
        dsk[(Ap, i)] = (da.core.dotmany, [(A.name, i, j) for j in range(hblocks)], [(p0, j) for j in range(hblocks)])
    dsk[gamma0] = (da.core.dotmany, [(r0, i) for i in range(vblocks)], [(r0, i) for i in range(vblocks)])
    dsk[pAp] = (da.core.dotmany, [(p0, i) for i in range(vblocks)], [(Ap, i) for i in range(vblocks)])
    for i in range(vblocks):
        dsk[(x1, i)] = (update_x, (x0, i), gamma0, pAp, (p0, i))
        dsk[(r1, i)] = (update_r, (r0, i), gamma0, pAp, (Ap, i))
        dsk[(p1, i)] = (update_p, (p0, i), gamma0, gamma1, (r1, i))
    dsk[gamma1] = (da.core.dotmany, [(r1, i) for i in range(vblocks)], [(r1, i) for i in range(vblocks)])
    return dsk

In [63]:
start = time.time()
dsk = cg_iterate_dsk(AA, 'iter1', 'iter2')
print time.time() - start

0.000180959701538


In [64]:
start = time.time()
key0, key1 = 'iter1', 'iter2'
# dsk = cg_iterate_dsk(AA, key0, key1)
dsk = dask.sharedict.merge(AA.dask, x.dask, r.dask, p.dask, cg_iterate_dsk(AA, key0, key1))
x = da.Array(dsk, 'x-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
r = da.Array(dsk, 'r-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
p = da.Array(dsk, 'p-' + key1, shape=bb.shape, chunks=bb.chunks, dtype=bb.dtype)
(x, r, p) = dask.optimize(x, r, p)
(resnrm,) = dask.compute(da.linalg.norm(r))
print time.time() - start

0.0266568660736


In [65]:
def cg_calcs(dsk, key, b):
    x = da.Array(dsk, 'x-' + key, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    r = da.Array(dsk, 'r-' + key, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    p = da.Array(dsk, 'p-' + key, shape=b.shape, chunks=b.chunks, dtype=b.dtype)
    (x, r, p) = dask.persist(x, r, p, optimize_graph=False, traverse=False)
    (res,) = dask.compute(da.linalg.norm(r))
    return x, r, p, res

def cg_dsk(A, b, tol=1e-5, maxiter=500, verbose=0, print_iters=0, graph_iters=1, time_iters=0):
    graph_iters = max(1, int(graph_iters))
    time_iters = max(0, int(time_iters))
    if int(print_iters) < 1 and verbose > 0:
        print_iters = max(0, max(int(print_iters), int(10**(3 - verbose))))
    key_init = 'cg-iter0'
    dsk = dask.sharedict.merge(AA.dask, bb.dask, cg_init_dsk(AA, bb, key_init))
    x, r, p, res = cg_calcs(dsk, key_init, b)
    if time_iters > 0:
        start = time.time()
    dsk = dict()
    for i in range(1, maxiter + 1):
        key0 = 'cg-iter{}'.format(i - 1)
        key1 = 'cg-iter{}'.format(i)
        calculate = bool(i % graph_iters == 0)
        dsk.update(cg_iterate_dsk(A, key0, key1))
        if calculate:
            dsk = dask.sharedict.merge(A.dask, x.dask, r.dask, p.dask, dsk)
            x, r, p, res = cg_calcs(dsk, key1, b)
            if print_iters > 0 and i % print_iters == 0:
                print '\t\t\t{}: residual = {:.1e}'.format(i, res)
            if res < tol:
                break
            dsk = dict()
        if time_iters > 0 and i % time_iters == 0:
            print '{}: {:.1e} seconds'.format(i, time.time() - start)
            start = time.time()
    if i == maxiter:
        dsk = dask.sharedict.merge(A.dask, x.dask, r.dask, p.dask, dsk)
        x, _, _, res = cg_calcs(dsk, key1, b)
    return x, res, i

In [66]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, time_iters=10, print_iters=10)
    t_cg = time.time() - t_start
    fmt = '\n\niters: {}\nresidual: {:.1e}\ntime: {:.2e} seconds\nper iter: {:.1f} ms'
    print fmt.format(i, res, t_cg, 1000 * t_cg / i)

In [67]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=5)
    t_cg = time.time() - t_start
    fmt = 'iters: {}\nresidual: {:.1e}\ntime: {:.2e} seconds\nper iter: {:.1f} ms'
    print fmt.format(i, res, t_cg, 1000 * t_cg / i)

In [68]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=20)
    print res, i, time.time() - t_start

In [69]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=10)
    print res, i, time.time() - t_start

In [70]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=30)
    print res, i, time.time() - t_start

In [71]:
if False:
    t_start = time.time()
    x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=50)
    print res, i, time.time() - t_start

In [72]:
if False:
    print 'graph iters | iters | residual | time (s) | per iter (ms)'
    print '------------|-------|----------|----------|--------------'
    for graph_iters in (1, 5, 10, 20, 50, 100):
        t_start = time.time()
        x, res, i = cg_dsk(AA, bb, maxiter=500, graph_iters=graph_iters)
        t_cg = time.time() - t_start
        print '{}|{}|{:.1e}|{:.2e}|{:.1f}'.format(graph_iters, i, res, t_cg, 1000 * t_cg / i)


graph iters | iters | residual | time (s) | per iter (ms)
------------|-------|----------|----------|--------------
1|406|5.7e-06|9.14e+00|22.5
5|410|5.1e-07|5.68e+00|13.9
10|410|5.1e-07|5.22e+00|12.7
20|420|2.9e-09|5.14e+00|12.2
50|450|5.2e-12|5.44e+00|12.1
100|500|3.3e-12|7.34e+00|14.7