In [58]:
import numpy as np
import operator as op


equals, leq, lt, geq, gt = op.eq, op.le, op.lt, op.ge, op.gt


def full(x, const):
    return np.full_like(x, const, dtype=int)
    
def indices(x):
    return np.arange(len(x), dtype=int)
    
def tok_map(x, func):
    return np.array([func(xi) for xi in x]).astype(int)
    
def seq_map(x , y, func):
    return np.array([func(xi, yi) for xi, yi in zip(x,y)]).astype(int)
    
def select(k, q, pred, causal=True):
    s = len(k)
    A = np.zeros((s, s), dtype=bool)
    
    for qi in range(s):
        for kj in (range(qi+1) if causal else range(s)): # k_index <= q_index if causal
            A[qi, kj] = pred(k[kj], q[qi])
    return A

def sel_width(A):
    return np.dot(A, np.ones(len(A))).astype(int)


def aggr_sum(A, v, default=0):
    return np.dot(A, v)


def aggr_mean(A, v, default=0):
    out = np.dot(A, v)
    norm = sel_width(A)
    out = np.divide(out, norm, out=np.full_like(v, default,dtype=float), where=(norm != 0))
    return out.astype(int)


def aggr_max(A, v, default=0):
    out = np.full_like(v, default)
    for i, row in enumerate(A):
        idxs = np.flatnonzero(row)
        if len(idxs) > 0:
            out[i] = np.max(v[idxs]) # max of selected elements in v
    return out.astype(int)


def aggr(A, v, default=0, reduction='mean'):
    match reduction:
        case 'mean':
            return aggr_mean(A, v, default)
        case 'max':
            return aggr_max(A, v, default)
        case 'sum':
            return aggr_sum(A, v, default)
        case 'min':
            return -1 * aggr_max(A, -v, -default)
        case _:
            raise NotImplementedError(f'Reduction "{reduction}" not implemented.')


def kqv(k, q, v, pred, default=0, reduction='mean'):
    return aggr(
        select(k, q, pred),
        v,
        default=default,
        reduction=reduction
    )


In [26]:
rng = np.random.default_rng()

seq = rng.integers(1, 50, 10)

shift_right(seq, 2)

array([ 0,  0, 11, 49, 28,  4, 26, 11, 40, 45])

In [None]:
def shift_right(x, n, default=0):
    # shifts sequence x to the right by n positions
    return kqv(indices(x) + n, indices(x), x, equals, default=default)

## Cumulative Sum

In [88]:
def cumsum(bool_array):
    # returns number of previous True elements in bool_array
    raise sel_width(select(bool_array, bool_array, lambda k, q: k))




In [95]:
def where(condition, x_if, y_else):
    # equivalent to np.where(condition, x_if, y_else)
    x = seq_map(x_if, condition, lambda x, c: x if c else -1)
    y = seq_map(y_else, condition, lambda y, c: y if not c else -1)
    return seq_map(x, y, lambda x, y: x if x != -1 else y)


In [96]:
x = np.arange(10)
y = np.arange(20, 30)
cond = rng.integers(0, 2, 10).astype(bool)
print(x)
print(y)
print(cond)
where(cond, x, y)

[0 1 2 3 4 5 6 7 8 9]
[20 21 22 23 24 25 26 27 28 29]
[ True False  True False  True False  True False False  True]


array([ 0, 21,  2, 23,  4, 25,  6, 27, 28,  9])

In [97]:
def mask(x, bool_mask, mask_val=0):
    # equivalent to x*bool_mask + default*(~bool_mask)
    return where(bool_mask, x, full(x, mask_val))


In [98]:
mask(np.arange(5), np.array([True, True, True, False, False]))

array([0, 1, 2, 0, 0])

In [124]:
def maximum(x):
    return kqv(x, x, x, lambda k, v: True, reduction='max')

x = rng.integers(0, 10, 5)
print(x)
maximum(x)



[3 8 5 9 5]


array([3, 8, 8, 9, 9])

In [125]:
def minimum(x):
    return -maximum(-x)

x = rng.integers(0, 10, 5)
print(x)
minimum(x)


[2 8 9 0 1]


array([2, 2, 2, 0, 0])

In [143]:
def argmax(x):
    return kqv(x, maximum(x), indices(x), equals, reduction='max')

x = rng.integers(0, 10, 5)
print(x)
argmax(x)

[2 3 1 5 2]


array([0, 1, 1, 3, 3])

In [148]:
np.roll(np.arange(10), 3)

array([7, 8, 9, 0, 1, 2, 3, 4, 5, 6])

In [144]:
def argmin(x):
    return argmax(-x)

In [None]:

def num_prev(x, queries):
    # output[i] = number of previous elements of x equal to queries[i], inclusive
    raise NotImplementedError  

In [138]:

def has_seen(x, queries):
    raise NotImplementedError    


def firsts(x, queries, default=-1):
    # find the index of the first occurrence of each query[i] in x
    # out[i] := np.flatnonzero(x[:i+1] == queries[i]).min()
    raise NotImplementedError 
    
def lasts(x, queries, default=-1):
    # find the index of the last occurrence of each query[i] in x
    # out[i] := np.flatnonzero(x[:i+1] == queries[i]).max()
    raise NotImplementedError


def index_select(x, idx, default=0):
    # indexes into sequence x, via index sequence idx
    # i.e. return x[idx] if idx[i] <= i else default
    raise NotImplementedError
    

def first_true(x, default=-1):
    # returns the index of the first true value in x
    raise NotImplementedError

def induct_kqv(k, q, v, offset, default=0, null_val=-999):
    # get value of v at index of: first occurrence of q[i] found in k (if found) + offset.
    # (excludes the last OFFSET tokens of k from matching)
    # null_val is a special token that cannot appear in k or q; used to prevent accidental matches
    raise NotImplementedError


def induct(k, q, offset, default=0, null_val=-999):
    raise NotImplementedError


def induct_prev(k, q, offset, default=0, null_val=-999):
    # A version of induct for negative offsets.
    raise NotImplementedError