Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Windows_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy==1.21.0
python -m pip install "jax[cpu]==0.3.5" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install "jax[cpu]==0.3.2" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install -r requirements-win.txt
python -m pip install tqdm brainpylib
python setup.py install
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
publishment.md
#experimental/
.vscode

io_test_tmp*

brainpy/base/tests/io_test_tmp*

Expand Down
34 changes: 17 additions & 17 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __sub__(self, oc):
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))

def __rsub__(self, oc):
return JaxArray(self._value - (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) - self._value)

def __isub__(self, oc):
# a -= b
Expand All @@ -249,7 +249,7 @@ def __mul__(self, oc):
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))

def __rmul__(self, oc):
return JaxArray(self._value * (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) * self._value)

def __imul__(self, oc):
# a *= b
Expand All @@ -258,17 +258,17 @@ def __imul__(self, oc):
self._value = self._value * (oc._value if isinstance(oc, JaxArray) else oc)
return self

def __div__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
# def __div__(self, oc):
# return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))

def __rdiv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)

def __truediv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))

def __rtruediv__(self, oc):
return JaxArray(self._value / (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) / self._value)

def __itruediv__(self, oc):
# a /= b
Expand All @@ -281,7 +281,7 @@ def __floordiv__(self, oc):
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))

def __rfloordiv__(self, oc):
return JaxArray(self._value // (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) // self._value)

def __ifloordiv__(self, oc):
# a //= b
Expand All @@ -291,16 +291,16 @@ def __ifloordiv__(self, oc):
return self

def __divmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc))

def __rdivmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray(self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc))

def __mod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))

def __rmod__(self, oc):
return JaxArray(self._value % (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) % self._value)

def __imod__(self, oc):
# a %= b
Expand All @@ -313,7 +313,7 @@ def __pow__(self, oc):
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))

def __rpow__(self, oc):
return JaxArray(self._value ** (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ** self._value)

def __ipow__(self, oc):
# a **= b
Expand All @@ -326,7 +326,7 @@ def __matmul__(self, oc):
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))

def __rmatmul__(self, oc):
return JaxArray(self._value @ (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) @ self._value)

def __imatmul__(self, oc):
# a @= b
Expand All @@ -339,7 +339,7 @@ def __and__(self, oc):
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))

def __rand__(self, oc):
return JaxArray(self._value & (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) & self._value)

def __iand__(self, oc):
# a &= b
Expand All @@ -352,7 +352,7 @@ def __or__(self, oc):
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))

def __ror__(self, oc):
return JaxArray(self._value | (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) | self._value)

def __ior__(self, oc):
# a |= b
Expand All @@ -365,7 +365,7 @@ def __xor__(self, oc):
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))

def __rxor__(self, oc):
return JaxArray(self._value ^ (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) ^ self._value)

def __ixor__(self, oc):
# a ^= b
Expand All @@ -378,7 +378,7 @@ def __lshift__(self, oc):
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))

def __rlshift__(self, oc):
return JaxArray(self._value << (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) << self._value)

def __ilshift__(self, oc):
# a <<= b
Expand All @@ -391,7 +391,7 @@ def __rshift__(self, oc):
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))

def __rrshift__(self, oc):
return JaxArray(self._value >> (oc._value if isinstance(oc, JaxArray) else oc))
return JaxArray((oc._value if isinstance(oc, JaxArray) else oc) >> self._value)

def __irshift__(self, oc):
# a >>= b
Expand Down
Loading