11import copy
2- from typing import List , Tuple , Union , Optional , Any
2+ from typing import List , Tuple , Dict , Union , Optional , Any
33from ..dtypes import Numbers
44
55
@@ -57,9 +57,13 @@ def __mul__(self, other: Union[Numbers, 'ndarray']) -> 'ndarray':
5757
5858 def __matmul__ (self , other : 'ndarray' ) -> 'ndarray' :
5959 if len (self .shape ) < 1 :
60- raise ValueError (f'matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)' )
60+ raise ValueError (
61+ f'matmul: Input operand 0 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)'
62+ )
6163 if len (other .shape ) < 1 :
62- raise ValueError (f'matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)' )
64+ raise ValueError (
65+ f'matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)'
66+ )
6367
6468 if len (self .shape ) != 1 and len (self .shape ) != 2 :
6569 raise ValueError (f'matmul: Input operand 0 is neither a vector nor a matrix and not supported' )
@@ -81,7 +85,9 @@ def __matmul__(self, other: 'ndarray') -> 'ndarray':
8185 squeeze_count += 1
8286
8387 if a .shape [1 ] != b .shape [0 ]:
84- raise ValueError (f'matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size { b .shape [0 ]} is different from { a .shape [1 ]} )' )
88+ raise ValueError (
89+ f'matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size { b .shape [0 ]} is different from { a .shape [1 ]} )'
90+ )
8591
8692 n_row = a .shape [0 ]
8793 n_col = b .shape [1 ]
@@ -299,10 +305,118 @@ def broadcast(a, shape: Union[List[int], Tuple[int]]) -> 'ndarray':
299305def einsum (subscripts : str , * operands : List [ndarray ]) -> ndarray :
300306 subscripts = subscripts .replace (' ' , '' )
301307
308+ from_indices , to_index = subscripts .split ('->' )
309+ if len (from_indices .split (',' )) != len (operands ):
310+ raise ValueError ('more operands provided to einstein sum function than specified in the subscripts string' )
311+
312+ index_list = [[idx for idx in index ] for index in from_indices .split (',' )]
313+ to_index = [idx for idx in to_index ]
314+
315+ for i , (op , index ) in enumerate (zip (operands , index_list )):
316+ if len (op .shape ) > len (index ):
317+ raise ValueError ('operand has more dimensions than subscripts given in einstein sum' )
318+
319+ if len (op .shape ) < len (index ):
320+ raise ValueError (f'einstein sum subscripts string contains too many subscripts for operand { i } ' )
321+
302322 if len (operands ) != 2 :
303323 raise ValueError (f'operands whose length != 2 are currently not supported' )
304324
305325 a , b = operands
306- from_ , to_ = subscripts .split ('->' )
326+ index_a , index_b = index_list
327+
328+ # index char -> loc, e.g. {'i': 0, 'j': 1, 'k': 2, 'l': 3} for 'ijkl'
329+ i2l_a = {index : index_a .index (index ) for index in index_a }
330+ i2l_b = {index : index_b .index (index ) for index in index_b }
331+
332+ # determin output tensor's shape
333+
334+ out_shape = []
335+ # index char -> dim, e.g. {'i': 3, 'j': 4}
336+ i2d = {}
337+ for idx in to_index :
338+ if idx in i2l_a :
339+ dim = a .shape [i2l_a [idx ]]
340+ out_shape .append (dim )
341+ i2d [idx ] = dim
342+ continue
343+ if idx in i2l_b :
344+ dim = b .shape [i2l_b [idx ]]
345+ out_shape .append (dim )
346+ i2d [idx ] = dim
347+ continue
348+ raise ValueError (f"einstein sum subscripts string included output subscript '{ idx } ' which never appeared in an input" )
349+
350+ # Preprocess finished. Main process begins
351+
352+ placeholder = zeros (out_shape ).data
353+
354+ def fill_placeholder (target : ndarray , index : List [str ], index_kv : Optional [Dict [str , int ]] = None ):
355+ if index_kv is None :
356+ index_kv = {}
357+
358+ idx , index = index [0 ], index [1 :] # index chars
359+
360+ for i in range (i2d [idx ]):
361+ index_kv_ = index_kv .copy ()
362+ index_kv_ [idx ] = i
363+ if isinstance (target [i ], list ):
364+ fill_placeholder (target [i ], index , index_kv_ )
365+ continue
366+
367+ target [i ] = calc_value (a , b , index_a , index_b , index_kv_ )
368+
369+ # e.g. 'ijkl,jmln->ikm': sum_j sum_l sum_n A_{ijkl} B_{jmln}
370+ def calc_value (a_1 : ndarray , a_2 : ndarray , index_1 : Tuple [str , ...], index_2 : Tuple [str , ...], index_kv : Dict [str , int ]):
371+ combinations_kv = []
372+ calc_combinations (list (a_1 .shape ), list (a_2 .shape ), index_1 , index_2 , index_kv , combinations_kv )
373+
374+ v = 0
375+ for idx_kv in combinations_kv :
376+ v_1 = get_value (a_1 .data , index_1 , idx_kv )
377+ v_2 = get_value (a_2 .data , index_2 , idx_kv )
378+ v += v_1 * v_2
379+
380+ return v
381+
382+ def calc_combinations (
383+ shape_1 : List [int ], shape_2 : List [int ], index_1 : List [str ], index_2 : List [str ], index_kv : Dict [str , int ], out_combs : List [Dict [str , int ]]
384+ ):
385+ if index_1 :
386+ idx1 , index_1 = index_1 [0 ], index_1 [1 :]
387+ dim1 , shape_1 = shape_1 [0 ], shape_1 [1 :]
388+ if idx1 in index_kv :
389+ calc_combinations (shape_1 , shape_2 , index_1 , index_2 , index_kv , out_combs )
390+ return
391+ else :
392+ for i in range (dim1 ):
393+ index_kv_ = index_kv .copy ()
394+ index_kv_ [idx1 ] = i
395+ calc_combinations (shape_1 , shape_2 , index_1 , index_2 , index_kv_ , out_combs )
396+ return
397+
398+ if index_2 :
399+ idx2 , index_2 = index_2 [0 ], index_2 [1 :]
400+ dim2 , shape_2 = shape_2 [0 ], shape_2 [1 :]
401+ if idx2 in index_kv :
402+ calc_combinations (shape_1 , shape_2 , index_1 , index_2 , index_kv , out_combs )
403+ return
404+ else :
405+ for i in range (dim2 ):
406+ index_kv_ = index_kv .copy ()
407+ index_kv_ [idx2 ] = i
408+ calc_combinations (shape_1 , shape_2 , index_1 , index_2 , index_kv_ , out_combs )
409+ return
410+
411+ out_combs .append (index_kv )
412+
413+ def get_value (target : List [Numbers ], index : List [Numbers ], index_kv : Dict [str , int ]):
414+ if isinstance (target , list ):
415+ idx , index = index [0 ], index [1 :]
416+ target = target [index_kv [idx ]]
417+ return get_value (target , index , index_kv )
418+ return target
419+
420+ fill_placeholder (placeholder , to_index )
307421
308- raise NotImplementedError ( 'not implemented yet' )
422+ return ndarray ( placeholder )
0 commit comments