Skip to content

Commit 9ee786c

Browse files
authored
Merge pull request #2 from derwind/develop
Rough implementation of `einsum`
2 parents af43762 + b86de47 commit 9ee786c

File tree

3 files changed

+224
-8
lines changed

3 files changed

+224
-8
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
name: Tests
22
on:
33
push:
4-
branches: [master, develop]
4+
branches: [main, develop]
55
pull_request:
6-
branches: [master, develop]
6+
branches: [main, develop]
77
concurrency:
88
group: ${{ github.repository }}-${{ github.ref }}-${{ github.head_ref }}-${{ github.workflow }}
99
cancel-in-progress: true

mynumpy/core/ndarray.py

Lines changed: 120 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from typing import List, Tuple, Union, Optional, Any
2+
from typing import List, Tuple, Dict, Union, Optional, Any
33
from ..dtypes import Numbers
44

55

@@ -57,9 +57,13 @@ def __mul__(self, other: Union[Numbers, 'ndarray']) -> 'ndarray':
5757

5858
def __matmul__(self, other: 'ndarray') -> 'ndarray':
5959
if len(self.shape) < 1:
60-
raise ValueError(f'matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)')
60+
raise ValueError(
61+
f'matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)'
62+
)
6163
if len(other.shape) < 1:
62-
raise ValueError(f'matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)')
64+
raise ValueError(
65+
f'matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)'
66+
)
6367

6468
if len(self.shape) != 1 and len(self.shape) != 2:
6569
raise ValueError(f'matmul: Input operand 0 is neither a vector nor a matrix and not supported')
@@ -81,7 +85,9 @@ def __matmul__(self, other: 'ndarray') -> 'ndarray':
8185
squeeze_count += 1
8286

8387
if a.shape[1] != b.shape[0]:
84-
raise ValueError(f'matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size {b.shape[0]} is different from {a.shape[1]})')
88+
raise ValueError(
89+
f'matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size {b.shape[0]} is different from {a.shape[1]})'
90+
)
8591

8692
n_row = a.shape[0]
8793
n_col = b.shape[1]
@@ -299,10 +305,118 @@ def broadcast(a, shape: Union[List[int], Tuple[int]]) -> 'ndarray':
299305
def einsum(subscripts: str, *operands: List[ndarray]) -> ndarray:
300306
subscripts = subscripts.replace(' ', '')
301307

308+
from_indices, to_index = subscripts.split('->')
309+
if len(from_indices.split(',')) != len(operands):
310+
raise ValueError('more operands provided to einstein sum function than specified in the subscripts string')
311+
312+
index_list = [[idx for idx in index] for index in from_indices.split(',')]
313+
to_index = [idx for idx in to_index]
314+
315+
for i, (op, index) in enumerate(zip(operands, index_list)):
316+
if len(op.shape) > len(index):
317+
raise ValueError('operand has more dimensions than subscripts given in einstein sum')
318+
319+
if len(op.shape) < len(index):
320+
raise ValueError(f'einstein sum subscripts string contains too many subscripts for operand {i}')
321+
302322
if len(operands) != 2:
303323
raise ValueError(f'operands whose length != 2 are currently not supported')
304324

305325
a, b = operands
306-
from_, to_ = subscripts.split('->')
326+
index_a, index_b = index_list
327+
328+
# index char -> loc, e.g. {'i': 0, 'j': 1, 'k': 2, 'l': 3} for 'ijkl'
329+
i2l_a = {index: index_a.index(index) for index in index_a}
330+
i2l_b = {index: index_b.index(index) for index in index_b}
331+
332+
# determin output tensor's shape
333+
334+
out_shape = []
335+
# index char -> dim, e.g. {'i': 3, 'j': 4}
336+
i2d = {}
337+
for idx in to_index:
338+
if idx in i2l_a:
339+
dim = a.shape[i2l_a[idx]]
340+
out_shape.append(dim)
341+
i2d[idx] = dim
342+
continue
343+
if idx in i2l_b:
344+
dim = b.shape[i2l_b[idx]]
345+
out_shape.append(dim)
346+
i2d[idx] = dim
347+
continue
348+
raise ValueError(f"einstein sum subscripts string included output subscript '{idx}' which never appeared in an input")
349+
350+
# Preprocess finished. Main process begins
351+
352+
placeholder = zeros(out_shape).data
353+
354+
def fill_placeholder(target: ndarray, index: List[str], index_kv: Optional[Dict[str, int]] = None):
355+
if index_kv is None:
356+
index_kv = {}
357+
358+
idx, index = index[0], index[1:] # index chars
359+
360+
for i in range(i2d[idx]):
361+
index_kv_ = index_kv.copy()
362+
index_kv_[idx] = i
363+
if isinstance(target[i], list):
364+
fill_placeholder(target[i], index, index_kv_)
365+
continue
366+
367+
target[i] = calc_value(a, b, index_a, index_b, index_kv_)
368+
369+
# e.g. 'ijkl,jmln->ikm': sum_j sum_l sum_n A_{ijkl} B_{jmln}
370+
def calc_value(a_1: ndarray, a_2: ndarray, index_1: Tuple[str, ...], index_2: Tuple[str, ...], index_kv: Dict[str, int]):
371+
combinations_kv = []
372+
calc_combinations(list(a_1.shape), list(a_2.shape), index_1, index_2, index_kv, combinations_kv)
373+
374+
v = 0
375+
for idx_kv in combinations_kv:
376+
v_1 = get_value(a_1.data, index_1, idx_kv)
377+
v_2 = get_value(a_2.data, index_2, idx_kv)
378+
v += v_1 * v_2
379+
380+
return v
381+
382+
def calc_combinations(
383+
shape_1: List[int], shape_2: List[int], index_1: List[str], index_2: List[str], index_kv: Dict[str, int], out_combs: List[Dict[str, int]]
384+
):
385+
if index_1:
386+
idx1, index_1 = index_1[0], index_1[1:]
387+
dim1, shape_1 = shape_1[0], shape_1[1:]
388+
if idx1 in index_kv:
389+
calc_combinations(shape_1, shape_2, index_1, index_2, index_kv, out_combs)
390+
return
391+
else:
392+
for i in range(dim1):
393+
index_kv_ = index_kv.copy()
394+
index_kv_[idx1] = i
395+
calc_combinations(shape_1, shape_2, index_1, index_2, index_kv_, out_combs)
396+
return
397+
398+
if index_2:
399+
idx2, index_2 = index_2[0], index_2[1:]
400+
dim2, shape_2 = shape_2[0], shape_2[1:]
401+
if idx2 in index_kv:
402+
calc_combinations(shape_1, shape_2, index_1, index_2, index_kv, out_combs)
403+
return
404+
else:
405+
for i in range(dim2):
406+
index_kv_ = index_kv.copy()
407+
index_kv_[idx2] = i
408+
calc_combinations(shape_1, shape_2, index_1, index_2, index_kv_, out_combs)
409+
return
410+
411+
out_combs.append(index_kv)
412+
413+
def get_value(target: List[Numbers], index: List[Numbers], index_kv: Dict[str, int]):
414+
if isinstance(target, list):
415+
idx, index = index[0], index[1:]
416+
target = target[index_kv[idx]]
417+
return get_value(target, index, index_kv)
418+
return target
419+
420+
fill_placeholder(placeholder, to_index)
307421

308-
raise NotImplementedError('not implemented yet')
422+
return ndarray(placeholder)

test/ndarray_test.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,3 +1245,105 @@ def test_matmul(self):
12451245

12461246
with self.assertRaises(ValueError):
12471247
b @ a
1248+
1249+
def test_einsum(self):
1250+
a = mynp.array([1, 2])
1251+
1252+
b = mynp.array([
1253+
[1, 2],
1254+
[3, 4]
1255+
])
1256+
1257+
self.assertEqual(mynp.einsum('i,ij->j', a, b).data, [7, 10])
1258+
self.assertEqual(mynp.einsum('i,ji->j', a, b).data, [5, 11])
1259+
1260+
self.assertEqual(mynp.einsum('ij,i->j', b, a).data, [7, 10])
1261+
self.assertEqual(mynp.einsum('ij,j->i', b, a).data, [5, 11])
1262+
1263+
a = mynp.array([
1264+
[1, 2],
1265+
[3, 4]
1266+
])
1267+
1268+
b = mynp.array([
1269+
[-2, 1],
1270+
[-5, 3]
1271+
])
1272+
1273+
self.assertEqual(mynp.einsum('ij,jk->ik', a, b).data, [
1274+
[-12, 7],
1275+
[-26, 15]
1276+
])
1277+
1278+
self.assertEqual(mynp.einsum('jk,ki->ji', a, b).data, [
1279+
[-12, 7],
1280+
[-26, 15]
1281+
])
1282+
1283+
a = mynp.array([
1284+
[1, 2],
1285+
[3, 4],
1286+
[5, 6]
1287+
])
1288+
1289+
b = mynp.array([
1290+
[7, 8, 9, 10],
1291+
[11, 12, 13, 14]
1292+
])
1293+
1294+
self.assertEqual(mynp.einsum('ij,jk->ik', a, b).data, [
1295+
[ 29, 32, 35, 38],
1296+
[ 65, 72, 79, 86],
1297+
[101, 112, 123, 134]
1298+
])
1299+
1300+
a = mynp.array([
1301+
[
1302+
[1, 2],
1303+
[3, 4]
1304+
],
1305+
[
1306+
[5, 6],
1307+
[7, 8]
1308+
],
1309+
])
1310+
1311+
b = mynp.array([
1312+
[
1313+
[-1, -5],
1314+
[-3, 2],
1315+
[1, 4],
1316+
],
1317+
[
1318+
[3, 6],
1319+
[-3, 2],
1320+
[-4, 1],
1321+
]
1322+
])
1323+
1324+
self.assertEqual(mynp.einsum('ijk,ilj->jl', a, b).data, [
1325+
[30, -42, -41],
1326+
[55, 44, 43]
1327+
])
1328+
1329+
with self.assertRaises(ValueError):
1330+
data = mynp.array([
1331+
[1, 2],
1332+
[3, 4]
1333+
])
1334+
data2 = mynp.array([
1335+
[-2, 1],
1336+
[-5, 3]
1337+
])
1338+
mynp.einsum('ijl,jk->ik', a, b)
1339+
1340+
with self.assertRaises(ValueError):
1341+
data = mynp.array([
1342+
[1, 2],
1343+
[3, 4]
1344+
])
1345+
data2 = mynp.array([
1346+
[-2, 1],
1347+
[-5, 3]
1348+
])
1349+
mynp.einsum('i,jk->ik', a, b)

0 commit comments

Comments
 (0)