Skip to content

Commit

Permalink
prototype parallel solve
Browse files Browse the repository at this point in the history
  • Loading branch information
Sierd authored and manuGil committed Jun 21, 2023
1 parent a520985 commit 558d884
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
61 changes: 52 additions & 9 deletions aeolis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

from aeolis.utils import *

from multiprocessing import Pool, Value, Array

class StreamFormatter(logging.Formatter):
"""A formater for log messages"""

Expand Down Expand Up @@ -241,7 +243,7 @@ def initialize(self)-> None:
aeolis.inout.visualize_timeseries(self.p, self.t)


def update(self, dt:float=-1) -> None:
def update(self, pools = "" , dt:float=-1) -> None:
'''Time stepping function
Takes a single step in time. Interpolates wind and
Expand Down Expand Up @@ -314,7 +316,7 @@ def update(self, dt:float=-1) -> None:
if self.p['scheme'] == 'euler_forward':
self.s.update(self.euler_forward())
elif self.p['scheme'] == 'euler_backward':
self.s.update(self.euler_backward())
self.s.update(self.euler_backward(pools=pools))
elif self.p['scheme'] == 'crank_nicolson':
self.s.update(self.crank_nicolson())
else:
Expand Down Expand Up @@ -683,7 +685,7 @@ def euler_forward(self) -> Any:
return solve


def euler_backward(self) -> Any:
def euler_backward(self,pools="") -> Any:
'''Convenience function for implicit solver based on Euler backward scheme
See Also
Expand All @@ -695,7 +697,7 @@ def euler_backward(self) -> Any:
if self.p['solver'].lower() == 'trunk':
solve = self.solve(alpha=1., beta=1)
elif self.p['solver'].lower() == 'pieter':
solve = self.solve_pieter(alpha=1., beta=1)
solve = self.solve_pieter(pools=pools, alpha=1., beta=1)
elif self.p['solver'].lower() == 'steadystate':
solve = self.solve_steadystate()
elif self.p['solver'].lower() == 'steadystatepieter':
Expand Down Expand Up @@ -1837,7 +1839,7 @@ def solve_steadystatepieter(self) -> dict:
w_bed=w_bed)


def solve_pieter(self, alpha:float=.5, beta:float=1.) -> dict:
def solve_pieter(self, pools="", alpha:float=.5, beta:float=1.) -> dict:
'''Implements the explicit Euler forward, implicit Euler backward and semi-implicit Crank-Nicolson numerical schemes
Determines weights of sediment fractions, sediment pickup and
Expand Down Expand Up @@ -2161,7 +2163,46 @@ def solve_pieter(self, alpha:float=.5, beta:float=1.) -> dict:

# solve system with current weights
Ct_i = Ct[:,:,i].flatten()
Ct_i += scipy.sparse.linalg.spsolve(A, yCt_i.flatten())
parallel2 = True
if parallel2 == True:
#domain 1 limit
d1 = np.int((np.floor(yCt_i.shape[0]/2)+1)*yCt_i.shape[1])
#domain 2 limit including 1 row overlap with d1
d2 = np.int((np.floor(yCt_i.shape[0]/2))*yCt_i.shape[1])

#Lets try 4 domains
d1 = np.int((np.floor(yCt_i.shape[0]*1/4)+1)*yCt_i.shape[1])
d2 = np.int((np.floor(yCt_i.shape[0]*1/4))*yCt_i.shape[1])
d3 = np.int((np.floor(yCt_i.shape[0]*2/4)+1)*yCt_i.shape[1])
d4 = np.int((np.floor(yCt_i.shape[0]*2/4))*yCt_i.shape[1])
d5 = np.int((np.floor(yCt_i.shape[0]*3/4)+1)*yCt_i.shape[1])
d6 = np.int((np.floor(yCt_i.shape[0]*3/4))*yCt_i.shape[1])

with Pool(processes=2) as pool:
results = pools.starmap(scipy.sparse.linalg.spsolve,
[(A[0:d1,0:d1], yCt_i.flatten()[0:d1]),
(A[d2:d3,d2:d3], yCt_i.flatten()[d2:d3]),
(A[d4:d5,d4:d5], yCt_i.flatten()[d4:d5]),
(A[d6:,d6:], yCt_i.flatten()[d6:])
])

#now we are going into sequential solve of the subdomains
# Ct_i[0:d1] += scipy.sparse.linalg.spsolve(A[0:d1,0:d1], yCt_i.flatten()[0:d1])
# Ct_i[d2:d3] += scipy.sparse.linalg.spsolve(A[d2:d3,d2:d3], yCt_i.flatten()[d2:d3])
# Ct_i[d4:d5] += scipy.sparse.linalg.spsolve(A[d4:d5,d4:d5], yCt_i.flatten()[d4:d5])
# Ct_i[d6:] += scipy.sparse.linalg.spsolve(A[d6:,d6:], yCt_i.flatten()[d6:])

# Ct_i[0:d1] += results[0]
# Ct_i[d2:d3] += results[1]
# Ct_i[d4:d5] += results[2]
# Ct_i[d6:] += results[3]

# in the current approach, the seam gets overwritten twice. This is incorrect but
# I will proceed for the sake of process optimization
# this needs to be corrected for.
else:
Ct_i += scipy.sparse.linalg.spsolve(A, yCt_i.flatten())

Ct_i = prevent_tiny_negatives(Ct_i, p['max_error'])

# check for negative values
Expand Down Expand Up @@ -2520,10 +2561,12 @@ def run(self, callback=None, restartfile:str=None) -> None:
# start model loop
self.t0 = time.time()
self.output_write()
#pools = ""
pools = Pool(5)
while self.t <= self.p['tstop']:
if callback is not None:
callback(self)
self.update()
self.update(pools)
self.output_write()
self.print_progress()

Expand Down Expand Up @@ -2656,7 +2699,7 @@ def initialize(self) -> None:
super(AeoLiSRunner, self).initialize()
self.output_init()

def update(self, dt:float=-1) -> None:
def update(self, pools="", dt:float=-1) -> None:
'''Time stepping function
Overloads the :func:`~model.AeoLiS.update()` function,
Expand All @@ -2674,7 +2717,7 @@ def update(self, dt:float=-1) -> None:
self.output_clear()
self.clear = False

super(AeoLiSRunner, self).update(dt=dt)
super(AeoLiSRunner, self).update(pools=pools, dt=dt)
self.output_update()


Expand Down
2 changes: 1 addition & 1 deletion aeolis/run_console.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from aeolis.console_debug import aeolis_debug

# configfile = r'c:\Users\weste_bt\aeolis\Tests\RotatingWind\Barchan_Grid270\aeolis.txt'
configfile = r'C:\Data\OneDrive - Lund University\DUNEFORCE\Aeolis\simple_model\testcase\4_test\aeolis.txt'
configfile = r'C:\Users\svries\Documents\GitHub\DCC_Bart\aeolis.txt'
aeolis_debug(configfile)

0 comments on commit 558d884

Please sign in to comment.