In [None]:
# 3.3 sorting buckets

In [1]:
import trax
import numpy as np
from trax import layers as tl
from trax import fastmath




In [2]:
t_n_hashes = 2
t_n_buckets = 4
t_n_seq = t_seqlen = 8
t_n_q = 3
n_v = 5

t_q = (np.array([(j % t_n_buckets) for j in range(t_n_seq)]) * np.ones((t_n_q, 1))).T
t_v = np.ones((t_n_seq, n_v))
t_buckets = np.array(
    [
        (j % t_n_buckets) + t_n_buckets * i
        for i in range(t_n_hashes)
        for j in range(t_n_seq)
    ]
)

In [3]:
def sort_buckets(buckets, q, v, n_buckets, n_hashes, seqlen, verbose=False):
    """ 
  Args:
    buckets: tensor of at least 2 dimension, 
    n_buckets: number of buckets in each hash table
    n_hashes: the number of hash tables    
    """
    if verbose:
        print("---sort_buckets--")
    ## Step 1
    ticker = np.arange(n_hashes*seqlen)
    if verbose:
        print("ticker", ticker.shape, ticker)
    ## Step 2
    buckets_and_t = seqlen * buckets + (ticker % seqlen)  # provided
    if verbose:
        print("buckets_and_t", buckets_and_t.shape, buckets_and_t)

    # Hash-based sort ("s" at the start of variable names means "sorted")
    # Step 3
    sbuckets_and_t, sticker = fastmath.sort_key_val(buckets_and_t, ticker, dimension=-1)
    if verbose:
        print("sbuckets_and_t", sbuckets_and_t.shape, sbuckets_and_t)
    if verbose:
        print("sticker", sticker.shape, sticker)
    # Step 4
    _, undo_sort = fastmath.sort_key_val(sticker, ticker, dimension=-1)
    if verbose:
        print("undo_sort", undo_sort.shape, undo_sort)

    # Step 5
    st = sticker % seqlen  # provided
    sq = np.take(q, st, axis = 0)
    sv = np.take(v, st, axis = 0)
    return sq, sv, sticker, undo_sort

In [4]:
t_sq, t_sv, t_sticker, t_undo_sort = sort_buckets(
    t_buckets, t_q, t_v, t_n_buckets, t_n_hashes, t_seqlen
)



In [7]:
t_undo_sort

DeviceArray([ 0,  4,  1,  5,  2,  6,  3,  7,  8, 12,  9, 13, 10, 14, 11,
             15], dtype=int32)

In [23]:
a = np.arange(16 * 3).reshape((16, 3))
b = np.arange(16 * 5).reshape((16, 5))
chunksize = 2


In [25]:
b

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44],
       [45, 46, 47, 48, 49],
       [50, 51, 52, 53, 54],
       [55, 56, 57, 58, 59],
       [60, 61, 62, 63, 64],
       [65, 66, 67, 68, 69],
       [70, 71, 72, 73, 74],
       [75, 76, 77, 78, 79]])

In [28]:
rsq = np.reshape(a, (-1, chunksize, a.shape[-1]))
rsqt = np.swapaxes(rsq, -1, -2)
dotlike = np.matmul(rsq,rsqt)

In [36]:
vr = np.reshape(b, (-1, chunksize, b.shape[-1]))
so = np.matmul(dotlike,vr)

In [38]:
print(dotlike.shape,vr.shape)

(8, 2, 2) (8, 2, 5)
