## Exploratory code to see if Environment class should have an __eq__ method or not

In [26]:
import numpy as np

class Environment:
    
    def __init__(self, state):
        self.state = state

    def __eq__(self, other):
        """Compare two environments based on their states.
        """
        if isinstance(other, self.__class__):
            comparison = (self.state == other.state)
            if isinstance(comparison, bool):
                return comparison
            elif isinstance(comparison, np.ndarray):
                return comparison.all()
            else:
                try:
                    return all(comparison)
                except TypeError:
                    raise NotImplementedError("Can't compare environments")
        return False

In [120]:
s = [[1, 2, 3], [4, 5, 6]]
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

True

In [121]:
s = 'abcdef'
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

True

In [122]:
s = np.array(range(6))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

True

In [123]:
s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

True

In [124]:
s = np.array(range(27)).reshape((3, 3, 3))
e1 = Environment(s)
e2 = Environment(s)

e1 == e2  # True

True

In [119]:
import numpy as np

def nd_true(nd_object):
    try:
        iterator = iter(nd_object)
    except TypeError:
        return nd_object
    else:
        return all([nd_true(x) for x in iterator])

class Environment:

    def __init__(self, state):
        self.state = state

    def __eq__(self, other):
        """Compare two environments based on their states.
        """
        if isinstance(other, self.__class__):
            return nd_true(self.state == other.state)
        return False

In [132]:
class NonComparable:
    
    def __init__(self, value):
        self.value = value
    
    def __eq__(self, other):
        
        raise NotImplementedError("Can't compare this object")

In [133]:
s = NonComparable(99)

e1 = Environment(s)
e2 = Environment(s)

e1 == e2

NotImplementedError: Can't compare this object

In [157]:
s = np.array(range(6)).reshape((2, 3))
s

array([[0, 1, 2],
       [3, 4, 5]])

In [165]:
s = np.array(range(27)).reshape((3, 3, 3))

In [178]:
from itertools import chain

def flatten(x):
    
    try:
        iterator = iter(x)
    except TypeError:
        return x
    return chain.from_iterable([flatten(i) for i in iterator])

flatten(s == s)

<itertools.chain at 0x11805aeb8>

In [248]:
def flatten(arr):
    for i in arr:
        if isinstance(i, list):
            yield from flatten(i)
        else:
            yield i

In [256]:
def flatten2(arr):
    for i in arr:
        try:
            iterator = iter(i)
        except TypeError:
            yield i
        else:
            yield from flatten2(iterator)


In [309]:
def nd_true(nd_object):
    try:
        iterator = iter(nd_object)
    except TypeError:
        return nd_object
    else:
        return all([nd_true(x) for x in iterator])

In [292]:
list(np.ones((2, 2)))

[array([1., 1.]), array([1., 1.])]

In [289]:
all((a == b).tolist())

True

In [295]:
bool([True, True, False])

True

In [296]:
all([[[False], [False]], [[False], [False]]])

True

In [280]:
def nd_true2(nd_object):
    try:
        return all(nd_object)
    except ValueError:
        return all([nd_true2(x) for x in nd_object])

In [282]:
assert all(flatten((a == b).tolist()))
assert all(flatten2((a == b))) is True
assert nd_true((a == b)) is True
assert nd_true2((a == b)) is True
assert np.array_equal(a, b)

In [326]:
def nd_equal(a, b):
    try:
        iterator = zip(a, b)
    except TypeError:
        return a == b
    else:
        return all([nd_equal(a, b) for a, b in iterator])

In [328]:
a = [1, 2, np.array([3, 4, 5]), 6]
b = [1, 2, np.array([3, 4, 5]), 6]

assert nd_equal(a, b) is True

b[2][1] += 1

assert nd_equal(a, b) is False

In [329]:
a = np.random.random((5, 5, 5))
b = a.copy()

In [265]:
%timeit all(flatten((a == b).tolist()))

38.1 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [266]:
%timeit all(flatten2((a == b)))

110 µs ± 523 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [267]:
%timeit nd_true((a == b))

116 µs ± 654 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [283]:
%timeit nd_true2((a == b))

37.7 µs ± 852 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [268]:
%timeit np.array_equal(a, b)

4.27 µs ± 96.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [270]:
%timeit a.tolist() == b.tolist()

8 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [330]:
%timeit nd_equal(a, b)

188 µs ± 712 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [305]:
b[2][2] = 10
b

[1, 2, array([ 3,  4, 10]), 6]

In [306]:
a == b

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [366]:
import collections
import six

def iterable(arg):
    return (
        isinstance(arg, collections.Iterable) 
        and not isinstance(arg, six.string_types)
    )

def is_equal(a, b):
    
    if iterable(a):
        return all([is_equal(a, b) for a, b in zip(a, b)])
    else:
        return a == b

class Environment:

    def __init__(self, state):
        self.state = state

    def __eq__(self, other):
        """Compare two environments based on their states.
        """
        if isinstance(other, self.__class__):
            return is_equal(self.state, other.state)
        return False


In [389]:
# Tests    
s = 'abcdef'
e1 = Environment(s)
e2 = Environment('abcdef')

e1 == e2  # True

True

In [390]:
s = [[1, 2, 3], [4, 5, 6]]
e1 = Environment(s)
e2 = Environment(s.copy())

e1 == e2  # True

True

In [391]:
s = np.array([1, 2, 3])
e1 = Environment(s)
e2 = Environment(s.copy())

e1 == e2  # True

True

In [392]:
s = np.array(range(6)).reshape((2, 3))
e1 = Environment(s)
e2 = Environment(s.copy())

e1 == e2  # True

True

In [393]:
s = np.array(range(27)).reshape((3, 3, 3))
e1 = Environment(s)
e2 = Environment(s.copy())

e1 == e2  # True

True

In [394]:
s = [1, 2, np.array(range(4)).reshape((2, 2)), {6, 7}]
e1 = Environment(s)
e2 = Environment(s.copy())

e1 == e2  # True

True

In [395]:
a = [1, 2, np.array(range(4)).reshape((2, 2)), {6, 7}]
b = [1, 2, np.array(range(4)).reshape((2, 2)), {6, 7}]

assert is_equal(a, b) is True

b[2][1, 0] += 1

assert is_equal(a, b) is False

In [396]:
a = np.random.random((5, 5, 5))
b = a.copy()

%timeit is_equal(a, b)

174 µs ± 415 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
