Skip to content
Permalink
Browse files

Python 3 support (#7)

# Conflicts:
#	.circleci/config.yml
#	Makefile
  • Loading branch information...
paulmelnikow committed Jul 15, 2018
1 parent 58a76a8 commit db6eaf8c93eb5ae571eb054575fb6ecec62fd86d
@@ -11,7 +11,7 @@ tidy:

test: clean qtest
qtest: all
python -m unittest discover -s chumpy
python -m unittest

coverage: clean qcov
qcov: all
@@ -1,12 +1,12 @@
from ch import *
from logic import *
from .ch import *
from .logic import *

from optimization import minimize
import extras
import testing
from version import version as __version__
from .optimization import minimize
from . import extras
from . import testing
from .version import version as __version__

from version import version as __version__
from .version import version as __version__

from numpy import bool, int, float, complex, object, unicode, str, nan, inf

@@ -107,12 +107,12 @@ def compute_dr_wrt(self, wrt):

def demo(which=None):
if which not in demos:
print 'Please indicate which demo you want, as follows:'
print('Please indicate which demo you want, as follows:')
for key in demos:
print "\tdemo('%s')" % (key,)
print("\tdemo('%s')" % (key,))
return

print '- - - - - - - - - - - <CODE> - - - - - - - - - - - -'
print demos[which]
print '- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n'
print('- - - - - - - - - - - <CODE> - - - - - - - - - - - -')
print(demos[which])
print('- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n')
exec('global np\n' + demos[which], globals(), locals())
@@ -5,13 +5,13 @@
"""


import ch
from . import ch
import numpy as np
from os.path import join, split
from StringIO import StringIO
from six import StringIO
import numpy
import chumpy
import cPickle as pickle
from six.moves import cPickle as pickle

src = ''
num_passed = 0
@@ -71,7 +71,7 @@ def r(fn_name, args_req, args_opt, nplib=numpy, chlib=chumpy):

try:
if isinstance(args_req, dict):
fn(**dict(args_req.items() + args_opt.items()))
fn(**dict(list(args_req.items()) + list(args_opt.items())))
else:
fn(*args_req, **args_opt)
if lib is chlib:
@@ -108,7 +108,7 @@ def append(a, b, c):
b_color = lookup[b] if b in lookup else 'white'
c_color = lookup[c] if c in lookup else 'white'

print '%s: %s, %s' % (a,b,c)
print('%s: %s, %s' % (a,b,c))
make_row(a, b, c, b_color, c_color)

def m(s):
@@ -524,11 +524,11 @@ def main():
src = '<html><body><table border=1>' + src + '</table></body></html>'
open(join(split(__file__)[0], 'api_compatibility.html'), 'w').write(src)

print 'passed %d, not passed %d' % (num_passed, num_not_passed)
print('passed %d, not passed %d' % (num_passed, num_not_passed))



if __name__ == '__main__':
global which_passed
main()
print ' '.join(which_passed)
print(' '.join(which_passed))
@@ -18,11 +18,12 @@
import copy as external_copy
from functools import wraps
from scipy.sparse.linalg.interface import LinearOperator
import utils
from utils import row, col
from . import utils
from .utils import row, col
import collections
from copy import deepcopy
from utils import timer
from .utils import timer
from functools import reduce


# Turn this on if you want the profiler injected
@@ -102,7 +103,7 @@ def __new__(cls, *args, **kwargs):
object.__setattr__(result, '_cache_info', {})
object.__setattr__(result, '_status', 'new')

for name, default_val in cls._default_kwargs.items():
for name, default_val in list(cls._default_kwargs.items()):
object.__setattr__(result, '_%s' % name, kwargs.get(name, default_val))
if name in kwargs:
del kwargs[name]
@@ -325,7 +326,7 @@ def __len__(self):
return len(self.r)

def minimize(self, *args, **kwargs):
import optimization
from . import optimization
return optimization.minimize(self, *args, **kwargs)

def __array__(self, *args):
@@ -408,7 +409,7 @@ def __setattr__(self, name, value, itr=None):

def _invalidate_cacheprop_names(self, names):
nameset = set(names)
for func_name, v in self._depends_on_deps.items():
for func_name, v in list(self._depends_on_deps.items()):
if len(nameset.intersection(v['deps'])) > 0:
v['out_of_date'] = True

@@ -427,7 +428,7 @@ def clear_cache(self, itr=None):
next._cache['drs'].clear()
next._itr = itr

for parent, parent_dict in next._parents.items():
for parent, parent_dict in list(next._parents.items()):
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
parent._invalidate_cacheprop_names(parent_dict['varnames'])
todo.append(parent)
@@ -436,14 +437,14 @@ def clear_cache(self, itr=None):


def clear_cache_wrt(self, wrt, itr=None):
if self._cache['drs'].has_key(wrt):
if wrt in self._cache['drs']:
self._cache['drs'][wrt] = None

if hasattr(self, 'dr_cached') and wrt in self.dr_cached:
self.dr_cached[wrt] = None

if itr is None or itr != self._itr:
for parent, parent_dict in self._parents.items():
for parent, parent_dict in list(self._parents.items()):
if wrt in parent._cache['drs'] or (hasattr(parent, 'dr_cached') and wrt in parent.dr_cached):
parent.clear_cache_wrt(wrt=wrt, itr=itr)
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
@@ -662,7 +663,7 @@ def lmult_wrt(self, lhs, wrt):

if hasattr(p, 'dterms') and p is not wrt and p.is_dr_wrt(wrt):
if not isinstance(p, Ch):
print 'BROKEN!'
print('BROKEN!')
raise Exception('Broken Should be Ch object')

indirect_dr = p.lmult_wrt(self._superdot(lhs, self._compute_dr_wrt_sliced(p)), wrt)
@@ -817,7 +818,7 @@ def dr_wrt(self, wrt, reverse_mode=False, profiler=None):
# If we *always* filled in the cache, it would require
# more memory but would occasionally save a little cpu,
# on average.
if len(self._parents.keys()) != 1:
if len(list(self._parents.keys())) != 1:
self._cache['drs'][wrt] = result

if DEBUG:
@@ -895,7 +896,7 @@ def string_for(self, my_name):
color = 'blue'
if isinstance(dtval, reordering.Concatenate) and len(dtval.dr_cached) > 0:
s = 'dr_cached\n'
for k, v in dtval.dr_cached.iteritems():
for k, v in dtval.dr_cached.items():
if v is not None:
issparse = sp.issparse(v)
size = v.size
@@ -913,7 +914,7 @@ def string_for(self, my_name):
elif len(dtval._cache['drs']) > 0:
s = '_cache\n'

for k, v in dtval._cache['drs'].iteritems():
for k, v in dtval._cache['drs'].items():
if v is not None:
issparse = sp.issparse(v)
size = v.size
@@ -1029,8 +1030,8 @@ def string_for(self, my_name):
result += string_for(getattr(self, dterm), dterm)

if cachelim != np.inf and hasattr(self, '_cache') and 'drs' in self._cache:
import cPickle as pickle
for dtval, jac in self._cache['drs'].items():
from six.moves import cPickle as pickle
for dtval, jac in list(self._cache['drs'].items()):
# child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
# child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
src = 'aaa%d' % (id(self))
@@ -1143,6 +1144,7 @@ def __mul__ (self, other): return ch_ops.multiply(a=self, b=other)
def __rmul__(self, other): return ch_ops.multiply(a=other, b=self)

def __div__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __truediv__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __rdiv__(self, other): return ch_ops.divide(x1=other, x2=self)

def __pow__ (self, other): return ch_ops.power(x=self, pow=other)
@@ -1203,7 +1205,7 @@ def _depends_on(func):

@wraps(func)
def with_caching(self, *args, **kwargs):
func_name = func.func_name
func_name = func.__name__
sdf = self._depends_on_deps[func_name]
if sdf['out_of_date'] == True:
#tm = time.time()
@@ -1248,7 +1250,7 @@ def __init__(self, lmb, initial_args=None):
if initial_arg in args:
args[initial_arg].x = initial_args[initial_arg]
result = lmb(**args)
for argname, arg in args.items():
for argname, arg in list(args.items()):
if result.is_dr_wrt(arg.x):
self.add_dterm(argname, arg.x)
else:
@@ -1289,8 +1291,8 @@ def on_changed(self, which):
# it would be better if they could be "internal" as well, but for now the idea
# is that result may itself be a ChLambda.
def __init__(self, result, args):
self.args = { argname: ChHandle(x=arg) for argname, arg in args.items() }
for argname, arg in self.args.items():
self.args = { argname: ChHandle(x=arg) for argname, arg in list(args.items()) }
for argname, arg in list(self.args.items()):
setattr(result, argname, arg)
if result.is_dr_wrt(arg.x):
self.add_dterm(argname, arg.x)
@@ -1305,17 +1307,17 @@ def compute_r(self):
def compute_dr_wrt(self, wrt):
return self._result.dr_wrt(wrt)

import ch_ops
from ch_ops import *
from . import ch_ops
from .ch_ops import *
__all__ += ch_ops.__all__

import reordering
from reordering import *
from . import reordering
from .reordering import *
__all__ += reordering.__all__


import linalg
import ch_random as random
from . import linalg
from . import ch_random as random
__all__ += ['linalg', 'random']


@@ -1339,16 +1341,16 @@ def main():
x30 = Ch(30)

tmp = ChLambda(lambda x, y, z: Ch(1) + Ch(2) * Ch(3) + 4)
print tmp.dr_wrt(tmp.x)
print(tmp.dr_wrt(tmp.x))
import pdb; pdb.set_trace()
#a(b(c(d(e(f),g),h)))

blah = tst(x10, x20, x30)

print blah.r
print(blah.r)


print foo
print(foo)

import pdb; pdb.set_trace()

@@ -44,13 +44,15 @@
__all__ += numpy_array_creation_routines


import ch
from . import ch
import six
import numpy as np
import warnings
import cPickle as pickle
from six.moves import cPickle as pickle
import scipy.sparse as sp
from utils import row, col
from .utils import row, col
from copy import copy as copy_copy
from functools import reduce

__all__ += ['pi', 'set_printoptions']
pi = np.pi
@@ -645,8 +647,8 @@ def _stride_for_axis(self,axis, mtx):
# np.amin here probably
idxs = np.arange(mtx.size).reshape(mtx.shape)
mn = np.amin(idxs, axis=axis)
stride = np.array(mtx.strides)
stride /= np.min(stride) # go from bytes to num elements
mtx_strides = np.array(mtx.strides)
stride = mtx_strides / np.min(mtx_strides) # go from bytes to num elements
stride = stride[axis]
return mn, stride

@@ -5,7 +5,7 @@
"""

import numpy.random
from ch import Ch
from .ch import Ch

api_not_implemented = ['choice','bytes','shuffle','permutation']

@@ -1,8 +1,8 @@
__author__ = 'matt'

import ch
from . import ch
import numpy as np
from utils import row, col
from .utils import row, col
import scipy.sparse as sp
import scipy.special

@@ -12,9 +12,9 @@

import numpy as np
import scipy.sparse as sp
from ch import Ch, depends_on, NanDivide
from utils import row, col
import ch
from .ch import Ch, depends_on, NanDivide
from .utils import row, col
from . import ch


try:
@@ -284,17 +284,17 @@ def slogdet(*args):
def main():

tmp = ch.random.randn(100).reshape((10,10))
print 'chumpy version: ' + str(slogdet(tmp)[1].r)
print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
print('chumpy version: ' + str(slogdet(tmp)[1].r))
print('old version:' + str(np.linalg.slogdet(tmp.r)[1]))

eps = 1e-10
diff = np.random.rand(100) * eps
diff_reshaped = diff.reshape((10,10))
print np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
print slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
print(np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1])
print(slogdet(tmp)[1].dr_wrt(tmp).dot(diff))

print np.linalg.slogdet(tmp.r)[0]
print slogdet(tmp)[0]
print(np.linalg.slogdet(tmp.r)[0])
print(slogdet(tmp)[0])

if __name__ == '__main__':
main()
Oops, something went wrong.

0 comments on commit db6eaf8

Please sign in to comment.
You can’t perform that action at this time.