In [1]:
import numpy as np

In [2]:
a = np.arange(5).reshape(1,5)
b = -a
a, b


(array([[0, 1, 2, 3, 4]]), array([[ 0, -1, -2, -3, -4]]))

In [3]:
np.stack([a, b], axis=-1).max(axis=-1)

array([[0, 1, 2, 3, 4]])

In [4]:
class LogSemiring(np.ndarray):
    def __new__(cls, x):
        obj = np.asarray(x).view(cls)
        return obj
    
    def __add__(self, x):
        x = np.array(x)
        cat_vals = np.stack([self, x], axis=-1)
        max_vals = cat_vals.max(axis=-1, keepdims=True)
        exp_vals = np.exp(cat_vals - max_vals).sum(axis=-1, keepdims=True)
        log_add = max_vals + np.log(exp_vals)
        return self.__class__(log_add.reshape(*x.shape))
    
    def __mul__(self, x):
        x = np.array(x)
        y = np.array(self)
        return self.__class__(y + x)
    
    def __radd__(self, x):
        return self.__add__(x)
    
    def __rmul__(self, x):
        return self.__mul__(x)
    
class MaxPlusSemiring(np.ndarray):
    def __new__(cls, x):
        obj = np.asarray(x).view(cls)
        return obj
    
    def __add__(self, x):
        x = np.array(x)
        cat_vals = np.stack([self, x], axis=-1)
        max_vals = cat_vals.max(axis=-1)
        return self.__class__(max_vals)
    
    def __mul__(self, x):
        x = np.array(x)
        y = np.array(self)
        return self.__class__(y + x)
    
    def __radd__(self, x):
        return self.__add__(x)
    
    def __rmul__(self, x):
        return self.__mul__(x)
    
class MinPlusSemiring(np.ndarray):
    def __new__(cls, x):
        obj = np.asarray(x).view(cls)
        return obj
    
    def __add__(self, x):
        x = np.array(x)
        cat_vals = np.stack([self, x], axis=-1)
        max_vals = cat_vals.min(axis=-1)
        return self.__class__(max_vals)
    
    def __mul__(self, x):
        x = np.array(x)
        y = np.array(self)
        return self.__class__(y + x)
    
    def __radd__(self, x):
        return self.__add__(x)
    
    def __rmul__(self, x):
        return self.__mul__(x)

In [5]:
al = LogSemiring(a)
bl = LogSemiring(b)
al + al, al + bl, al*bl

(LogSemiring([[0.69314718, 1.69314718, 2.69314718, 3.69314718, 4.69314718]]),
 LogSemiring([[0.69314718, 1.12692801, 2.01814993, 3.00247569, 4.00033541]]),
 LogSemiring([[0, 0, 0, 0, 0]]))

In [6]:
al = MaxPlusSemiring(a)
bl = MaxPlusSemiring(b)
al + al, al + bl, al*bl

(MaxPlusSemiring([[0, 1, 2, 3, 4]]),
 MaxPlusSemiring([[0, 1, 2, 3, 4]]),
 MaxPlusSemiring([[0, 0, 0, 0, 0]]))

In [7]:
al = MinPlusSemiring(a)
bl = MinPlusSemiring(b)
al + al, al + bl, al*bl

(MinPlusSemiring([[0, 1, 2, 3, 4]]),
 MinPlusSemiring([[ 0, -1, -2, -3, -4]]),
 MinPlusSemiring([[0, 0, 0, 0, 0]]))