Skip to content

Commit

Permalink
Added Routine Data Structures
Browse files Browse the repository at this point in the history
These are used for describe the computation wanted for the generated
code. The main type is `Routine`, which takes in a name, arguments, and
results. An easy interface `routine` is also created that makes common
assumptions about what the user wants. This allows finer control for
users that want it, while still having an easy interface.

Routines cannot be used in expression trees, but they can be called as a
function, resulting in a `RoutineCall` type. This type has two
attributes: `returns` and `inplace`. The `returns` attr returns a
`RoutineCallResult` representing the result for what is directly
returned. A `Tuple` of these is returned if there is more than one
result. The `inplace` attr returns a dict of `symbol` ->
`RoutineCallResult`, which represents results returned "inplace".

The `RoutineCallResult` types "alias" all assumptions to what the
underlying computation represented would be. This means they can be used
inside sympy expressions. This design makes it easy to call generated
routines from other routines, while keeping everything modular and
expressive.

At this point only `ScalarRoutineCallResult` is implemented.
Corresponding types for Matrices (and later Indexed) still need to be
done.
  • Loading branch information
jcrist committed Sep 29, 2014
1 parent 270e3f9 commit 668da9a
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 336 deletions.
2 changes: 1 addition & 1 deletion bin/test
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ import pytest
import sys

path_hack()
sys.exit(pytest.main())
sys.exit(pytest.main('-s'))
3 changes: 0 additions & 3 deletions symcc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
from .generators import *
from .printers import *
from .utilities import *
from .types import *
from .wrappers import *
2 changes: 1 addition & 1 deletion symcc/printers/codeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sympy.printing.str import StrPrinter
from sympy.printing.precedence import precedence

from symcc.types.routines import Assign, AssignmentError
from symcc.types.ast import Assign

__all__ = ["CodePrinter"]

Expand Down
1 change: 0 additions & 1 deletion symcc/printers/tests/test_ccode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

x, y, z = symbols('x, y, z')
a, b, c = symbols('a, b, c')
some_long_name = 'yaaaaaaa'


def test_printmethod():
Expand Down
54 changes: 38 additions & 16 deletions symcc/types/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
-------------
*Basic*
|--->Assign
|--->AugAssign
|--->NativeOp
|--->Assign
|--->AugAssign
|--->NativeOp
| |--------------|
| |--->AddOp
| |--->SubOp
Expand Down Expand Up @@ -259,7 +259,7 @@ class For(Basic):
iter : iterable
body : sympy expr
"""

def __new__(cls, target, iter, body):
target = _sympify(target)
if not iterable(iter):
Expand Down Expand Up @@ -293,43 +293,64 @@ class DataType(with_metaclass(Singleton, Basic)):


class NativeBool(DataType):
_name = 'Bool'
pass


class NativeInteger(DataType):
_name = 'Int'
pass


class NativeFloat(DataType):
_name = 'Float'
pass


class NativeDouble(DataType):
_name = 'Double'
pass


class NativeVoid(DataType):
_name = 'Void'
pass


dtype_registry = {'bool': NativeBool(),
'int': NativeInteger(),
'float': NativeFloat(),
'double': NativeDouble(),
'void': NativeVoid()}
Bool = NativeBool()
Int = NativeInteger()
Float = NativeFloat()
Double = NativeDouble()
Void = NativeVoid()


dtype_registry = {'bool': Bool,
'int': Int,
'float': Float,
'double': Double,
'void': Void}


def datatype(dtype):
"""Returns the datatype singleton for the given dtype"""

if dtype.lower() not in dtype_registry:
raise ValueError("Unrecognized datatype " + dtype)
return dtype_registry[dtype]
if isinstance(dtype, str):
if dtype.lower() not in dtype_registry:
raise ValueError("Unrecognized datatype " + dtype)
return dtype_registry[dtype]
else:
dtype = _sympify(dtype)
if dtype.is_integer:
return dtype_registry['int']
elif dtype.is_Boolean:
return dtype_registry['bool']
else:
return dtype_registry['double']


class Variable(Basic):
"""Represents a typed variable.
Parameters
----------
name : Symbol, MatrixSymbol
Expand Down Expand Up @@ -444,9 +465,10 @@ class FunctionDef(Basic):

def __new__(cls, name, args, body, results):
# name
if not isinstance(name, str):
raise TypeError("Function name must be string")
name = Symbol(name)
if isinstance(name, str):
name = Symbol(name)
elif not isinstance(name, Symbol):
raise TypeError("Function name must be Symbol or string")
# args
if not iterable(args):
raise TypeError("args must be an iterable")
Expand Down
Loading

0 comments on commit 668da9a

Please sign in to comment.