In [1]:
from pymatsolver import Pardiso
import numpy as np
import scipy as sp
import dask
import dask.array as da

In [2]:
# Create a simple sparse banded system
nC = 100000
A = sp.sparse.diags([-np.ones(nC), 2*np.ones(nC), -np.ones(nC)], [-1,0,1], shape=(nC,nC))

# Create Pardiso solver for Ainv
Ainv = Pardiso(A)

# Solve it for a random RHS
b = np.random.randn(nC,100)



In [3]:
# Solve the RHS in a loop
y = Ainv*b
y

array([[ -23.83315843,  194.02058934,  -76.67440679, ...,  128.00983571,
         -19.07211464, -312.9730096 ],
       [ -48.06786034,  385.61050047, -152.57187283, ...,  255.96749588,
         -37.02378928, -625.63994802],
       [ -72.17821251,  576.07908493, -228.9395117 , ...,  384.71469504,
         -56.15634123, -939.93603779],
       ...,
       [-267.74925839,  233.35010731,  110.40689906, ...,  122.54691174,
         201.82825759, -217.47401107],
       [-176.88819199,  155.78057395,   74.58208403, ...,   81.81474714,
         132.84248617, -144.8863017 ],
       [ -87.74261011,   77.62650622,   37.48977625, ...,   41.02276335,
          65.63889072,  -72.51581147]])

In [4]:
# Try to do the same solves inside a delayed function
@dask.delayed
def solveRHD(rhs):

    return Ainv*rhs

# Split the solves in dask delays
columns = [solveRHD(b[:, ii]).reshape((nC,1)) for ii in range(b.shape[1])]
solves = [da.from_delayed(column, dtype='float', shape=(nC, 1)).reshape((nC,1)) for column in columns]

# Stack the result
yDask = da.hstack(solves)

In [5]:
# Compute only one of them 
# First is the same ... good
solves[0].compute()




array([[ -23.83315843],
       [ -48.06786034],
       [ -72.17821251],
       ...,
       [-267.74925839],
       [-176.88819199],
       [ -87.74261011]])

In [6]:
# Run all with dask ... not good, run again different
yDask.compute()



array([[ -76.67440679,   24.72412378,   24.72412378, ...,  128.00983571,
         -19.07211464, -312.9730096 ],
       [-152.57187283,   49.37725788,   49.37725788, ...,  255.96749588,
         -37.02378928, -625.63994802],
       [-228.9395117 ,   72.86412572,   72.86412572, ...,  384.71469504,
         -56.15634123, -939.93603779],
       ...,
       [ 233.35010731,  161.80812718,  161.80812718, ...,  122.54691174,
         201.82825759, -217.47401107],
       [ 155.78057395,  107.94883099,  107.94883099, ...,   81.81474714,
         132.84248617, -144.8863017 ],
       [  77.62650622,   54.14696415,   54.14696415, ...,   41.02276335,
          65.63889072,  -72.51581147]])