In [7]:
import jax
import jax.numpy as jnp

In [8]:
a = jnp.array([0, 0, 0, 0.3, 0.7]).reshape((1, -1))
b = jnp.array([0.1, 0.9, 0, 0, 0]).reshape((-1, 1))
b@a

Array([[0.  , 0.  , 0.  , 0.03, 0.07],
       [0.  , 0.  , 0.  , 0.27, 0.63],
       [0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  ]], dtype=float32)

In [9]:
a.shape

(1, 5)

In [10]:
# training data
num_trains = 10
num_tests = 200
input_dim = 1
output_dim = 1
x_train = jnp.linspace(-jnp.pi, jnp.pi, num_trains).reshape(num_trains, input_dim)
y_train = jnp.sin(x_train)
x_test = jnp.linspace(-jnp.pi, jnp.pi, num_tests).reshape(num_tests, input_dim)

In [150]:
feature_dim = 1
num_features = 1
num_bins = 4

In [151]:
std = 1 / jnp.sqrt(input_dim)
rng = jax.random.PRNGKey(42)
_rng, rng = jax.random.split(rng)
projection_matrix = std * jax.random.truncated_normal(
    key=_rng, lower=-2, upper=2, shape=(num_features, input_dim, feature_dim)
)

In [152]:
x = x_train[0 : 1]
y = y_train[0 : 1]

In [153]:
latent_vectors = jax.vmap(jnp.matmul, in_axes=[None, 0], out_axes=0)(x, projection_matrix)
latent_vectors = jax.nn.sigmoid(latent_vectors)
latent_vectors = latent_vectors * num_bins
latent_vectors, latent_vectors.shape

(Array([[[1.5923828]]], dtype=float32), (1, 1, 1))

In [154]:
# latent_vector = x@projection_matrix[:, :, 0]
# latent_vector = jax.nn.sigmoid(latent_vector)
# latent_vector = latent_vector * num_bins
# latent_vector

In [155]:
indices = jnp.floor(latent_vectors)
offsets = latent_vectors - indices
indices = jnp.concatenate([indices, indices+1], axis=1)  # 每个feature单独处理，第0维代表第几个feature，concate发生在feature内部，所以是第1维
offsets = jnp.concatenate([offsets, 1-offsets], axis=1)
offsets, offsets.shape, indices, indices.shape

(Array([[[0.5923828],
         [0.4076172]]], dtype=float32),
 (1, 2, 1),
 Array([[[1.],
         [2.]]], dtype=float32),
 (1, 2, 1))

In [156]:
all_values = jnp.zeros(shape=(2**feature_dim*num_features), dtype=jnp.float32)
all_indices = jnp.zeros_like(all_values, dtype=jnp.int32)  
for j, (indice, offset) in enumerate(zip(indices, offsets)):
    values = offset[:, 0:1]
    index = indice[:, 0:1]
    # multiplier = jnp.power(num_bins, jnp.arange(feature_dim-1, -1,-1))
    if feature_dim > 1:
        for i in range(1, feature_dim):
            values = values@offset[:, i:i+1].T
            values = values.reshape(-1, 1)
            index = index*(num_bins+1) + indice[:, i]
            index = index.reshape(-1, 1)
    all_values = all_values.at[j * 2**feature_dim:(j+1) * 2**feature_dim].add(values.flatten())
    all_indices = all_indices.at[j * 2**feature_dim:(j+1) * 2**feature_dim].add((index + j * num_bins**feature_dim).astype(jnp.int32).flatten())
    print(values, values.shape, index, index.shape)

[[0.5923828]
 [0.4076172]] (2, 1) [[1.]
 [2.]] (2, 1)


In [157]:
all_values, '\n', all_indices

(Array([0.5923828, 0.4076172], dtype=float32),
 '\n',
 Array([1, 2], dtype=int32))

In [141]:
0.1178593 +0.01706885+0.75563747+0.10943438

1.0

In [137]:
index + int(j * num_bins**feature_dim)

Array([[38.],
       [39.],
       [43.],
       [44.]], dtype=float32)

In [132]:
all_values.at[j * num_bins**feature_dim:(j+1) * num_bins**feature_dim]

_IndexUpdateRef(Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32), slice(0, 16, None))

In [133]:
values.flatten()

Array([0.1178593 , 0.01706885, 0.75563747, 0.10943438], dtype=float32)

In [126]:
a = jnp.ones(3)
a

Array([1., 1., 1.], dtype=float32)

In [128]:
a.at[1].multiply(5).at[1].add(1)

Array([1., 6., 1.], dtype=float32)

In [124]:
a

Array([0., 0., 0.], dtype=float32)

In [67]:
indices[:, 0:1] * 5 + indices[:, 1]

Array([[20., 21.],
       [25., 26.]], dtype=float32)

In [16]:
projection_matrix.shape

(1, 2, 50)

In [21]:
(x@projection_matrix[:, :, 0]).shape

(1, 2)

In [17]:
h = jax.vmap(jnp.matmul, in_axes=[None, -1], out_axes=-1)(x, projection_matrix)
h = jax.nn.sigmoid(h)
h = h * num_bins
indices = jnp.floor(h).astype(jnp.int32)
offsets = h - indices
h.shape, indices.shape, offsets.shape  # [1, feature_dim, num_features]

((1, 2, 50), (1, 2, 50), (1, 2, 50))

In [18]:
indices = jnp.stack([indices, indices + 1], axis=-1)  # [1, bin_dim, n_feat, 2]
values = jnp.stack([1.0 - offsets, offsets], axis=-1)  # [1, bin_dim, n_feat, 2]

In [19]:
multiplier = jnp.power(num_bins + 1, jnp.arange(feature_dim - 1, -1, -1))
indices *= multiplier[:, None, None]
# shape = (-1, n_feat, ) + (2,) * bin_dim
shape_suffix = [tuple(*p) for p in np.split(np.eye(bin_dim, dtype=np.int32) + 1, bin_dim)]
indices = sum(jnp.reshape(indices[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
values = math.prod(jnp.reshape(values[:, i], (-1, n_feat, *suffix)) for i, suffix in enumerate(shape_suffix))
# both indices and values has the shape (-1, n_feat, *(2,)*bin_dim) now.
indices += jnp.expand_dims(
    n_grids_per_lsh * jnp.arange(n_feat), axis=tuple(range(-bin_dim, 1, 1))
)  # expand 1 dim in the front and bin_dim in the back.
indices = jnp.reshape(indices, (-1, n_feat * 2**bin_dim))
values = jnp.reshape(values, (-1, n_feat * 2**bin_dim))

NameError: name 'np' is not defined

In [None]:
num_grids_per_feature = (num_bins + 1) ** feature_dim

Array([[ 0.64917064, -0.6725234 ]], dtype=float32)

In [None]:
proj_matrix = jnp.ones(shape=(2, 3, 4))
proj_matrix

Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float32)

In [None]:
jax.vmap(jnp.matmul, in_axes=[None, -1], out_axes=-1)(x, proj_matrix)

Array([[[-0.02335274, -0.02335274, -0.02335274, -0.02335274],
        [-0.02335274, -0.02335274, -0.02335274, -0.02335274],
        [-0.02335274, -0.02335274, -0.02335274, -0.02335274]]],      dtype=float32)