In [1]:
import numpy as np
from functools import reduce
from typing import TypeVar, Callable

T = TypeVar('T') # generic type for output of F(M)

def R(x: T, row: np.ndarray): ... # reducer

initial_value: T = ...

def F(M: np.ndarray) -> T:
  return reduce(R, M, initial_value)

In [2]:
A = np.array([
  [1, 2],
  [1, 3],
  [2, 3],
  [0, -2],
  [3, 4],
  [2, 1],
])

B = np.array([
  [3, 0, 2, 1, 2, 3],
  [2, 1, 2, -6, 4, 2]
])

np.matmul(A, B)

array([[  7,   2,   6, -11,  10,   7],
       [  9,   3,   8, -17,  14,   9],
       [ 12,   3,  10, -16,  16,  12],
       [ -4,  -2,  -4,  12,  -8,  -4],
       [ 17,   4,  14, -21,  22,  17],
       [  8,   1,   6,  -4,   8,   8]])

In [3]:
# computes max entry
initial_value: float = -np.inf
def R(x: float, row: np.ndarray) -> float:
  return max(x, row.max())

In [4]:
F(np.matmul(A, B))

22

In [5]:
m = np.matmul(A, B).shape[1]

In [6]:
# computes max of each column
initial_value: np.ndarray = np.full((1, m), -np.inf)
def R(x: np.ndarray, row: np.ndarray) -> np.ndarray:
  return np.maximum(x, row)

In [7]:
F(np.matmul(A, B))

array([[17.,  4., 14., 12., 22., 17.]])

In [8]:
def F_batched(A: np.ndarray,
              B: np.ndarray,
              batch_size: int,
              R: Callable[[T, np.ndarray], T],
              initial_value: T) -> T:
  num_batches = max(int(A.shape[0] / batch_size), 1)
  x = initial_value
  for A_submatrix in np.array_split(A, num_batches):
    x = reduce(R, np.matmul(A_submatrix, B), x)
  return x

In [9]:
F_batched(A, B, 2, R, initial_value)

array([[17.,  4., 14., 12., 22., 17.]])