In [23]:
import jax
import jax.numpy as jnp
import math

In [2]:
a = jnp.array([0, 0, 0, 0.3, 0.7]).reshape((1, -1))
b = jnp.array([0.1, 0.9, 0, 0, 0]).reshape((-1, 1))
b@a

Array([[0.  , 0.  , 0.  , 0.03, 0.07],
       [0.  , 0.  , 0.  , 0.27, 0.63],
       [0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  ]], dtype=float32)

In [3]:
a.shape

(1, 5)

In [4]:
# training data
num_trains = 10
num_tests = 200
input_dim = 1
output_dim = 1
x_train = jnp.linspace(-jnp.pi, jnp.pi, num_trains).reshape(num_trains, input_dim)
y_train = jnp.sin(x_train)
x_test = jnp.linspace(-jnp.pi, jnp.pi, num_tests).reshape(num_tests, input_dim)

In [5]:
feature_dim = 2
num_features = 50
num_bins = 5

In [6]:
std = 1 / jnp.sqrt(input_dim)
rng = jax.random.PRNGKey(42)
_rng, rng = jax.random.split(rng)
projection_matrix = std * jax.random.truncated_normal(
    key=_rng, lower=-2, upper=2, shape=(num_features, input_dim, feature_dim)
)

In [7]:
x = x_train[0 : 1]
y = y_train[0 : 1]

In [8]:
latent_vectors = jax.vmap(jnp.matmul, in_axes=[None, 0], out_axes=0)(x, projection_matrix)
latent_vectors = jax.nn.sigmoid(latent_vectors)
latent_vectors = latent_vectors * num_bins
latent_vectors, latent_vectors.shape

(Array([[[4.882359  , 0.7142538 ]],
 
        [[1.0420057 , 4.8252716 ]],
 
        [[0.31945843, 2.4596522 ]],
 
        [[4.091795  , 3.9456549 ]],
 
        [[1.0098107 , 1.9127394 ]],
 
        [[4.9750767 , 0.31888887]],
 
        [[0.7162783 , 0.42991114]],
 
        [[3.960459  , 0.09502675]],
 
        [[2.24647   , 4.1822186 ]],
 
        [[0.05617506, 0.25240266]],
 
        [[0.4121259 , 1.90544   ]],
 
        [[4.7805457 , 0.7636429 ]],
 
        [[2.0518212 , 4.7193713 ]],
 
        [[0.04069029, 4.8028646 ]],
 
        [[1.0596509 , 1.2458144 ]],
 
        [[4.6338477 , 1.3567331 ]],
 
        [[4.866855  , 3.9298487 ]],
 
        [[1.1417775 , 0.3799395 ]],
 
        [[0.17134286, 3.4881709 ]],
 
        [[0.23484844, 4.8579545 ]],
 
        [[4.867742  , 1.9413234 ]],
 
        [[3.222763  , 4.3263783 ]],
 
        [[4.0213447 , 0.04547948]],
 
        [[4.0328035 , 1.5068208 ]],
 
        [[3.961658  , 4.818966  ]],
 
        [[0.05374113, 0.21159226]],
 
        [[4.

In [154]:
# latent_vector = x@projection_matrix[:, :, 0]
# latent_vector = jax.nn.sigmoid(latent_vector)
# latent_vector = latent_vector * num_bins
# latent_vector

In [9]:
indices = jnp.floor(latent_vectors)
offsets = latent_vectors - indices
indices = jnp.concatenate([indices, indices+1], axis=1)  # 每个feature单独处理，第0维代表第几个feature，concate发生在feature内部，所以是第1维
offsets = jnp.concatenate([offsets, 1-offsets], axis=1)
offsets, offsets.shape, indices, indices.shape

(Array([[[0.882359  , 0.7142538 ],
         [0.11764097, 0.28574622]],
 
        [[0.04200566, 0.8252716 ],
         [0.95799434, 0.1747284 ]],
 
        [[0.31945843, 0.4596522 ],
         [0.6805416 , 0.5403478 ]],
 
        [[0.09179497, 0.94565487],
         [0.90820503, 0.05434513]],
 
        [[0.00981069, 0.9127394 ],
         [0.9901893 , 0.0872606 ]],
 
        [[0.9750767 , 0.31888887],
         [0.02492332, 0.6811111 ]],
 
        [[0.7162783 , 0.42991114],
         [0.2837217 , 0.57008886]],
 
        [[0.960459  , 0.09502675],
         [0.03954101, 0.90497327]],
 
        [[0.24646997, 0.18221855],
         [0.75353   , 0.81778145]],
 
        [[0.05617506, 0.25240266],
         [0.94382495, 0.74759734]],
 
        [[0.4121259 , 0.90544   ],
         [0.5878741 , 0.09456003]],
 
        [[0.7805457 , 0.7636429 ],
         [0.21945429, 0.2363571 ]],
 
        [[0.05182123, 0.7193713 ],
         [0.94817877, 0.28062868]],
 
        [[0.04069029, 0.80286455],
         [0.9593

In [10]:
all_values = jnp.zeros(shape=(2**feature_dim*num_features), dtype=jnp.float32)
all_indices = jnp.zeros_like(all_values, dtype=jnp.int32)  
for j, (indice, offset) in enumerate(zip(indices, offsets)):
    values = offset[:, 0:1]
    index = indice[:, 0:1]
    # multiplier = jnp.power(num_bins, jnp.arange(feature_dim-1, -1,-1))
    if feature_dim > 1:
        for i in range(1, feature_dim):
            values = values@offset[:, i:i+1].T
            values = values.reshape(-1, 1)
            index = index*(num_bins+1) + indice[:, i]
            index = index.reshape(-1, 1)
    all_values = all_values.at[j * 2**feature_dim:(j+1) * 2**feature_dim].add(values.flatten())
    all_indices = all_indices.at[j * 2**feature_dim:(j+1) * 2**feature_dim].add((index + j * num_bins**feature_dim).astype(jnp.int32).flatten())
    print(values, values.shape, index, index.shape)

[[0.6302283 ]
 [0.25213075]
 [0.08402551]
 [0.03361546]] (4, 1) [[24.]
 [25.]
 [30.]
 [31.]] (4, 1)
[[0.03466608]
 [0.00733958]
 [0.79060555]
 [0.16738881]] (4, 1) [[10.]
 [11.]
 [16.]
 [17.]] (4, 1)
[[0.14683977]
 [0.17261866]
 [0.31281242]
 [0.36772916]] (4, 1) [[2.]
 [3.]
 [8.]
 [9.]] (4, 1)
[[0.08680636]
 [0.00498861]
 [0.8588485 ]
 [0.04935652]] (4, 1) [[27.]
 [28.]
 [33.]
 [34.]] (4, 1)
[[8.9545995e-03]
 [8.5608638e-04]
 [9.0378481e-01]
 [8.6404517e-02]] (4, 1) [[ 7.]
 [ 8.]
 [13.]
 [14.]] (4, 1)
[[0.3109411 ]
 [0.6641355 ]
 [0.00794777]
 [0.01697555]] (4, 1) [[24.]
 [25.]
 [30.]
 [31.]] (4, 1)
[[0.307936  ]
 [0.4083423 ]
 [0.12197511]
 [0.16174658]] (4, 1) [[0.]
 [1.]
 [6.]
 [7.]] (4, 1)
[[0.09126929]
 [0.86918974]
 [0.00375745]
 [0.03578356]] (4, 1) [[18.]
 [19.]
 [24.]
 [25.]] (4, 1)
[[0.0449114 ]
 [0.20155858]
 [0.13730715]
 [0.61622286]] (4, 1) [[16.]
 [17.]
 [22.]
 [23.]] (4, 1)
[[0.01417874]
 [0.04199633]
 [0.23822393]
 [0.70560104]] (4, 1) [[0.]
 [1.]
 [6.]
 [7.]] (4, 1)


In [None]:
def _to_1d_index(indices, offsets, n_feat, bin_dim, n_bins):
    """Compute the flattened index into the weight matrix."""
    n_grids_per_lsh = (n_bins + 1) ** bin_dim
    indices = jnp.reshape(indices, (-1, bin_dim, n_feat))
    offsets = jnp.reshape(offsets, (-1, bin_dim, n_feat))
    indices = jnp.stack([indices, indices + 1], axis=-1)  # [-1, bin_dim, n_feat, 2]
    values = jnp.stack([1.0 - offsets, offsets], axis=-1)  # [-1, bin_dim, n_feat, 2]
    multiplier = jnp.power(n_bins + 1, jnp.arange(bin_dim - 1, -1, -1))
    indices *= multiplier[:, None, None]
    # shape = (-1, n_feat, ) + (2,) * bin_dim
    shape_suffix = [tuple(*p) for p in np.split(np.eye(bin_dim, dtype=np.int32) + 1, bin_dim)]
    indices = sum(jnp.reshape(indices[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
    values = math.prod(jnp.reshape(values[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
    # both indices and values has the shape (-1, n_feat, *(2,)*bin_dim) now.
    indices += jnp.expand_dims(
        n_grids_per_lsh * jnp.arange(n_feat), axis=tuple(range(-bin_dim, 1, 1))
    )  # expand 1 dim in the front and bin_dim in the back.
    indices = jnp.reshape(indices, (-1, n_feat * 2**bin_dim))
    values = jnp.reshape(values, (-1, n_feat * 2**bin_dim))
    return indices, values

In [22]:
index = indice[:, 0:1]
index = index*(num_bins) + indice[:, i]
index

Array([[22., 23.],
       [27., 28.]], dtype=float32)

In [18]:
49*25

1225

In [14]:
5**2*50

1250

In [137]:
index + int(j * num_bins**feature_dim)

Array([[38.],
       [39.],
       [43.],
       [44.]], dtype=float32)

In [132]:
all_values.at[j * num_bins**feature_dim:(j+1) * num_bins**feature_dim]

_IndexUpdateRef(Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32), slice(0, 16, None))

In [133]:
values.flatten()

Array([0.1178593 , 0.01706885, 0.75563747, 0.10943438], dtype=float32)

In [126]:
a = jnp.ones(3)
a

Array([1., 1., 1.], dtype=float32)

In [128]:
a.at[1].multiply(5).at[1].add(1)

Array([1., 6., 1.], dtype=float32)

In [124]:
a

Array([0., 0., 0.], dtype=float32)

In [67]:
indices[:, 0:1] * 5 + indices[:, 1]

Array([[20., 21.],
       [25., 26.]], dtype=float32)

In [16]:
projection_matrix.shape

(1, 2, 50)

In [21]:
(x@projection_matrix[:, :, 0]).shape

(1, 2)

In [17]:
h = jax.vmap(jnp.matmul, in_axes=[None, -1], out_axes=-1)(x, projection_matrix)
h = jax.nn.sigmoid(h)
h = h * num_bins
indices = jnp.floor(h).astype(jnp.int32)
offsets = h - indices
h.shape, indices.shape, offsets.shape  # [1, feature_dim, num_features]

((1, 2, 50), (1, 2, 50), (1, 2, 50))

In [18]:
indices = jnp.stack([indices, indices + 1], axis=-1)  # [1, bin_dim, n_feat, 2]
values = jnp.stack([1.0 - offsets, offsets], axis=-1)  # [1, bin_dim, n_feat, 2]

In [19]:
multiplier = jnp.power(num_bins + 1, jnp.arange(feature_dim - 1, -1, -1))
indices *= multiplier[:, None, None]
# shape = (-1, n_feat, ) + (2,) * bin_dim
shape_suffix = [tuple(*p) for p in np.split(np.eye(bin_dim, dtype=np.int32) + 1, bin_dim)]
indices = sum(jnp.reshape(indices[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
values = math.prod(jnp.reshape(values[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
# both indices and values has the shape (-1, n_feat, *(2,)*bin_dim) now.
indices += jnp.expand_dims(
    n_grids_per_lsh * jnp.arange(n_feat), axis=tuple(range(-bin_dim, 1, 1))
)  # expand 1 dim in the front and bin_dim in the back.
indices = jnp.reshape(indices, (-1, n_feat * 2**bin_dim))
values = jnp.reshape(values, (-1, n_feat * 2**bin_dim))

NameError: name 'np' is not defined

In [None]:
num_grids_per_feature = (num_bins + 1) ** feature_dim

Array([[ 0.64917064, -0.6725234 ]], dtype=float32)

In [None]:
proj_matrix = jnp.ones(shape=(2, 3, 4))
proj_matrix

Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)

In [None]:
jax.vmap(jnp.matmul, in_axes=[None, -1], out_axes=-1)(x, proj_matrix)

Array([[[-0.02335274, -0.02335274, -0.02335274, -0.02335274],
        [-0.02335274, -0.02335274, -0.02335274, -0.02335274],
        [-0.02335274, -0.02335274, -0.02335274, -0.02335274]]],      dtype=float32)