Skip to content
This repository has been archived by the owner on Aug 12, 2021. It is now read-only.

Commit

Permalink
Cleanup of AD calls
Browse files Browse the repository at this point in the history
  • Loading branch information
mzszym committed Apr 13, 2019
1 parent 08288c4 commit 3444c02
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
8 changes: 3 additions & 5 deletions oedes/fvm/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(self, mesh, name):
self.bc_dof_is_free = np.ones_like(mesh.boundary.idx, dtype=np.bool)
self.boundary_labels = None

def residuals(self, part, facefluxes, celltransient=0., cellsource=0.):
def residuals(self, part, facefluxes, celltransient, cellsource):
"""
Calculate FVM residuals
Expand Down Expand Up @@ -179,10 +179,8 @@ def residuals(self, part, facefluxes, celltransient=0., cellsource=0.):
FdS = 0.
else:
FdS = ad.dot(part.fluxsum, facefluxes)
if not ad.isscalar(celltransient):
celltransient = celltransient[idx]
if not ad.isscalar(cellsource):
cellsource = cellsource[idx]
celltransient = ad.getitem(celltransient, idx)
cellsource = ad.getitem(cellsource, idx)
return self.idx[idx], -FdS / \
part.cells['volume'] + celltransient - cellsource

Expand Down
8 changes: 4 additions & 4 deletions oedes/fvm/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from .cell import FVMBoundaryEquation, FVMConservationEquation
import scipy.sparse.csgraph
from oedes.ad import sparsesum_bare
from oedes import ad
from oedes.model import model


Expand Down Expand Up @@ -83,7 +83,7 @@ def _init_boundary(self):
(np.ones_like(i), (i, j)), shape=(self.ndof,) * 2)
nlabels, labels = scipy.sparse.csgraph.connected_components(
g, directed=False, return_labels=True)
self.bc_label_volume = sparsesum_bare(nlabels, ((labels[eq.idx], eq.mesh.cells[
self.bc_label_volume = ad.sparsesum_bare(nlabels, ((labels[eq.idx], eq.mesh.cells[
'volume']) for eq in self.equations if isinstance(eq, FVMConservationEquation)))
self.bc_labels = labels
for eq in self._all_conservation():
Expand All @@ -93,15 +93,15 @@ def createContext(self, target):
return FVMEvalContext(self, target)

def finalize(self, target, bc_conservation):
bc_label_conservation = sparsesum_bare(
bc_label_conservation = ad.sparsesum_bare(
len(self.bc_labels), bc_conservation)
for eq in self.equations:
if isinstance(eq, FVMConservationEquation):
bc_free_dof = np.arange(len(eq.mesh.boundary.idx))[
eq.bc_dof_is_free]
i = eq.idx[eq.mesh.boundary.idx[bc_free_dof]]
j = self.bc_labels[i]
yield i, bc_label_conservation[j] * (1. / self.bc_label_volume[j])
yield i, ad.getitem(bc_label_conservation, j) * (1. / self.bc_label_volume[j])

def scaling(self, params):
xscaling = np.ones_like(self.X)
Expand Down
14 changes: 7 additions & 7 deletions oedes/models/equations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

from oedes.ad import getitem, sparsesum_bare, dot
from oedes import ad
from oedes.utils import EquationWithMesh, Equation
import itertools
import weakref
Expand Down Expand Up @@ -44,15 +44,15 @@ def _evaluate_bc(self, ctx, eq, FdS_boundary,

def residuals(self, ctx, eq, flux, source=None, transient=0.):
variables = ctx.varsOf(eq)
yield eq.residuals(eq.mesh.internal, flux, cellsource=source, celltransient=transient)
yield eq.residuals(eq.mesh.internal, flux, transient, source)
n = len(eq.mesh.boundary.cells)
bc_FdS = sparsesum_bare(n, variables['boundary_FdS'])
bc_source = sparsesum_bare(n, variables['boundary_sources'])
bc_FdS = ad.sparsesum_bare(n, variables['boundary_FdS'])
bc_source = ad.sparsesum_bare(n, variables['boundary_sources'])
if flux is not None:
bc_FdS = dot(eq.mesh.boundary.fluxsum, flux) + bc_FdS
bc_FdS = ad.dot(eq.mesh.boundary.fluxsum, flux) + bc_FdS
if source is not None:
bc_source = getitem(source, eq.mesh.boundary.idx) + bc_source
bc_transient = getitem(transient, eq.mesh.boundary.idx)
bc_source = ad.getitem(source, eq.mesh.boundary.idx) + bc_source
bc_transient = ad.getitem(transient, eq.mesh.boundary.idx)
variables['total_boundary_FdS'] = bc_FdS
variables['total_boundary_sources'] = bc_source
variables['total_boundary_transient'] = bc_transient
Expand Down
4 changes: 2 additions & 2 deletions oedes/models/equations/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from oedes.utils import SubEquation
from .charged import ChargedSpecies
from oedes.models import solver
from oedes.ad import getitem
from oedes import ad


class BandEnergy(SubEquation):
Expand Down Expand Up @@ -114,7 +114,7 @@ def QuasiFermiLevel(self, ctx, eq):
def _concentration(self, ctx, eq, idx, imref, ref, **kwargs):
assert eq.z in [1, -1]
if idx is not None:
Eband = getitem(ctx.varsOf(eq)[ref], idx)
Eband = ad.getitem(ctx.varsOf(eq)[ref], idx)
else:
Eband = ctx.varsOf(eq)[ref]
return self.c(ctx, eq, -(imref - Eband) / eq.z, **kwargs)
Expand Down

0 comments on commit 3444c02

Please sign in to comment.