Skip to content

Commit

Permalink
Python 3 support (#7)
Browse files Browse the repository at this point in the history
# Conflicts:
#	.circleci/config.yml
#	Makefile
  • Loading branch information
paulmelnikow committed Jul 15, 2018
1 parent 58a76a8 commit db6eaf8
Show file tree
Hide file tree
Showing 20 changed files with 131 additions and 125 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -11,7 +11,7 @@ tidy:

test: clean qtest
qtest: all
python -m unittest discover -s chumpy
python -m unittest

coverage: clean qcov
qcov: all
Expand Down
24 changes: 12 additions & 12 deletions chumpy/__init__.py
@@ -1,12 +1,12 @@
from ch import *
from logic import *
from .ch import *
from .logic import *

from optimization import minimize
import extras
import testing
from version import version as __version__
from .optimization import minimize
from . import extras
from . import testing
from .version import version as __version__

from version import version as __version__
from .version import version as __version__

from numpy import bool, int, float, complex, object, unicode, str, nan, inf

Expand Down Expand Up @@ -107,12 +107,12 @@ def compute_dr_wrt(self, wrt):

def demo(which=None):
if which not in demos:
print 'Please indicate which demo you want, as follows:'
print('Please indicate which demo you want, as follows:')
for key in demos:
print "\tdemo('%s')" % (key,)
print("\tdemo('%s')" % (key,))
return

print '- - - - - - - - - - - <CODE> - - - - - - - - - - - -'
print demos[which]
print '- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n'
print('- - - - - - - - - - - <CODE> - - - - - - - - - - - -')
print(demos[which])
print('- - - - - - - - - - - </CODE> - - - - - - - - - - - -\n')
exec('global np\n' + demos[which], globals(), locals())
14 changes: 7 additions & 7 deletions chumpy/api_compatibility.py
Expand Up @@ -5,13 +5,13 @@
"""


import ch
from . import ch
import numpy as np
from os.path import join, split
from StringIO import StringIO
from six import StringIO
import numpy
import chumpy
import cPickle as pickle
from six.moves import cPickle as pickle

src = ''
num_passed = 0
Expand Down Expand Up @@ -71,7 +71,7 @@ def r(fn_name, args_req, args_opt, nplib=numpy, chlib=chumpy):

try:
if isinstance(args_req, dict):
fn(**dict(args_req.items() + args_opt.items()))
fn(**dict(list(args_req.items()) + list(args_opt.items())))
else:
fn(*args_req, **args_opt)
if lib is chlib:
Expand Down Expand Up @@ -108,7 +108,7 @@ def append(a, b, c):
b_color = lookup[b] if b in lookup else 'white'
c_color = lookup[c] if c in lookup else 'white'

print '%s: %s, %s' % (a,b,c)
print('%s: %s, %s' % (a,b,c))
make_row(a, b, c, b_color, c_color)

def m(s):
Expand Down Expand Up @@ -524,11 +524,11 @@ def main():
src = '<html><body><table border=1>' + src + '</table></body></html>'
open(join(split(__file__)[0], 'api_compatibility.html'), 'w').write(src)

print 'passed %d, not passed %d' % (num_passed, num_not_passed)
print('passed %d, not passed %d' % (num_passed, num_not_passed))



if __name__ == '__main__':
global which_passed
main()
print ' '.join(which_passed)
print(' '.join(which_passed))
58 changes: 30 additions & 28 deletions chumpy/ch.py
Expand Up @@ -18,11 +18,12 @@
import copy as external_copy
from functools import wraps
from scipy.sparse.linalg.interface import LinearOperator
import utils
from utils import row, col
from . import utils
from .utils import row, col
import collections
from copy import deepcopy
from utils import timer
from .utils import timer
from functools import reduce


# Turn this on if you want the profiler injected
Expand Down Expand Up @@ -102,7 +103,7 @@ def __new__(cls, *args, **kwargs):
object.__setattr__(result, '_cache_info', {})
object.__setattr__(result, '_status', 'new')

for name, default_val in cls._default_kwargs.items():
for name, default_val in list(cls._default_kwargs.items()):
object.__setattr__(result, '_%s' % name, kwargs.get(name, default_val))
if name in kwargs:
del kwargs[name]
Expand Down Expand Up @@ -325,7 +326,7 @@ def __len__(self):
return len(self.r)

def minimize(self, *args, **kwargs):
import optimization
from . import optimization
return optimization.minimize(self, *args, **kwargs)

def __array__(self, *args):
Expand Down Expand Up @@ -408,7 +409,7 @@ def __setattr__(self, name, value, itr=None):

def _invalidate_cacheprop_names(self, names):
nameset = set(names)
for func_name, v in self._depends_on_deps.items():
for func_name, v in list(self._depends_on_deps.items()):
if len(nameset.intersection(v['deps'])) > 0:
v['out_of_date'] = True

Expand All @@ -427,7 +428,7 @@ def clear_cache(self, itr=None):
next._cache['drs'].clear()
next._itr = itr

for parent, parent_dict in next._parents.items():
for parent, parent_dict in list(next._parents.items()):
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
parent._invalidate_cacheprop_names(parent_dict['varnames'])
todo.append(parent)
Expand All @@ -436,14 +437,14 @@ def clear_cache(self, itr=None):


def clear_cache_wrt(self, wrt, itr=None):
if self._cache['drs'].has_key(wrt):
if wrt in self._cache['drs']:
self._cache['drs'][wrt] = None

if hasattr(self, 'dr_cached') and wrt in self.dr_cached:
self.dr_cached[wrt] = None

if itr is None or itr != self._itr:
for parent, parent_dict in self._parents.items():
for parent, parent_dict in list(self._parents.items()):
if wrt in parent._cache['drs'] or (hasattr(parent, 'dr_cached') and wrt in parent.dr_cached):
parent.clear_cache_wrt(wrt=wrt, itr=itr)
object.__setattr__(parent, '_dirty_vars', parent._dirty_vars.union(parent_dict['varnames']))
Expand Down Expand Up @@ -662,7 +663,7 @@ def lmult_wrt(self, lhs, wrt):

if hasattr(p, 'dterms') and p is not wrt and p.is_dr_wrt(wrt):
if not isinstance(p, Ch):
print 'BROKEN!'
print('BROKEN!')
raise Exception('Broken Should be Ch object')

indirect_dr = p.lmult_wrt(self._superdot(lhs, self._compute_dr_wrt_sliced(p)), wrt)
Expand Down Expand Up @@ -817,7 +818,7 @@ def dr_wrt(self, wrt, reverse_mode=False, profiler=None):
# If we *always* filled in the cache, it would require
# more memory but would occasionally save a little cpu,
# on average.
if len(self._parents.keys()) != 1:
if len(list(self._parents.keys())) != 1:
self._cache['drs'][wrt] = result

if DEBUG:
Expand Down Expand Up @@ -895,7 +896,7 @@ def string_for(self, my_name):
color = 'blue'
if isinstance(dtval, reordering.Concatenate) and len(dtval.dr_cached) > 0:
s = 'dr_cached\n'
for k, v in dtval.dr_cached.iteritems():
for k, v in dtval.dr_cached.items():
if v is not None:
issparse = sp.issparse(v)
size = v.size
Expand All @@ -913,7 +914,7 @@ def string_for(self, my_name):
elif len(dtval._cache['drs']) > 0:
s = '_cache\n'

for k, v in dtval._cache['drs'].iteritems():
for k, v in dtval._cache['drs'].items():
if v is not None:
issparse = sp.issparse(v)
size = v.size
Expand Down Expand Up @@ -1029,8 +1030,8 @@ def string_for(self, my_name):
result += string_for(getattr(self, dterm), dterm)

if cachelim != np.inf and hasattr(self, '_cache') and 'drs' in self._cache:
import cPickle as pickle
for dtval, jac in self._cache['drs'].items():
from six.moves import cPickle as pickle
for dtval, jac in list(self._cache['drs'].items()):
# child_label = getattr(dtval, 'label') if hasattr(dtval, 'label') else dterm
# child_label = '%s (%s)' % (child_label, str(dtval.__class__.__name__))
src = 'aaa%d' % (id(self))
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def __mul__ (self, other): return ch_ops.multiply(a=self, b=other)
def __rmul__(self, other): return ch_ops.multiply(a=other, b=self)

def __div__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __truediv__ (self, other): return ch_ops.divide(x1=self, x2=other)
def __rdiv__(self, other): return ch_ops.divide(x1=other, x2=self)

def __pow__ (self, other): return ch_ops.power(x=self, pow=other)
Expand Down Expand Up @@ -1203,7 +1205,7 @@ def _depends_on(func):

@wraps(func)
def with_caching(self, *args, **kwargs):
func_name = func.func_name
func_name = func.__name__
sdf = self._depends_on_deps[func_name]
if sdf['out_of_date'] == True:
#tm = time.time()
Expand Down Expand Up @@ -1248,7 +1250,7 @@ def __init__(self, lmb, initial_args=None):
if initial_arg in args:
args[initial_arg].x = initial_args[initial_arg]
result = lmb(**args)
for argname, arg in args.items():
for argname, arg in list(args.items()):
if result.is_dr_wrt(arg.x):
self.add_dterm(argname, arg.x)
else:
Expand Down Expand Up @@ -1289,8 +1291,8 @@ def on_changed(self, which):
# it would be better if they could be "internal" as well, but for now the idea
# is that result may itself be a ChLambda.
def __init__(self, result, args):
self.args = { argname: ChHandle(x=arg) for argname, arg in args.items() }
for argname, arg in self.args.items():
self.args = { argname: ChHandle(x=arg) for argname, arg in list(args.items()) }
for argname, arg in list(self.args.items()):
setattr(result, argname, arg)
if result.is_dr_wrt(arg.x):
self.add_dterm(argname, arg.x)
Expand All @@ -1305,17 +1307,17 @@ def compute_r(self):
def compute_dr_wrt(self, wrt):
return self._result.dr_wrt(wrt)

import ch_ops
from ch_ops import *
from . import ch_ops
from .ch_ops import *
__all__ += ch_ops.__all__

import reordering
from reordering import *
from . import reordering
from .reordering import *
__all__ += reordering.__all__


import linalg
import ch_random as random
from . import linalg
from . import ch_random as random
__all__ += ['linalg', 'random']


Expand All @@ -1339,16 +1341,16 @@ def main():
x30 = Ch(30)

tmp = ChLambda(lambda x, y, z: Ch(1) + Ch(2) * Ch(3) + 4)
print tmp.dr_wrt(tmp.x)
print(tmp.dr_wrt(tmp.x))
import pdb; pdb.set_trace()
#a(b(c(d(e(f),g),h)))

blah = tst(x10, x20, x30)

print blah.r
print(blah.r)


print foo
print(foo)

import pdb; pdb.set_trace()

Expand Down
12 changes: 7 additions & 5 deletions chumpy/ch_ops.py
Expand Up @@ -44,13 +44,15 @@
__all__ += numpy_array_creation_routines


import ch
from . import ch
import six
import numpy as np
import warnings
import cPickle as pickle
from six.moves import cPickle as pickle
import scipy.sparse as sp
from utils import row, col
from .utils import row, col
from copy import copy as copy_copy
from functools import reduce

__all__ += ['pi', 'set_printoptions']
pi = np.pi
Expand Down Expand Up @@ -645,8 +647,8 @@ def _stride_for_axis(self,axis, mtx):
# np.amin here probably
idxs = np.arange(mtx.size).reshape(mtx.shape)
mn = np.amin(idxs, axis=axis)
stride = np.array(mtx.strides)
stride /= np.min(stride) # go from bytes to num elements
mtx_strides = np.array(mtx.strides)
stride = mtx_strides / np.min(mtx_strides) # go from bytes to num elements
stride = stride[axis]
return mn, stride

Expand Down
2 changes: 1 addition & 1 deletion chumpy/ch_random.py
Expand Up @@ -5,7 +5,7 @@
"""

import numpy.random
from ch import Ch
from .ch import Ch

api_not_implemented = ['choice','bytes','shuffle','permutation']

Expand Down
4 changes: 2 additions & 2 deletions chumpy/extras.py
@@ -1,8 +1,8 @@
__author__ = 'matt'

import ch
from . import ch
import numpy as np
from utils import row, col
from .utils import row, col
import scipy.sparse as sp
import scipy.special

Expand Down
18 changes: 9 additions & 9 deletions chumpy/linalg.py
Expand Up @@ -12,9 +12,9 @@

import numpy as np
import scipy.sparse as sp
from ch import Ch, depends_on, NanDivide
from utils import row, col
import ch
from .ch import Ch, depends_on, NanDivide
from .utils import row, col
from . import ch


try:
Expand Down Expand Up @@ -284,17 +284,17 @@ def slogdet(*args):
def main():

tmp = ch.random.randn(100).reshape((10,10))
print 'chumpy version: ' + str(slogdet(tmp)[1].r)
print 'old version:' + str(np.linalg.slogdet(tmp.r)[1])
print('chumpy version: ' + str(slogdet(tmp)[1].r))
print('old version:' + str(np.linalg.slogdet(tmp.r)[1]))

eps = 1e-10
diff = np.random.rand(100) * eps
diff_reshaped = diff.reshape((10,10))
print np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1]
print slogdet(tmp)[1].dr_wrt(tmp).dot(diff)
print(np.linalg.slogdet(tmp.r+diff_reshaped)[1] - np.linalg.slogdet(tmp.r)[1])
print(slogdet(tmp)[1].dr_wrt(tmp).dot(diff))

print np.linalg.slogdet(tmp.r)[0]
print slogdet(tmp)[0]
print(np.linalg.slogdet(tmp.r)[0])
print(slogdet(tmp)[0])

if __name__ == '__main__':
main()
Expand Down

0 comments on commit db6eaf8

Please sign in to comment.